Unverified Commit 11c5d588 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Remove buffer registration for FSDP like cases (#318)



Remove extra buffers
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8c3110d1
......@@ -696,13 +696,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.reset_layer_norm_parameters()
if not skip_weight_param_allocation:
self.register_buffer("weight_tensor",
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
self.weight_tensor = torch.empty(
self.out_features, self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
initialize_affine_weight_gpu(
self.weight_tensor,
......@@ -713,17 +710,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
if self.use_bias:
self.register_buffer("bias_tensor",
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
self.bias_tensor = torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
else:
self.register_buffer("bias_tensor",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
with torch.no_grad():
self.bias_tensor.zero_()
......@@ -760,10 +753,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
self.register_buffer(bname,
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
......
......@@ -1048,10 +1048,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
else:
self.register_buffer("fc1_bias",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
self.fc1_bias = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
with torch.no_grad():
self.fc1_bias.zero_()
......@@ -1082,10 +1080,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
)
else:
self.register_buffer("fc2_bias",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
self.fc2_bias = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......
......@@ -559,13 +559,10 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
if not skip_weight_param_allocation:
self.register_buffer("weight_tensor",
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
self.weight_tensor = torch.empty(
self.out_features, self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
initialize_affine_weight_gpu(
self.weight_tensor,
......@@ -576,17 +573,13 @@ class Linear(TransformerEngineBaseModule):
)
if self.use_bias:
self.register_buffer("bias_tensor",
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
self.bias_tensor = torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
else:
self.register_buffer("bias_tensor",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
with torch.no_grad():
self.bias_tensor.zero_()
......@@ -623,10 +616,8 @@ class Linear(TransformerEngineBaseModule):
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
self.register_buffer(bname,
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment