Commit 950d42b4 authored by dongcl's avatar dongcl
Browse files

megatron patch

parent 4e2de453
...@@ -141,7 +141,8 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -141,7 +141,8 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_extentions(self): def patch_core_extentions(self):
import transformer_engine as te import transformer_engine as te
from ..core.extensions.transformer_engine import te_dot_product_attention_init, TEGroupedLinear from ..core.extensions.transformer_engine import te_dot_product_attention_init
from megatron.core.extensions.transformer_engine import TEGroupedLinear
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__', MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
te_dot_product_attention_init) te_dot_product_attention_init)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment