Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
57b1ce94
Unverified
Commit
57b1ce94
authored
Sep 04, 2025
by
Li, Jiang
Committed by
GitHub
Sep 04, 2025
Browse files
[CPU] Refactor CPU unquantized linear (#24150)
Signed-off-by:
jiang1.li
<
jiang1.li@intel.com
>
parent
cb55ad86
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
466 additions
and
26 deletions
+466
-26
csrc/cpu/dnnl_helper.cpp
csrc/cpu/dnnl_helper.cpp
+177
-0
csrc/cpu/dnnl_helper.h
csrc/cpu/dnnl_helper.h
+74
-0
csrc/cpu/dnnl_kernels.cpp
csrc/cpu/dnnl_kernels.cpp
+54
-0
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+18
-0
tests/kernels/test_onednn.py
tests/kernels/test_onednn.py
+70
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+29
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+4
-21
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+34
-5
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+6
-0
No files found.
csrc/cpu/dnnl_helper.cpp
View file @
57b1ce94
...
...
@@ -22,6 +22,23 @@ void release_dnnl_matmul_handler(int64_t handler) {
delete
ptr
;
}
DNNLScratchPadManager
::
DNNLScratchPadManager
()
:
size_
(
0
),
ptr_
(
nullptr
)
{
this
->
realloc
(
allocation_unit
*
128
);
}
void
DNNLScratchPadManager
::
realloc
(
size_t
new_size
)
{
new_size
=
round
(
new_size
);
if
(
new_size
>
size_
)
{
ptr_
=
std
::
aligned_alloc
(
64
,
new_size
);
size_
=
new_size
;
}
}
DNNLScratchPadManager
*
DNNLScratchPadManager
::
get_dnnl_scratchpad_manager
()
{
static
DNNLScratchPadManager
manager
;
return
&
manager
;
}
template
<
typename
KT
,
typename
VT
>
class
DNNLPrimitiveCache
{
public:
...
...
@@ -166,6 +183,23 @@ struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
hash
<
int
>
()(
static_cast
<
int
>
(
val
.
bias_type
));
}
};
template
<
>
struct
hash
<
MatMulPrimitiveHandler
::
ClassMatmulCacheKey
>
{
size_t
operator
()(
const
MatMulPrimitiveHandler
::
ClassMatmulCacheKey
&
val
)
const
{
return
hash
<
dnnl_dim_t
>
()(
val
.
b_n_size
)
^
hash
<
dnnl_dim_t
>
()(
val
.
b_k_size
);
}
};
template
<
>
struct
hash
<
MatMulPrimitiveHandler
::
MSizeCacheKey
>
{
size_t
operator
()(
const
MatMulPrimitiveHandler
::
MSizeCacheKey
&
val
)
const
{
return
hash
<
dnnl_dim_t
>
()(
val
.
a_m_size
)
^
hash
<
dnnl_dim_t
>
()(
val
.
a_m_stride
)
^
hash
<
bool
>
()(
val
.
use_bias
)
^
hash
<
int
>
()(
static_cast
<
int
>
(
val
.
bias_type
));
}
};
}
// namespace std
bool
operator
==
(
const
W8A8MatMulPrimitiveHandler
::
ClassMatmulCacheKey
&
l
,
...
...
@@ -181,6 +215,17 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
l
.
bias_type
==
r
.
bias_type
;
}
bool
operator
==
(
const
MatMulPrimitiveHandler
::
ClassMatmulCacheKey
&
l
,
const
MatMulPrimitiveHandler
::
ClassMatmulCacheKey
&
r
)
{
return
l
.
b_n_size
==
r
.
b_n_size
&&
l
.
b_k_size
==
r
.
b_k_size
;
}
bool
operator
==
(
const
MatMulPrimitiveHandler
::
MSizeCacheKey
&
l
,
const
MatMulPrimitiveHandler
::
MSizeCacheKey
&
r
)
{
return
l
.
a_m_size
==
r
.
a_m_size
&&
l
.
a_m_stride
==
r
.
a_m_stride
&&
l
.
use_bias
==
r
.
use_bias
&&
l
.
bias_type
==
r
.
bias_type
;
}
static
std
::
shared_ptr
<
W8A8MatMulPrimitiveHandler
::
MSizeCache
>
get_w8a8_class_primitive_cache
(
const
W8A8MatMulPrimitiveHandler
::
ClassMatmulCacheKey
&
key
,
...
...
@@ -239,6 +284,11 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
}
dnnl
::
matmul
matmul
=
get_matmul_cache
(
args
);
auto
&&
[
scratchpad_storage
,
scratchpad_mem_desc
]
=
get_runtime_memory_ptr
(
5
);
scratchpad_storage
->
set_data_handle
(
DNNLScratchPadManager
::
get_dnnl_scratchpad_manager
()
->
get_data
<
void
>
());
matmul
.
execute
(
default_stream
(),
memory_cache_
);
default_stream
().
wait
();
}
...
...
@@ -257,6 +307,8 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
return
m_size_cache_
->
get_or_create
(
key
,
[
&
]()
{
dnnl
::
matmul
::
primitive_desc
desc
=
this
->
create_primitive_desc
(
key
,
false
);
auto
manager
=
DNNLScratchPadManager
::
get_dnnl_scratchpad_manager
();
manager
->
realloc
(
desc
.
scratchpad_desc
().
get_size
());
return
dnnl
::
matmul
(
desc
);
});
}
...
...
@@ -300,6 +352,11 @@ void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
dnnl
::
memory
({{
b_n_size_
},
dnnl
::
memory
::
data_type
::
f32
,
{
1
}},
default_engine
(),
nullptr
);
set_runtime_memory_ptr
(
4
,
memory_cache_
[
DNNL_ARG_BIAS
].
get
());
memory_cache_
[
DNNL_ARG_SCRATCHPAD
]
=
dnnl
::
memory
({{
b_n_size_
},
dnnl
::
memory
::
data_type
::
f32
,
{
1
}},
default_engine
(),
nullptr
);
set_runtime_memory_ptr
(
5
,
memory_cache_
[
DNNL_ARG_SCRATCHPAD
].
get
());
}
dnnl
::
matmul
::
primitive_desc
W8A8MatMulPrimitiveHandler
::
create_primitive_desc
(
...
...
@@ -319,6 +376,9 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
dnnl
::
memory
::
format_tag
::
ab
);
dnnl
::
primitive_attr
attr
;
attr
.
set_scratchpad_mode
(
dnnl
::
scratchpad_mode
::
user
);
// For PER_TOKEN, scales will be applied in outside epilogue
if
(
a_qs_
==
QuantizationStrategy
::
PER_TENSOR
)
{
attr
.
set_scales_mask
(
DNNL_ARG_SRC
,
0
);
...
...
@@ -344,3 +404,120 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
attr
);
}
}
MatMulPrimitiveHandler
::
MatMulPrimitiveHandler
(
const
Args
&
args
)
:
DNNLMatMulPrimitiveHandler
(
static_cast
<
DNNLMatMulPrimitiveHandler
::
Args
>
(
args
),
args
.
ab_type
),
m_size_cache_
(
nullptr
)
{
assert
(
ab_type_
==
dnnl
::
memory
::
data_type
::
f32
||
ab_type_
==
dnnl
::
memory
::
data_type
::
bf16
||
ab_type_
==
dnnl
::
memory
::
data_type
::
f16
);
prepack_weight
(
args
.
b_ptr
,
create_primitive_desc
(
MSizeCacheKey
{.
a_m_size
=
DNNL_RUNTIME_DIM_VAL
,
.
a_m_stride
=
DNNL_RUNTIME_DIM_VAL
,
.
use_bias
=
false
,
.
bias_type
=
dnnl
::
memory
::
data_type
::
undef
},
true
)
.
weights_desc
());
init_runtime_memory_cache
(
args
);
}
static
std
::
shared_ptr
<
MatMulPrimitiveHandler
::
MSizeCache
>
get_matul_class_primitive_cache
(
const
MatMulPrimitiveHandler
::
ClassMatmulCacheKey
&
key
,
int64_t
cache_size
)
{
static
MatMulPrimitiveHandler
::
ClassMatmulCache
cache
(
128
);
assert
(
cache_size
>
0
);
return
cache
.
get_or_create
(
key
,
[
&
]()
{
return
std
::
make_shared
<
MatMulPrimitiveHandler
::
MSizeCache
>
(
cache_size
);
});
}
void
MatMulPrimitiveHandler
::
execute
(
ExecArgs
&
args
)
{
auto
&&
[
a_storage
,
a_mem_desc
]
=
get_runtime_memory_ptr
(
0
);
auto
&&
[
c_storage
,
c_mem_desc
]
=
get_runtime_memory_ptr
(
1
);
a_storage
->
set_data_handle
((
void
*
)
args
.
a_ptr
);
a_mem_desc
->
dims
[
0
]
=
args
.
a_m_size
;
a_mem_desc
->
format_desc
.
blocking
.
strides
[
0
]
=
args
.
a_m_stride
;
c_storage
->
set_data_handle
((
void
*
)
args
.
c_ptr
);
c_mem_desc
->
dims
[
0
]
=
args
.
a_m_size
;
if
(
args
.
use_bias
)
{
auto
&&
[
bias_storage
,
bias_mem_desc
]
=
get_runtime_memory_ptr
(
2
);
bias_storage
->
set_data_handle
((
void
*
)
args
.
bias_ptr
);
}
dnnl
::
matmul
matmul
=
get_matmul_cache
(
args
);
auto
&&
[
scratchpad_storage
,
scratchpad_mem_desc
]
=
get_runtime_memory_ptr
(
3
);
scratchpad_storage
->
set_data_handle
(
DNNLScratchPadManager
::
get_dnnl_scratchpad_manager
()
->
get_data
<
void
>
());
matmul
.
execute
(
default_stream
(),
memory_cache_
);
default_stream
().
wait
();
}
dnnl
::
matmul
MatMulPrimitiveHandler
::
get_matmul_cache
(
const
MSizeCacheKey
&
key
)
{
if
(
m_size_cache_
.
get
()
==
nullptr
)
{
ClassMatmulCacheKey
key
=
{.
b_n_size
=
b_n_size_
,
.
b_k_size
=
b_k_size_
};
m_size_cache_
=
get_matul_class_primitive_cache
(
key
,
primitive_cache_size_
);
}
return
m_size_cache_
->
get_or_create
(
key
,
[
&
]()
{
dnnl
::
matmul
::
primitive_desc
desc
=
this
->
create_primitive_desc
(
key
,
false
);
auto
manager
=
DNNLScratchPadManager
::
get_dnnl_scratchpad_manager
();
manager
->
realloc
(
desc
.
scratchpad_desc
().
get_size
());
return
dnnl
::
matmul
(
desc
);
});
}
dnnl
::
matmul
::
primitive_desc
MatMulPrimitiveHandler
::
create_primitive_desc
(
const
MSizeCacheKey
&
key
,
bool
first_time
)
{
dnnl
::
memory
::
desc
a_md
;
dnnl
::
memory
::
desc
b_md
;
if
(
first_time
)
{
a_md
=
dnnl
::
memory
::
desc
({
key
.
a_m_size
,
b_k_size_
},
b_type_
,
dnnl
::
memory
::
format_tag
::
ab
);
b_md
=
dnnl
::
memory
::
desc
({
b_k_size_
,
b_n_size_
},
b_type_
,
dnnl
::
memory
::
format_tag
::
any
);
}
else
{
a_md
=
dnnl
::
memory
::
desc
({
key
.
a_m_size
,
b_k_size_
},
b_type_
,
{
key
.
a_m_stride
,
1
});
b_md
=
b_target_mem_desc_
;
}
dnnl
::
memory
::
desc
c_md
({
key
.
a_m_size
,
b_n_size_
},
c_type_
,
dnnl
::
memory
::
format_tag
::
ab
);
dnnl
::
primitive_attr
attr
;
attr
.
set_scratchpad_mode
(
dnnl
::
scratchpad_mode
::
user
);
if
(
key
.
use_bias
)
{
dnnl
::
memory
::
desc
bias_md
({
1
,
b_n_size_
},
key
.
bias_type
,
{
b_n_size_
,
1
});
return
dnnl
::
matmul
::
primitive_desc
(
default_engine
(),
a_md
,
b_md
,
bias_md
,
c_md
,
attr
);
}
else
{
return
dnnl
::
matmul
::
primitive_desc
(
default_engine
(),
a_md
,
b_md
,
c_md
,
attr
);
}
}
void
MatMulPrimitiveHandler
::
init_runtime_memory_cache
(
const
Args
&
args
)
{
memory_cache_
[
DNNL_ARG_SRC
]
=
dnnl
::
memory
(
{{
1
,
b_k_size_
},
b_type_
,
{
b_k_size_
,
1
}},
default_engine
(),
nullptr
);
set_runtime_memory_ptr
(
0
,
memory_cache_
[
DNNL_ARG_SRC
].
get
());
memory_cache_
[
DNNL_ARG_DST
]
=
dnnl
::
memory
({{
1
,
b_n_size_
},
c_type_
,
dnnl
::
memory
::
format_tag
::
ab
},
default_engine
(),
nullptr
);
set_runtime_memory_ptr
(
1
,
memory_cache_
[
DNNL_ARG_DST
].
get
());
memory_cache_
[
DNNL_ARG_BIAS
]
=
dnnl
::
memory
({{
b_n_size_
},
dnnl
::
memory
::
data_type
::
f32
,
{
1
}},
default_engine
(),
nullptr
);
set_runtime_memory_ptr
(
2
,
memory_cache_
[
DNNL_ARG_BIAS
].
get
());
memory_cache_
[
DNNL_ARG_SCRATCHPAD
]
=
dnnl
::
memory
({{
b_n_size_
},
dnnl
::
memory
::
data_type
::
f32
,
{
1
}},
default_engine
(),
nullptr
);
set_runtime_memory_ptr
(
3
,
memory_cache_
[
DNNL_ARG_SCRATCHPAD
].
get
());
}
csrc/cpu/dnnl_helper.h
View file @
57b1ce94
...
...
@@ -59,6 +59,30 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() {
return
DNNLType
<
std
::
decay_t
<
T
>>::
type
;
}
class
DNNLScratchPadManager
{
public:
static
constexpr
size_t
allocation_unit
=
4
*
1024
*
1024
;
// 4KB
static
DNNLScratchPadManager
*
get_dnnl_scratchpad_manager
();
DNNLScratchPadManager
();
template
<
typename
T
>
T
*
get_data
()
{
return
reinterpret_cast
<
T
*>
(
ptr_
);
}
static
size_t
round
(
size_t
size
)
{
return
((
size
+
allocation_unit
-
1
)
/
allocation_unit
)
*
allocation_unit
;
}
void
realloc
(
size_t
new_size
);
private:
size_t
size_
;
void
*
ptr_
;
};
class
DNNLMatMulPrimitiveHandler
{
public:
virtual
~
DNNLMatMulPrimitiveHandler
()
=
default
;
...
...
@@ -166,4 +190,54 @@ class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
std
::
shared_ptr
<
MSizeCache
>
m_size_cache_
;
};
class
MatMulPrimitiveHandler
:
public
DNNLMatMulPrimitiveHandler
{
public:
struct
Args
:
public
DNNLMatMulPrimitiveHandler
::
Args
{
dnnl
::
memory
::
data_type
ab_type
;
};
struct
ClassMatmulCacheKey
{
dnnl_dim_t
b_n_size
;
dnnl_dim_t
b_k_size
;
friend
bool
operator
==
(
const
ClassMatmulCacheKey
&
l
,
const
ClassMatmulCacheKey
&
r
);
};
struct
MSizeCacheKey
{
dnnl_dim_t
a_m_size
;
dnnl_dim_t
a_m_stride
;
bool
use_bias
;
dnnl
::
memory
::
data_type
bias_type
;
friend
bool
operator
==
(
const
MSizeCacheKey
&
l
,
const
MSizeCacheKey
&
r
);
};
using
MSizeCache
=
DNNLPrimitiveCache
<
MSizeCacheKey
,
dnnl
::
matmul
>
;
using
ClassMatmulCache
=
DNNLPrimitiveCache
<
ClassMatmulCacheKey
,
std
::
shared_ptr
<
MSizeCache
>>
;
struct
ExecArgs
:
public
MSizeCacheKey
{
const
void
*
a_ptr
;
const
void
*
bias_ptr
;
void
*
c_ptr
;
};
public:
MatMulPrimitiveHandler
(
const
Args
&
args
);
void
execute
(
ExecArgs
&
args
);
private:
dnnl
::
matmul
::
primitive_desc
create_primitive_desc
(
const
MSizeCacheKey
&
key
,
bool
first_time
);
void
init_runtime_memory_cache
(
const
Args
&
args
);
dnnl
::
matmul
get_matmul_cache
(
const
MSizeCacheKey
&
key
);
private:
std
::
shared_ptr
<
MSizeCache
>
m_size_cache_
;
};
#endif
csrc/cpu/dnnl_kernels.cpp
View file @
57b1ce94
...
...
@@ -379,6 +379,7 @@ void onednn_scaled_mm(
exec_args
.
a_ptr
=
a
.
data_ptr
<
int8_t
>
();
exec_args
.
a_m_size
=
a
.
size
(
0
);
exec_args
.
bias_ptr
=
nullptr
;
exec_args
.
bias_type
=
get_dnnl_type
<
void
>
();
exec_args
.
use_bias
=
false
;
exec_args
.
a_scales_ptr
=
nullptr
;
exec_args
.
a_zero_points_ptr
=
nullptr
;
...
...
@@ -492,3 +493,56 @@ void dynamic_scaled_int8_quant(
}
});
}
int64_t
create_onednn_mm_handler
(
const
torch
::
Tensor
&
b
,
int64_t
primitive_cache_size
)
{
TORCH_CHECK
(
b
.
dim
()
==
2
);
MatMulPrimitiveHandler
::
Args
args
;
args
.
primitive_cache_size
=
primitive_cache_size
;
args
.
b_k_size
=
b
.
size
(
0
);
args
.
b_k_stride
=
b
.
stride
(
0
);
args
.
b_n_size
=
b
.
size
(
1
);
args
.
b_n_stride
=
b
.
stride
(
1
);
args
.
b_ptr
=
b
.
data_ptr
();
VLLM_DISPATCH_FLOATING_TYPES
(
b
.
scalar_type
(),
"create_onednn_mm_handler"
,
[
&
]
{
args
.
c_type
=
get_dnnl_type
<
scalar_t
>
();
args
.
ab_type
=
get_dnnl_type
<
scalar_t
>
();
});
return
reinterpret_cast
<
int64_t
>
(
new
MatMulPrimitiveHandler
(
args
));
}
void
onednn_mm
(
torch
::
Tensor
&
c
,
// [M, OC], row-major
const
torch
::
Tensor
&
a
,
// [M, IC], row-major
const
std
::
optional
<
torch
::
Tensor
>&
bias
,
int64_t
handler
)
{
CPU_KERNEL_GUARD_IN
(
onednn_mm
)
TORCH_CHECK
(
a
.
dim
()
==
2
);
TORCH_CHECK
(
a
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
c
.
is_contiguous
());
MatMulPrimitiveHandler
*
ptr
=
reinterpret_cast
<
MatMulPrimitiveHandler
*>
(
handler
);
MatMulPrimitiveHandler
::
ExecArgs
exec_args
;
exec_args
.
a_m_size
=
a
.
size
(
0
);
exec_args
.
a_m_stride
=
a
.
stride
(
0
);
VLLM_DISPATCH_FLOATING_TYPES
(
a
.
scalar_type
(),
"onednn_mm"
,
[
&
]
{
if
(
bias
.
has_value
())
{
exec_args
.
use_bias
=
true
;
exec_args
.
bias_type
=
get_dnnl_type
<
scalar_t
>
();
exec_args
.
bias_ptr
=
bias
->
data_ptr
<
scalar_t
>
();
}
else
{
exec_args
.
use_bias
=
false
;
exec_args
.
bias_type
=
get_dnnl_type
<
void
>
();
exec_args
.
bias_ptr
=
nullptr
;
}
exec_args
.
a_ptr
=
a
.
data_ptr
<
scalar_t
>
();
exec_args
.
c_ptr
=
c
.
data_ptr
<
scalar_t
>
();
ptr
->
execute
(
exec_args
);
});
}
csrc/cpu/torch_bindings.cpp
View file @
57b1ce94
...
...
@@ -21,6 +21,12 @@ void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
const
std
::
optional
<
torch
::
Tensor
>&
bias
,
int64_t
handler
);
int64_t
create_onednn_mm_handler
(
const
torch
::
Tensor
&
b
,
int64_t
primitive_cache_size
);
void
onednn_mm
(
torch
::
Tensor
&
c
,
const
torch
::
Tensor
&
a
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
,
int64_t
handler
);
void
mla_decode_kvcache
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
kv_cache
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
);
...
...
@@ -153,6 +159,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"release_dnnl_matmul_handler(int handler) -> ()"
,
&
release_dnnl_matmul_handler
);
// Create oneDNN GEMM handler
ops
.
def
(
"create_onednn_mm_handler(Tensor b, int "
"primitive_cache_size) -> int"
,
&
create_onednn_mm_handler
);
// oneDNN GEMM
ops
.
def
(
"onednn_mm(Tensor! c, Tensor a, Tensor? bias, "
"int handler) -> ()"
);
ops
.
impl
(
"onednn_mm"
,
torch
::
kCPU
,
&
onednn_mm
);
// Create oneDNN W8A8 handler
ops
.
def
(
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
...
...
tests/kernels/test_onednn.py
View file @
57b1ce94
...
...
@@ -111,6 +111,49 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int,
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
def
onednn_gemm_test_helper
(
primitive_cache_size
:
int
,
m
:
int
,
n
:
int
,
k
:
int
,
use_bias
:
bool
,
use_stride
:
bool
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
device
:
str
=
"cpu"
):
if
use_stride
:
a
=
torch
.
rand
((
m
,
2
*
k
),
dtype
=
dtype
,
device
=
device
)
*
1.5
a
=
a
[:,
:
k
]
else
:
a
=
torch
.
rand
((
m
,
k
),
dtype
=
dtype
,
device
=
device
)
*
1.5
b
=
torch
.
rand
((
n
,
k
),
dtype
=
dtype
,
device
=
device
)
*
1.5
if
use_bias
:
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
dtype
)
*
5
bias_f32
=
bias
.
float
()
else
:
bias
=
None
bias_f32
=
None
handler
=
ops
.
create_onednn_mm
(
b
.
t
(),
primitive_cache_size
,
)
out
=
ops
.
onednn_mm
(
handler
,
a
,
bias
)
baseline
=
torch
.
nn
.
functional
.
linear
(
a
.
float
(),
b
.
float
(),
bias_f32
).
to
(
dtype
=
a
.
dtype
)
torch
.
testing
.
assert_close
(
out
,
baseline
)
if
use_bias
:
# To test runtime bias setting
out
=
ops
.
onednn_mm
(
handler
,
a
,
None
)
baseline
=
torch
.
nn
.
functional
.
linear
(
a
.
float
(),
b
.
float
(),
None
).
to
(
dtype
=
a
.
dtype
)
torch
.
testing
.
assert_close
(
out
,
baseline
)
@
pytest
.
mark
.
parametrize
(
"n,k"
,
NK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"m_list"
,
M_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"per_tensor_a_scale"
,
[
True
,
False
])
...
...
@@ -142,3 +185,30 @@ def test_onednn_int8_scaled_gemm(
use_azp
=
use_azp
,
out_dtype
=
output_type
,
)
@
pytest
.
mark
.
parametrize
(
"n,k"
,
NK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"m_list"
,
M_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_stride"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"primitive_cache_size"
,
CACHE_SIZES
)
def
test_onednn_gemm
(
n
:
int
,
k
:
int
,
m_list
:
tuple
[
int
],
use_bias
:
bool
,
use_stride
:
bool
,
dtype
:
torch
.
dtype
,
primitive_cache_size
:
int
,
):
for
m
in
m_list
:
onednn_gemm_test_helper
(
primitive_cache_size
=
primitive_cache_size
,
m
=
m
,
n
=
n
,
k
=
k
,
use_bias
=
use_bias
,
use_stride
=
use_stride
,
dtype
=
dtype
,
)
vllm/_custom_ops.py
View file @
57b1ce94
...
...
@@ -1928,6 +1928,35 @@ class CPUDNNLGEMMHandler:
torch
.
ops
.
_C
.
release_dnnl_matmul_handler
(
self
.
handler
)
if
hasattr
(
torch
.
ops
.
_C
,
"create_onednn_mm_handler"
):
_supports_onednn
=
True
else
:
_supports_onednn
=
False
def
create_onednn_mm
(
weight
:
torch
.
Tensor
,
# [K, N]
primitive_cache_size
:
int
=
128
,
)
->
CPUDNNLGEMMHandler
:
handler
=
CPUDNNLGEMMHandler
()
handler
.
k
,
handler
.
n
=
weight
.
size
()
handler
.
handler
=
torch
.
ops
.
_C
.
create_onednn_mm_handler
(
weight
,
primitive_cache_size
)
return
handler
def
onednn_mm
(
dnnl_handler
:
CPUDNNLGEMMHandler
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
output
=
torch
.
empty
((
*
x
.
shape
[
0
:
-
1
],
dnnl_handler
.
n
),
dtype
=
x
.
dtype
)
torch
.
ops
.
_C
.
onednn_mm
(
output
,
x
.
reshape
(
-
1
,
dnnl_handler
.
k
),
bias
,
dnnl_handler
.
handler
)
return
output
def
create_onednn_scaled_mm
(
weight
:
torch
.
Tensor
,
# [K, N]
weight_scales
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/linear.py
View file @
57b1ce94
...
...
@@ -9,7 +9,6 @@ import torch
import
torch.nn
as
nn
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm
import
envs
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
...
...
@@ -200,26 +199,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# special postprocessing for CPU SGL
if
current_platform
.
is_cpu
()
and
envs
.
VLLM_CPU_SGL_KERNEL
:
from
vllm.model_executor.layers.utils
import
check_cpu_sgl_kernel
N
,
K
=
layer
.
weight
.
size
()
dtype
=
layer
.
weight
.
dtype
if
check_cpu_sgl_kernel
(
N
,
K
,
dtype
):
packed_weight
=
torch
.
ops
.
_C
.
convert_weight_packed
(
layer
.
weight
)
assert
packed_weight
.
size
()
==
layer
.
weight
.
size
()
layer
.
weight
.
copy_
(
packed_weight
)
if
layer
.
bias
is
not
None
:
layer
.
bias
=
Parameter
(
layer
.
bias
.
to
(
torch
.
float32
),
requires_grad
=
False
)
layer
.
use_cpu_sgl
=
True
else
:
logger
.
warning
(
"CPU SGL kernels require Intel AMX support,"
" bf16/fp16/int8 weight, IC and OC are divisible by "
"32 and 16."
)
layer
.
use_cpu_sgl
=
False
if
current_platform
.
is_cpu
():
from
vllm.model_executor.layers.utils
import
(
dispatch_cpu_unquantized_gemm
)
dispatch_cpu_unquantized_gemm
(
layer
,
remove_weight
=
True
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/layers/utils.py
View file @
57b1ce94
...
...
@@ -142,20 +142,49 @@ direct_register_custom_op(
)
def
check_cpu_sgl_kernel
(
n
:
int
,
k
:
int
,
dtype
:
torch
.
dtype
):
def
check_cpu_sgl_kernel
(
n
:
int
,
k
:
int
,
dtype
:
torch
.
dtype
)
->
bool
:
return
(
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
and
(
dtype
in
(
torch
.
bfloat16
,
torch
.
int8
))
and
k
%
32
==
0
and
n
%
16
==
0
)
def
dispatch_cpu_unquantized_gemm
(
layer
:
torch
.
nn
.
Module
,
remove_weight
:
bool
,
)
->
None
:
N
,
K
=
layer
.
weight
.
size
()
dtype
=
layer
.
weight
.
dtype
if
envs
.
VLLM_CPU_SGL_KERNEL
and
check_cpu_sgl_kernel
(
N
,
K
,
dtype
):
packed_weight
=
torch
.
ops
.
_C
.
convert_weight_packed
(
layer
.
weight
)
if
getattr
(
layer
,
"bias"
,
None
)
is
not
None
:
bias_f32
=
layer
.
bias
.
to
(
torch
.
float32
)
else
:
bias_f32
=
None
layer
.
cpu_linear
=
(
lambda
x
,
weight
,
bias
:
torch
.
ops
.
_C
.
weight_packed_linear
(
x
,
packed_weight
,
bias_f32
if
bias
is
not
None
else
None
,
True
))
if
remove_weight
:
layer
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
)
elif
ops
.
_supports_onednn
:
origin_weight
=
layer
.
weight
if
remove_weight
:
layer
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
)
handler
=
ops
.
create_onednn_mm
(
origin_weight
.
t
(),
32
)
layer
.
cpu_linear
=
lambda
x
,
weight
,
bias
:
ops
.
onednn_mm
(
handler
,
x
,
bias
)
else
:
layer
.
cpu_linear
=
lambda
x
,
weight
,
bias
:
torch
.
nn
.
functional
.
linear
(
x
,
weight
,
bias
)
def
cpu_unquantized_gemm
(
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
if
getattr
(
layer
,
"use_cpu_sgl"
,
False
):
return
torch
.
ops
.
_C
.
weight_packed_linear
(
x
,
weight
,
bias
,
True
)
else
:
return
torch
.
nn
.
functional
.
linear
(
x
,
weight
,
bias
)
return
layer
.
cpu_linear
(
x
,
weight
,
bias
)
def
dispatch_unquantized_gemm
()
->
Callable
[...,
torch
.
Tensor
]:
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
57b1ce94
...
...
@@ -40,6 +40,12 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
current_platform
.
is_cpu
():
from
vllm.model_executor.layers.utils
import
(
dispatch_cpu_unquantized_gemm
)
dispatch_cpu_unquantized_gemm
(
layer
,
remove_weight
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment