Unverified Commit 4b523d29 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Set flags in norm modules for Mcore sequence-parallel support (#1528)



Set flag in norm modules for Mcore sequence-parallel support
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent d3efaebb
...@@ -104,6 +104,9 @@ class LayerNorm(_LayerNormOp): ...@@ -104,6 +104,9 @@ class LayerNorm(_LayerNormOp):
# Flag for sequence parallelism (custom Megatron-LM integration) # Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
self.bias.sequence_parallel = sequence_parallel
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
......
...@@ -108,6 +108,8 @@ class RMSNorm(_RMSNormOp): ...@@ -108,6 +108,8 @@ class RMSNorm(_RMSNormOp):
# Flag for sequence parallelism (custom Megatron-LM integration) # Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
def reset_rms_norm_parameters(self) -> None: def reset_rms_norm_parameters(self) -> None:
"""Deprecated""" """Deprecated"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment