Unverified Commit 4b523d29 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Set flags in norm modules for Mcore sequence-parallel support (#1528)



Set flag in norm modules for Mcore sequence-parallel support
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent d3efaebb
......@@ -104,6 +104,9 @@ class LayerNorm(_LayerNormOp):
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
self.bias.sequence_parallel = sequence_parallel
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
......
......@@ -108,6 +108,8 @@ class RMSNorm(_RMSNormOp):
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
def reset_rms_norm_parameters(self) -> None:
"""Deprecated"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment