Unverified Commit d34f5fe9 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[Bugfix][CPU] Fallback oneDNN linear to torch linear to fix half gemm support...


[Bugfix][CPU] Fallback oneDNN linear to torch linear to fix half gemm support on legecy platforms (#27526)
Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent bdb01a38
...@@ -79,7 +79,7 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc ...@@ -79,7 +79,7 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc
######################### BUILD IMAGE ######################### ######################### BUILD IMAGE #########################
FROM base AS vllm-build FROM base AS vllm-build
ARG max_jobs=2 ARG max_jobs=32
ENV MAX_JOBS=${max_jobs} ENV MAX_JOBS=${max_jobs}
ARG GIT_REPO_CHECK=0 ARG GIT_REPO_CHECK=0
......
...@@ -8,9 +8,12 @@ import torch ...@@ -8,9 +8,12 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
def shuffle_weight(w: torch.Tensor) -> torch.Tensor: def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
# Shuffle weight along the last dimension so that # Shuffle weight along the last dimension so that
...@@ -178,19 +181,28 @@ def dispatch_cpu_unquantized_gemm( ...@@ -178,19 +181,28 @@ def dispatch_cpu_unquantized_gemm(
) )
if remove_weight: if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
return
elif ( elif (
ops._supports_onednn ops._supports_onednn
and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC
): ):
origin_weight = layer.weight try:
if remove_weight: origin_weight = layer.weight
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) handler = ops.create_onednn_mm(origin_weight.t(), 32)
handler = ops.create_onednn_mm(origin_weight.t(), 32) layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias) if remove_weight:
else: layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( return
x, weight, bias except RuntimeError as e:
) logger.warning_once(
"Failed to create oneDNN linear, fallback to torch linear."
f" Exception: {e}"
)
# fallback case
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
x, weight, bias
)
def cpu_unquantized_gemm( def cpu_unquantized_gemm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment