Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
58631d7c
Unverified
Commit
58631d7c
authored
Apr 20, 2026
by
nemanjaudovic
Committed by
GitHub
Apr 20, 2026
Browse files
[Bugfix] Fix scaled_mm output narrowing for 3D input tensors (#38093)
Signed-off-by:
nemanjaudovic
<
nudovic@amd.com
>
parent
a943839e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
4 deletions
+15
-4
vllm/model_executor/kernels/linear/scaled_mm/pytorch.py
vllm/model_executor/kernels/linear/scaled_mm/pytorch.py
+15
-4
No files found.
vllm/model_executor/kernels/linear/scaled_mm/pytorch.py
View file @
58631d7c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
torch
import
torch
...
@@ -13,6 +14,13 @@ from .ScaledMMLinearKernel import (
...
@@ -13,6 +14,13 @@ from .ScaledMMLinearKernel import (
)
)
def
_get_num_tokens
(
output_shape
:
list
)
->
int
:
# torch._scaled_mm works with 2D tensors, so input tensors are
# flattened if they are 3D. If output_shape is 3D, num_tokens is
# the product of all dims except the last (hidden dim).
return
math
.
prod
(
output_shape
[:
-
1
])
class
TorchFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
class
TorchFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
"""
"""
Base class for FP8 linear kernels using Torch.
Base class for FP8 linear kernels using Torch.
...
@@ -78,7 +86,8 @@ class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
...
@@ -78,7 +86,8 @@ class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
output
=
output
[
0
]
return
torch
.
narrow
(
output
,
0
,
0
,
output_shape
[
0
]).
view
(
*
output_shape
)
num_tokens
=
_get_num_tokens
(
output_shape
)
return
torch
.
narrow
(
output
,
0
,
0
,
num_tokens
).
view
(
*
output_shape
)
class
RowWiseTorchFP8ScaledMMLinearKernel
(
TorchFP8ScaledMMLinearKernel
):
class
RowWiseTorchFP8ScaledMMLinearKernel
(
TorchFP8ScaledMMLinearKernel
):
...
@@ -145,7 +154,8 @@ class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
...
@@ -145,7 +154,8 @@ class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
bias
=
bias
,
bias
=
bias
,
)
)
return
torch
.
narrow
(
output
,
0
,
0
,
output_shape
[
0
]).
view
(
*
output_shape
)
num_tokens
=
_get_num_tokens
(
output_shape
)
return
torch
.
narrow
(
output
,
0
,
0
,
num_tokens
).
view
(
*
output_shape
)
class
ChannelWiseTorchFP8ScaledMMLinearKernel
(
TorchFP8ScaledMMLinearKernel
):
class
ChannelWiseTorchFP8ScaledMMLinearKernel
(
TorchFP8ScaledMMLinearKernel
):
...
@@ -206,8 +216,9 @@ class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
...
@@ -206,8 +216,9 @@ class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
output
=
output
[
0
]
output
=
output
[
0
]
# Unpad (undo num_token_padding)
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
output_shape
[
0
])
num_tokens
=
_get_num_tokens
(
output_shape
)
x_scale
=
torch
.
narrow
(
As
,
0
,
0
,
output_shape
[
0
])
output
=
torch
.
narrow
(
output
,
0
,
0
,
num_tokens
)
x_scale
=
torch
.
narrow
(
As
,
0
,
0
,
num_tokens
)
# DQ
# DQ
# C = sw * sx * (X * W) + bias
# C = sw * sx * (X * W) + bias
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment