Commit 950d42b4 authored by dongcl's avatar dongcl
Browse files

megatron patch

parent 4e2de453
...@@ -110,7 +110,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -110,7 +110,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper, transformer_block_forward from ..core import transformer_block_init_wrapper, transformer_block_forward
from ..core.transformer.transformer_config import TransformerConfig, MLATransformerConfig from ..core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
# Transformer block # Transformer block
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__', MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper) transformer_block_init_wrapper)
...@@ -141,7 +141,8 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -141,7 +141,8 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_extentions(self): def patch_core_extentions(self):
import transformer_engine as te import transformer_engine as te
from ..core.extensions.transformer_engine import te_dot_product_attention_init, TEGroupedLinear from ..core.extensions.transformer_engine import te_dot_product_attention_init
from megatron.core.extensions.transformer_engine import TEGroupedLinear
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__', MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
te_dot_product_attention_init) te_dot_product_attention_init)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment