Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
11b6af52
Unverified
Commit
11b6af52
authored
Jan 12, 2026
by
Andreas Karatzas
Committed by
GitHub
Jan 13, 2026
Browse files
[ROCm][Bugfix] Fix Mamba batched decode producing incorrect output (#32099)
Signed-off-by:
Andreas Karatzas
<
akaratza@amd.com
>
parent
2a719e08
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
5 deletions
+15
-5
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+7
-5
vllm/model_executor/models/plamo2.py
vllm/model_executor/models/plamo2.py
+8
-0
No files found.
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
11b6af52
...
...
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.mamba1_attn
import
Mamba1AttentionMetadata
...
...
@@ -195,10 +196,11 @@ class MambaMixer(MambaBase, CustomOp):
def
_ssm_transform
(
self
,
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
is_lora_enabled
:
# Lora kernel requires contiguous tensor.
ssm_params
=
self
.
x_proj
(
x
.
contiguous
())[
0
]
else
:
# LoRA kernel requires contiguous tensor.
# ROCm: Non-contiguous tensors cause incorrect GEMM
# results when batch > 1.
if
self
.
is_lora_enabled
or
current_platform
.
is_rocm
():
x
=
x
.
contiguous
()
ssm_params
=
self
.
x_proj
(
x
)[
0
]
time_step
,
B
,
C
=
torch
.
split
(
ssm_params
,
...
...
vllm/model_executor/models/plamo2.py
View file @
11b6af52
...
...
@@ -63,6 +63,7 @@ from vllm.model_executor.models.utils import (
maybe_prefix
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backend
import
AttentionMetadata
...
...
@@ -414,6 +415,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_state_indices
=
state_indices_tensor_d
,
)
# ROCm: Ensure contiguous tensor for bcdt_proj linear layer.
# causal_conv1d_update returns a non-contiguous view (stride 8192
# instead of 4096 for shape [batch, 4096]), causing incorrect GEMM
# results when batch > 1 on ROCm.
if
current_platform
.
is_rocm
():
hidden_states_d
=
hidden_states_d
.
contiguous
()
B
,
C
,
dt
=
self
.
_project_ssm_parameters
(
hidden_states_d
)
# 3. State Space Model sequence transformation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment