Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
836dee3b
Commit
836dee3b
authored
Nov 11, 2025
by
wanglong3
Committed by
zhuwenwen
Nov 11, 2025
Browse files
Support blaslt w8a8 GEMM op.
parent
aefb81d8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
18 deletions
+40
-18
vllm/_custom_ops.py
vllm/_custom_ops.py
+12
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+2
-0
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+26
-18
No files found.
vllm/_custom_ops.py
View file @
836dee3b
...
...
@@ -1140,6 +1140,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
blaslt_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
0
]
k
=
a
.
shape
[
1
]
_
,
out
=
quant_ops
.
hipblaslt_w8a8_gemm
(
a
,
b
,
scale_a
,
scale_b
,
m
,
n
,
k
,
'NT'
,
out_dtype
)
return
out
def
triton_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
836dee3b
...
...
@@ -635,6 +635,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
elif
self
.
w8a8_strategy
==
3
:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
T
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
836dee3b
...
...
@@ -461,31 +461,39 @@ def apply_int8_linear(
else
:
best_config
=
None
# if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return
ops
.
triton_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
return
ops
.
triton_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
elif
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
elif
w8a8_strategy
==
3
:
# x_q: shape (m, k) stride (k, 1)
# weight: shape (n, k) stride (k, 1)
return
ops
.
blaslt_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
bias
=
None
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
return
ops
.
rocblas_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
def
normalize_e4m3fn_to_e4m3fnuz
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment