Unverified Commit 8061412b authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

Set sequence_parallel before super().__init__() in norm modules (#1771)



* Set sequence_parallel before super().__init__() in norm modules
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>

* getattr(self, sequence_parallel, None) -> self.sequence_parallel
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>

---------
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>
parent c203f527
......@@ -94,6 +94,9 @@ class LayerNorm(_LayerNormOp):
)
kwargs["dtype"] = params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
# Initialize layer norm operation
super().__init__(
normalized_shape,
......@@ -102,12 +105,6 @@ class LayerNorm(_LayerNormOp):
**kwargs,
)
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
self.bias.sequence_parallel = sequence_parallel
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
......@@ -136,7 +133,7 @@ class LayerNorm(_LayerNormOp):
super().reset_parameters()
# Set flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None:
if self.sequence_parallel is not None:
self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel
......
......@@ -98,6 +98,9 @@ class RMSNorm(_RMSNormOp):
)
kwargs["dtype"] = params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
# Initialize RMSNorm operation
super().__init__(
normalized_shape,
......@@ -106,11 +109,6 @@ class RMSNorm(_RMSNormOp):
**kwargs,
)
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
def reset_rms_norm_parameters(self) -> None:
"""Deprecated"""
warnings.warn(
......@@ -139,7 +137,7 @@ class RMSNorm(_RMSNormOp):
super().reset_parameters()
# Flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None:
if self.sequence_parallel is not None:
self.weight.sequence_parallel = self.sequence_parallel
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment