"vscode:/vscode.git/clone" did not exist on "998d9d15095e7a69629f9e131c8b59bfdd1c6314"
Unverified Commit 8061412b authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

Set sequence_parallel before super().__init__() in norm modules (#1771)



* Set sequence_parallel before super().__init__() in norm modules
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>

* getattr(self, sequence_parallel, None) -> self.sequence_parallel
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>

---------
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>
parent c203f527
...@@ -94,6 +94,9 @@ class LayerNorm(_LayerNormOp): ...@@ -94,6 +94,9 @@ class LayerNorm(_LayerNormOp):
) )
kwargs["dtype"] = params_dtype kwargs["dtype"] = params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
# Initialize layer norm operation # Initialize layer norm operation
super().__init__( super().__init__(
normalized_shape, normalized_shape,
...@@ -102,12 +105,6 @@ class LayerNorm(_LayerNormOp): ...@@ -102,12 +105,6 @@ class LayerNorm(_LayerNormOp):
**kwargs, **kwargs,
) )
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
self.bias.sequence_parallel = sequence_parallel
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn( warnings.warn(
...@@ -136,7 +133,7 @@ class LayerNorm(_LayerNormOp): ...@@ -136,7 +133,7 @@ class LayerNorm(_LayerNormOp):
super().reset_parameters() super().reset_parameters()
# Set flag for sequence parallelism (custom Megatron-LM integration) # Set flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None: if self.sequence_parallel is not None:
self.weight.sequence_parallel = self.sequence_parallel self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel self.bias.sequence_parallel = self.sequence_parallel
......
...@@ -98,6 +98,9 @@ class RMSNorm(_RMSNormOp): ...@@ -98,6 +98,9 @@ class RMSNorm(_RMSNormOp):
) )
kwargs["dtype"] = params_dtype kwargs["dtype"] = params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
# Initialize RMSNorm operation # Initialize RMSNorm operation
super().__init__( super().__init__(
normalized_shape, normalized_shape,
...@@ -106,11 +109,6 @@ class RMSNorm(_RMSNormOp): ...@@ -106,11 +109,6 @@ class RMSNorm(_RMSNormOp):
**kwargs, **kwargs,
) )
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
def reset_rms_norm_parameters(self) -> None: def reset_rms_norm_parameters(self) -> None:
"""Deprecated""" """Deprecated"""
warnings.warn( warnings.warn(
...@@ -139,7 +137,7 @@ class RMSNorm(_RMSNormOp): ...@@ -139,7 +137,7 @@ class RMSNorm(_RMSNormOp):
super().reset_parameters() super().reset_parameters()
# Flag for sequence parallelism (custom Megatron-LM integration) # Flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None: if self.sequence_parallel is not None:
self.weight.sequence_parallel = self.sequence_parallel self.weight.sequence_parallel = self.sequence_parallel
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment