Unverified Commit 9277a0b5 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

DDP support for no-bias option [PyTorch] (#194)



DDP support for no-bias option
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 186cfaf3
...@@ -1465,9 +1465,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1465,9 +1465,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
dtype=params_dtype), dtype=params_dtype),
persistent=False) persistent=False)
else: else:
self.register_buffer( self.register_buffer("bias_tensor",
"bias_tensor", torch.Tensor().type(params_dtype), persistent=False torch.Tensor().to(dtype=params_dtype,
) device=torch.cuda.current_device()),
persistent=False)
with torch.no_grad(): with torch.no_grad():
self.bias_tensor.zero_() self.bias_tensor.zero_()
...@@ -1504,7 +1505,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1504,7 +1505,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
) )
else: else:
self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False) self.register_buffer(bname,
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
if parallel_mode == "column": if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
...@@ -2173,9 +2177,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -2173,9 +2177,10 @@ class Linear(TransformerEngineBaseModule):
dtype=params_dtype), dtype=params_dtype),
persistent=False) persistent=False)
else: else:
self.register_buffer( self.register_buffer("bias_tensor",
"bias_tensor", torch.Tensor().type(params_dtype), persistent=False torch.Tensor().to(dtype=params_dtype,
) device=torch.cuda.current_device()),
persistent=False)
with torch.no_grad(): with torch.no_grad():
self.bias_tensor.zero_() self.bias_tensor.zero_()
...@@ -2212,7 +2217,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -2212,7 +2217,10 @@ class Linear(TransformerEngineBaseModule):
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
) )
else: else:
self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False) self.register_buffer(bname,
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
if parallel_mode == "column": if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
...@@ -3249,7 +3257,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -3249,7 +3257,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
) )
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
else: else:
self.register_buffer("fc1_bias", torch.Tensor().type(params_dtype), persistent=False) self.register_buffer("fc1_bias",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
with torch.no_grad(): with torch.no_grad():
self.fc1_bias.zero_() self.fc1_bias.zero_()
...@@ -3280,7 +3291,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -3280,7 +3291,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
) )
) )
else: else:
self.register_buffer("fc2_bias", torch.Tensor().type(params_dtype), persistent=False) self.register_buffer("fc2_bias",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment