Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6a515304
Unverified
Commit
6a515304
authored
Aug 05, 2025
by
Michael Goin
Committed by
GitHub
Aug 06, 2025
Browse files
[Bugfix] Fix 3D input passed into cutlass_scaled_mm (#22278)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
35509fc5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
15 deletions
+20
-15
vllm/_custom_ops.py
vllm/_custom_ops.py
+20
-15
No files found.
vllm/_custom_ops.py
View file @
6a515304
...
...
@@ -710,23 +710,25 @@ def cutlass_scaled_mm(a: torch.Tensor,
scale_b.shape * [128, 128] == b.shape
"""
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
shape
[
0
]
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
assert
bias
is
None
or
bias
.
numel
(
)
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
1
]
# Massage the input to be 2D
target_shape
=
(
*
a
.
shape
[:
-
1
],
b
.
shape
[
1
])
a
=
a
.
view
(
-
1
,
a
.
shape
[
-
1
])
cutlass_compatible_b
=
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
if
current_platform
.
is_rocm
()
or
not
cutlass_compatible_b
:
from
vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm
import
(
# noqa
triton_scaled_mm
)
return
triton_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
)
out
=
triton_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
else
:
out
=
torch
.
empty
((
a
.
shape
[
0
],
b
.
shape
[
1
]),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
)
return
out
return
out
.
view
(
*
target_shape
)
def
cutlass_scaled_mm_azp
(
a
:
torch
.
Tensor
,
...
...
@@ -746,15 +748,18 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
numel
(
)
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
assert
azp
is
None
or
azp
.
numel
()
==
a
.
shape
[
0
]
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
1
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
# Massage the input to be 2D
target_shape
=
(
*
a
.
shape
[:
-
1
],
b
.
shape
[
1
])
a
=
a
.
view
(
-
1
,
a
.
shape
[
-
1
])
assert
azp
is
None
or
azp
.
numel
()
==
a
.
shape
[
0
]
out
=
torch
.
empty
((
a
.
shape
[
0
],
b
.
shape
[
1
]),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
(
out
,
a
,
b
,
scale_a
,
scale_b
,
azp_adj
,
azp
,
bias
)
return
out
return
out
.
view
(
*
target_shape
)
def
cutlass_sparse_scaled_mm_supported
(
cuda_device_capability
:
int
)
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment