Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b8e809a0
Unverified
Commit
b8e809a0
authored
Jun 11, 2025
by
artetaout
Committed by
GitHub
Jun 11, 2025
Browse files
[Kernel] Support deep_gemm for linear methods (#19085)
Signed-off-by:
artetaout
<
lulala341@gmail.com
>
parent
5039ec23
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
124 additions
and
1 deletion
+124
-1
vllm/model_executor/layers/quantization/deepgemm.py
vllm/model_executor/layers/quantization/deepgemm.py
+84
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+1
-0
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+39
-1
No files found.
vllm/model_executor/layers/quantization/deepgemm.py
0 → 100644
View file @
b8e809a0
# SPDX-License-Identifier: Apache-2.0
import
importlib.util
import
logging
import
torch
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
from
vllm.utils
import
direct_register_custom_op
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
if
has_deep_gemm
:
import
deep_gemm
logger
=
logging
.
getLogger
(
__name__
)
def
prepare_block_fp8_matmul_inputs
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
list
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
tuple
[
int
,
int
,
int
,
torch
.
Tensor
]:
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
assert
A
.
is_contiguous
()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
assert
B
.
is_contiguous
()
assert
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,
)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
return
M
,
N
,
K
,
C
def
w8a8_block_fp8_matmul_deepgemm
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
list
[
int
],
output_dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
M
,
N
,
K
,
C
=
prepare_block_fp8_matmul_inputs
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
)
# Deepgemm only supports output tensor type as bfloat16
assert
C
.
dtype
==
torch
.
bfloat16
deep_gemm
.
gemm_fp8_fp8_bf16_nt
((
A
,
As
),
(
B
,
Bs
),
C
)
return
C
def
w8a8_block_fp8_matmul_deepgemm_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
list
[
int
],
output_dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
M
,
N
,
K
,
C
=
prepare_block_fp8_matmul_inputs
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
)
return
C
direct_register_custom_op
(
op_name
=
"w8a8_block_fp8_matmul_deepgemm"
,
op_func
=
w8a8_block_fp8_matmul_deepgemm
,
mutates_args
=
[],
fake_impl
=
w8a8_block_fp8_matmul_deepgemm_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
vllm/model_executor/layers/quantization/fp8.py
View file @
b8e809a0
...
...
@@ -402,6 +402,7 @@ class Fp8LinearMethod(LinearMethodBase):
if
self
.
block_quant
:
assert
self
.
quant_config
.
weight_block_size
is
not
None
return
torch
.
ops
.
vllm
.
apply_w8a8_block_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
b8e809a0
...
...
@@ -3,12 +3,14 @@
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import
functools
import
importlib.util
import
json
import
os
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
...
@@ -20,6 +22,7 @@ from vllm.triton_utils import tl, triton
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
def
is_fp8
(
x
:
Union
[
torch
.
dtype
,
torch
.
Tensor
])
->
bool
:
...
...
@@ -98,6 +101,19 @@ def dispatch_w8a8_blockscale_func(
return
w8a8_block_fp8_matmul
def
should_use_deepgemm
(
output_dtype
:
torch
.
dtype
,
weight
:
torch
.
Tensor
):
"""
Check if DeepGEMM should be used based on the output dtype and weight shape.
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
divisible by 128.
"""
return
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
90
)
and
has_deep_gemm
and
envs
.
VLLM_USE_DEEP_GEMM
and
output_dtype
==
torch
.
bfloat16
and
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
)
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def
apply_w8a8_block_fp8_linear
(
...
...
@@ -114,6 +130,29 @@ def apply_w8a8_block_fp8_linear(
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
output_dtype
=
input
.
dtype
if
should_use_deepgemm
(
output_dtype
,
weight
):
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
True
,
)
output
=
torch
.
ops
.
vllm
.
w8a8_block_fp8_matmul_deepgemm
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
output_dtype
)
if
bias
is
not
None
:
output
+=
bias
return
output
.
to
(
dtype
=
output_dtype
).
view
(
*
output_shape
)
if
current_platform
.
is_cuda
():
if
current_platform
.
has_device_capability
(
100
):
...
...
@@ -134,7 +173,6 @@ def apply_w8a8_block_fp8_linear(
w8a8_blockscale_func
=
dispatch_w8a8_blockscale_func
(
use_cutlass
,
use_aiter_and_is_supported
)
if
use_cutlass
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
use_cutlass
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment