Commit bb9e670a authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update linear.py

parent ee9541af
......@@ -3,19 +3,19 @@ from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
import os
# if SYSTEM == "rocm":
# ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
# "true",
# "1",
# )
ROCM_USE_SKINNY_GEMM=False
# if ROCM_USE_SKINNY_GEMM:
# try:
# from vllm import _custom_C
# except Exception as e:
# raise ImportError(
# f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
# )
if SYSTEM == "rocm":
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
"true",
"1",
)
if ROCM_USE_SKINNY_GEMM:
try:
from vllm import _custom_ops
except Exception as e:
raise ImportError(
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
)
class FastLinear(torch.nn.Module):
......@@ -91,17 +91,17 @@ class FastLinearROCm(torch.nn.Module):
batched = True
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
if m > 8 and n <= 4:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.LLMM1(weight, inp, out, 4)
else:
# if m > 8 and n <= 4:
# out = torch.empty(
# inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
# )
# _custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
# elif m % 4 == 0 and n == 1 and k <= 8192:
# out = torch.empty(
# inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
# )
# _custom_C.LLMM1(weight, inp, out, 4)
# else:
out = F.linear(inp, weight)
if batched:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment