Commit 54740897 authored by dongcl's avatar dongcl
Browse files

check te version

parent 2862a32a
__pycache__
...@@ -140,7 +140,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -140,7 +140,7 @@ class CoreAdaptation(MegatronAdaptationABC):
TEDotProductAttentionPatch.__init__) TEDotProductAttentionPatch.__init__)
if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')): if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')):
TEGroupedLinear.__bases__ = (te.pytorch.BatchedLinear if is_te_min_version("2.3.0") else te.pytorch.BatchLinear,) TEGroupedLinear.__bases__ = (te.pytorch.BatchedLinear if is_te_min_version("2.3.0.dev0") else te.pytorch.BatchLinear,)
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment