Unverified Commit 9277a0b5 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

DDP support for no-bias option [PyTorch] (#194)



DDP support for no-bias option
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 186cfaf3
......@@ -1465,9 +1465,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
dtype=params_dtype),
persistent=False)
else:
self.register_buffer(
"bias_tensor", torch.Tensor().type(params_dtype), persistent=False
)
self.register_buffer("bias_tensor",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
with torch.no_grad():
self.bias_tensor.zero_()
......@@ -1504,7 +1505,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False)
self.register_buffer(bname,
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
......@@ -2173,9 +2177,10 @@ class Linear(TransformerEngineBaseModule):
dtype=params_dtype),
persistent=False)
else:
self.register_buffer(
"bias_tensor", torch.Tensor().type(params_dtype), persistent=False
)
self.register_buffer("bias_tensor",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
with torch.no_grad():
self.bias_tensor.zero_()
......@@ -2212,7 +2217,10 @@ class Linear(TransformerEngineBaseModule):
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False)
self.register_buffer(bname,
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
......@@ -3249,7 +3257,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
else:
self.register_buffer("fc1_bias", torch.Tensor().type(params_dtype), persistent=False)
self.register_buffer("fc1_bias",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
with torch.no_grad():
self.fc1_bias.zero_()
......@@ -3280,7 +3291,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
)
else:
self.register_buffer("fc2_bias", torch.Tensor().type(params_dtype), persistent=False)
self.register_buffer("fc2_bias",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment