Commit f796eb80 authored by wenjh's avatar wenjh
Browse files

Fix new gemm


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 04afba37
...@@ -45,7 +45,6 @@ from transformer_engine.pytorch.attention.inference import InferenceParams ...@@ -45,7 +45,6 @@ from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm, batchgemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm, batchgemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
......
...@@ -13,8 +13,7 @@ import warnings ...@@ -13,8 +13,7 @@ import warnings
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.cpp_extensions import gemm from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
def use_hipblaslt(): def use_hipblaslt():
...@@ -142,7 +141,7 @@ def run_gemm(): ...@@ -142,7 +141,7 @@ def run_gemm():
N = 32 N = 32
datatype = torch.float16 datatype = torch.float16
inp = torch.randn((N, N), device="cuda", dtype=datatype) inp = torch.randn((N, N), device="cuda", dtype=datatype)
_, _, _ = gemm(A=inp, B=inp, dtype=datatype, workspace=get_workspace()) _, _, _ = general_gemm(A=inp, B=inp, dtype=datatype)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment