"vscode:/vscode.git/clone" did not exist on "915140fd18c9ff4193e994e6d756ea762a52240a"
Commit 23eb9b17 authored by dongcl's avatar dongcl
Browse files

modify TEGroupedLinear base

parent 43770f8e
...@@ -5,6 +5,8 @@ import types ...@@ -5,6 +5,8 @@ import types
import argparse import argparse
import torch import torch
from megatron.core.utils import is_te_min_version
class MegatronAdaptation: class MegatronAdaptation:
""" """
...@@ -132,12 +134,13 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -132,12 +134,13 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..core.extensions.transformer_engine import TEDotProductAttentionPatch from ..core.extensions.transformer_engine import TEDotProductAttentionPatch
from megatron.core.extensions.transformer_engine import TEGroupedLinear from megatron.core.extensions.transformer_engine import TEGroupedLinear
# kv channels, te_min_version 1.10.0 -> 1.9.0 if not is_te_min_version("1.10.0"):
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__', # kv channels, te_min_version 1.10.0 -> 1.9.0
TEDotProductAttentionPatch.__init__) MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
TEDotProductAttentionPatch.__init__)
if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')): if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')):
TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,) TEGroupedLinear.__bases__ = (te.pytorch.BatchedLinear if is_te_min_version("2.3.0") else te.pytorch.BatchLinear,)
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment