Commit c686efc1 authored by yuguo's avatar yuguo
Browse files

[DCU] fix batch gemm

parent cdb862cd
......@@ -84,7 +84,7 @@ from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP
from transformer_engine.pytorch.module import LayerNorm
from transformer_engine.pytorch.module import RMSNorm
from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.pytorch.module import GroupedLinear, BatchedLinear
from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding
from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
......
......@@ -18,6 +18,7 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
__all__ = [
"general_gemm",
"general_grouped_gemm",
"general_batched_gemm",
]
......
......@@ -5,7 +5,7 @@
"""Module level PyTorch APIs"""
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .grouped_linear import GroupedLinear
from .grouped_linear import GroupedLinear, BatchedLinear
from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm
from .rmsnorm import RMSNorm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment