Commit bb9e670a authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update linear.py

parent ee9541af
...@@ -3,19 +3,19 @@ from text_generation_server.utils.import_utils import SYSTEM ...@@ -3,19 +3,19 @@ from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F from torch.nn import functional as F
import os import os
# if SYSTEM == "rocm": if SYSTEM == "rocm":
# ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in ( ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
# "true", "true",
# "1", "1",
# ) )
ROCM_USE_SKINNY_GEMM=False
# if ROCM_USE_SKINNY_GEMM: if ROCM_USE_SKINNY_GEMM:
# try: try:
# from vllm import _custom_C from vllm import _custom_ops
# except Exception as e: except Exception as e:
# raise ImportError( raise ImportError(
# f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}" f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
# ) )
class FastLinear(torch.nn.Module): class FastLinear(torch.nn.Module):
...@@ -91,18 +91,18 @@ class FastLinearROCm(torch.nn.Module): ...@@ -91,18 +91,18 @@ class FastLinearROCm(torch.nn.Module):
batched = True batched = True
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1] m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
if m > 8 and n <= 4: # if m > 8 and n <= 4:
out = torch.empty( # out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device # inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
) # )
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count) # _custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192: # elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty( # out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device # inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
) # )
_custom_C.LLMM1(weight, inp, out, 4) # _custom_C.LLMM1(weight, inp, out, 4)
else: # else:
out = F.linear(inp, weight) out = F.linear(inp, weight)
if batched: if batched:
out.view(*inp_shape[:-1], out.shape[-1]) out.view(*inp_shape[:-1], out.shape[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment