Commit 950d42b4 authored by dongcl's avatar dongcl
Browse files

megatron patch

parent 4e2de453
......@@ -110,7 +110,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper, transformer_block_forward
from ..core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
# Transformer block
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper)
......@@ -141,7 +141,8 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_extentions(self):
import transformer_engine as te
from ..core.extensions.transformer_engine import te_dot_product_attention_init, TEGroupedLinear
from ..core.extensions.transformer_engine import te_dot_product_attention_init
from megatron.core.extensions.transformer_engine import TEGroupedLinear
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
te_dot_product_attention_init)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment