Commit 425a9899 authored by dongcl's avatar dongcl
Browse files

bug fix

parent 9800dec4
......@@ -241,15 +241,15 @@ class LegacyAdaptation(MegatronAdaptationABC):
self.patch_legacy_models()
def patch_legacy_models(self):
from ..legacy.model.transformer import ParallelMLP, ParallelAttention
from ..legacy.model.transformer import ParallelMLPPatch, ParallelAttentionPatch
from ..legacy.model.utils import get_norm
# ParallecMLP
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__',
ParallelMLP.__init__)
ParallelMLPPatch.__init__)
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.forward',
ParallelAttention.forward)
ParallelAttentionPatch.forward)
# rms_norm.RMSNorm
MegatronAdaptation.register('megatron.legacy.model.rms_norm.RMSNorm.forward',
......
......@@ -10,7 +10,7 @@ from megatron.legacy.model.utils import (
)
class ParallelMLP(MegatronModule):
class ParallelMLPPatch(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
......@@ -74,7 +74,7 @@ class ParallelMLP(MegatronModule):
)
class ParallelAttention(MegatronModule):
class ParallelAttentionPatch(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment