Unverified Commit 11b6af52 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][Bugfix] Fix Mamba batched decode producing incorrect output (#32099)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 2a719e08
......@@ -34,6 +34,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
......@@ -195,10 +196,11 @@ class MambaMixer(MambaBase, CustomOp):
def _ssm_transform(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.is_lora_enabled:
# Lora kernel requires contiguous tensor.
ssm_params = self.x_proj(x.contiguous())[0]
else:
# LoRA kernel requires contiguous tensor.
# ROCm: Non-contiguous tensors cause incorrect GEMM
# results when batch > 1.
if self.is_lora_enabled or current_platform.is_rocm():
x = x.contiguous()
ssm_params = self.x_proj(x)[0]
time_step, B, C = torch.split(
ssm_params,
......
......@@ -63,6 +63,7 @@ from vllm.model_executor.models.utils import (
maybe_prefix,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
......@@ -414,6 +415,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_state_indices=state_indices_tensor_d,
)
# ROCm: Ensure contiguous tensor for bcdt_proj linear layer.
# causal_conv1d_update returns a non-contiguous view (stride 8192
# instead of 4096 for shape [batch, 4096]), causing incorrect GEMM
# results when batch > 1 on ROCm.
if current_platform.is_rocm():
hidden_states_d = hidden_states_d.contiguous()
B, C, dt = self._project_ssm_parameters(hidden_states_d)
# 3. State Space Model sequence transformation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment