Commit 0b2b5417 authored by dongcl's avatar dongcl
Browse files

rewrite transformer_engine

parent f098f250
...@@ -143,11 +143,11 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -143,11 +143,11 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_extentions(self): def patch_core_extentions(self):
import transformer_engine as te import transformer_engine as te
from ..core.extensions.transformer_engine import te_dot_product_attention_init from ..core.extensions.transformer_engine import TEDotProductAttentionPatch
from megatron.core.extensions.transformer_engine import TEGroupedLinear from megatron.core.extensions.transformer_engine import TEGroupedLinear
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__', MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
te_dot_product_attention_init) TEDotProductAttentionPatch.__init__)
if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')): if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')):
TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,) TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,)
......
import os import os
import dataclasses import dataclasses
import transformer_engine as te
from typing import Any, Optional from typing import Any, Optional
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
...@@ -19,7 +20,8 @@ from megatron.core.parallel_state import ( ...@@ -19,7 +20,8 @@ from megatron.core.parallel_state import (
) )
def te_dot_product_attention_init( class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
def __init__(
self, self,
config: TransformerConfig, config: TransformerConfig,
layer_number: int, layer_number: int,
...@@ -30,7 +32,7 @@ def te_dot_product_attention_init( ...@@ -30,7 +32,7 @@ def te_dot_product_attention_init(
k_channels: Optional[int] = None, k_channels: Optional[int] = None,
v_channels: Optional[int] = None, v_channels: Optional[int] = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
): ):
self.config = config self.config = config
self.te_forward_mask_type = False self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd' self.qkv_format: str = 'sbhd'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment