Unverified Commit b95c1818 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] set SP attr on bias param for reduction (#440)



Fix for sequence-parallel
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c32a62cc
...@@ -737,6 +737,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -737,6 +737,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.register_parameter( self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
) )
if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
else: else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device)) setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
......
...@@ -1054,6 +1054,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1054,6 +1054,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fc2_bias = Parameter( self.fc2_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype) torch.empty(hidden_size, device=device, dtype=params_dtype)
) )
# RPL
if self.set_parallel_mode:
setattr(self.fc2_bias, "sequence_parallel", sequence_parallel)
else: else:
self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device) self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device)
......
...@@ -628,6 +628,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -628,6 +628,8 @@ class Linear(TransformerEngineBaseModule):
self.register_parameter( self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
) )
if parallel_mode == "row":
setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
else: else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device)) setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment