Commit f796eb80 authored by wenjh's avatar wenjh
Browse files

Fix new gemm


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 04afba37
......@@ -45,7 +45,6 @@ from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm, batchgemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex
......
......@@ -13,8 +13,7 @@ import warnings
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.cpp_extensions import gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm
def use_hipblaslt():
......@@ -142,7 +141,7 @@ def run_gemm():
N = 32
datatype = torch.float16
inp = torch.randn((N, N), device="cuda", dtype=datatype)
_, _, _ = gemm(A=inp, B=inp, dtype=datatype, workspace=get_workspace())
_, _, _ = general_gemm(A=inp, B=inp, dtype=datatype)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment