Unverified Commit 58a05b0c authored by dolpm's avatar dolpm Committed by GitHub
Browse files

[fix] CPUDNNLGEMMHandler pointer baked into inductor artifact (#32913)


Signed-off-by: default avatardolpm <34420038+dolpm@users.noreply.github.com>
parent 6ee7f18f
......@@ -360,13 +360,14 @@ void onednn_scaled_mm(
const std::optional<torch::Tensor>& azp, // [M] or [1]
const std::optional<torch::Tensor>& azp_adj, // [M] or [1]
const std::optional<torch::Tensor>& bias, // [N]
int64_t handler) {
const torch::Tensor& handler_tensor) {
CPU_KERNEL_GUARD_IN(onednn_scaled_mm)
TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(a.is_contiguous());
TORCH_CHECK(c.is_contiguous());
W8A8MatMulPrimitiveHandler* ptr =
reinterpret_cast<W8A8MatMulPrimitiveHandler*>(handler);
reinterpret_cast<W8A8MatMulPrimitiveHandler*>(
handler_tensor.item<int64_t>());
const int32_t* azp_ptr = nullptr;
if (azp.has_value()) {
azp_ptr = azp->data_ptr<int32_t>();
......@@ -519,13 +520,14 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b,
void onednn_mm(torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& a, // [M, IC], row-major
const std::optional<torch::Tensor>& bias, int64_t handler) {
const std::optional<torch::Tensor>& bias,
const torch::Tensor& handler_tensor) {
CPU_KERNEL_GUARD_IN(onednn_mm)
TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(a.stride(-1) == 1);
TORCH_CHECK(c.stride(-1) == 1);
MatMulPrimitiveHandler* ptr =
reinterpret_cast<MatMulPrimitiveHandler*>(handler);
reinterpret_cast<MatMulPrimitiveHandler*>(handler_tensor.item<int64_t>());
// ACL matmuls expect contiguous source tensors
#ifdef VLLM_USE_ACL
......
......@@ -19,13 +19,14 @@ void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& azp,
const std::optional<torch::Tensor>& azp_adj,
const std::optional<torch::Tensor>& bias,
int64_t handler);
const torch::Tensor& handler_tensor);
int64_t create_onednn_mm_handler(const torch::Tensor& b,
int64_t primitive_cache_size);
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& bias, int64_t handler);
const std::optional<torch::Tensor>& bias,
const torch::Tensor& handler_tensor);
bool is_onednn_acl_supported();
......@@ -196,7 +197,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// oneDNN GEMM
ops.def(
"onednn_mm(Tensor! c, Tensor a, Tensor? bias, "
"int handler) -> ()");
"Tensor handler_tensor) -> ()");
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
// Check if oneDNN was built with ACL backend
......@@ -212,7 +213,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// oneDNN scaled_mm for W8A8 with static per-tensor activation quantization
ops.def(
"onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, "
"Tensor? azp_adj, Tensor? bias, int handler) -> ()");
"Tensor? azp_adj, Tensor? bias, Tensor handler_tensor) -> ()");
ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm);
// Compute int8 quantized tensor for given scaling factor.
......
......@@ -2845,13 +2845,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
class CPUDNNLGEMMHandler:
def __init__(self) -> None:
self.handler: int | None = None
self.handler_tensor: torch.Tensor | None = None
self.n = -1
self.k = -1
def __del__(self):
if self.handler is not None:
torch.ops._C.release_dnnl_matmul_handler(self.handler)
if self.handler_tensor is not None:
torch.ops._C.release_dnnl_matmul_handler(self.handler_tensor.item())
_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler"))
......@@ -2867,8 +2867,10 @@ def create_onednn_mm(
) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_mm_handler(
weight, primitive_cache_size
# store the handler pointer in a tensor it doesn't get inlined
handler.handler_tensor = torch.tensor(
torch.ops._C.create_onednn_mm_handler(weight, primitive_cache_size),
dtype=torch.int64,
)
return handler
......@@ -2880,7 +2882,7 @@ def onednn_mm(
) -> torch.Tensor:
output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype)
torch.ops._C.onednn_mm(
output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler
output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler_tensor
)
return output
......@@ -2896,8 +2898,17 @@ def create_onednn_scaled_mm(
) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_scaled_mm_handler(
weight, weight_scales, output_type, dynamic_quant, use_azp, primitive_cache_size
# store the handler pointer in a tensor so it doesn't get inlined
handler.handler_tensor = torch.tensor(
torch.ops._C.create_onednn_scaled_mm_handler(
weight,
weight_scales,
output_type,
dynamic_quant,
use_azp,
primitive_cache_size,
),
dtype=torch.int64,
)
return handler
......@@ -2950,7 +2961,13 @@ def onednn_scaled_mm(
bias: torch.Tensor | None,
) -> torch.Tensor:
torch.ops._C.onednn_scaled_mm(
output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler
output,
x,
input_scale,
input_zp,
input_zp_adj,
bias,
dnnl_handler.handler_tensor,
)
return output
......
......@@ -289,16 +289,11 @@ def use_aot_compile() -> bool:
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = (
"1"
if is_torch_equal_or_newer("2.10.0.dev")
and not disable_compile_cache()
# Disabling AOT_COMPILE for CPU
# See: https://github.com/vllm-project/vllm/issues/32033
and not current_platform.is_cpu()
if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache()
else "0"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment