"transformer_engine/pytorch/attention.py" did not exist on "7324fe2b06de33dcf030ad0cec3bd6348fc4a7b1"
Unverified Commit b95c1818 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] set SP attr on bias param for reduction (#440)



Fix for sequence-parallel
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c32a62cc
......@@ -737,6 +737,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
......
......@@ -1054,6 +1054,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fc2_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype)
)
# RPL
if self.set_parallel_mode:
setattr(self.fc2_bias, "sequence_parallel", sequence_parallel)
else:
self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device)
......
......@@ -628,6 +628,8 @@ class Linear(TransformerEngineBaseModule):
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment