"vscode:/vscode.git/clone" did not exist on "74d5543ec589daaa4ac042d65d52dccf26ee3f2c"
Unverified Commit d34f5fe9 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[Bugfix][CPU] Fallback oneDNN linear to torch linear to fix half gemm support...


[Bugfix][CPU] Fallback oneDNN linear to torch linear to fix half gemm support on legecy platforms (#27526)
Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent bdb01a38
...@@ -79,7 +79,7 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc ...@@ -79,7 +79,7 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc
######################### BUILD IMAGE ######################### ######################### BUILD IMAGE #########################
FROM base AS vllm-build FROM base AS vllm-build
ARG max_jobs=2 ARG max_jobs=32
ENV MAX_JOBS=${max_jobs} ENV MAX_JOBS=${max_jobs}
ARG GIT_REPO_CHECK=0 ARG GIT_REPO_CHECK=0
......
...@@ -8,9 +8,12 @@ import torch ...@@ -8,9 +8,12 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
def shuffle_weight(w: torch.Tensor) -> torch.Tensor: def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
# Shuffle weight along the last dimension so that # Shuffle weight along the last dimension so that
...@@ -178,16 +181,25 @@ def dispatch_cpu_unquantized_gemm( ...@@ -178,16 +181,25 @@ def dispatch_cpu_unquantized_gemm(
) )
if remove_weight: if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
return
elif ( elif (
ops._supports_onednn ops._supports_onednn
and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC
): ):
try:
origin_weight = layer.weight origin_weight = layer.weight
if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
handler = ops.create_onednn_mm(origin_weight.t(), 32) handler = ops.create_onednn_mm(origin_weight.t(), 32)
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias) layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
else: if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
return
except RuntimeError as e:
logger.warning_once(
"Failed to create oneDNN linear, fallback to torch linear."
f" Exception: {e}"
)
# fallback case
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
x, weight, bias x, weight, bias
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment