Unverified Commit 3e472e81 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][Bugfix][CI] Fix hybrid models and their tests (Mamba/Jamba/Bamba) (#32710)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
Signed-off-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
parent 038914b7
...@@ -8,6 +8,7 @@ import pytest ...@@ -8,6 +8,7 @@ import pytest
from tests.models.registry import HF_EXAMPLE_MODELS from tests.models.registry import HF_EXAMPLE_MODELS
from tests.utils import multi_gpu_test from tests.utils import multi_gpu_test
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
...@@ -577,6 +578,10 @@ def test_apc_multiple_prompts_all_cached_outputs( ...@@ -577,6 +578,10 @@ def test_apc_multiple_prompts_all_cached_outputs(
model, max_model_len, tensor_parallel_size=tensor_parallel_size model, max_model_len, tensor_parallel_size=tensor_parallel_size
) )
vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
# Reduce the effects of batch variance on ROCm since batch invariance is not
# yet supported. See: https://github.com/vllm-project/vllm/issues/27433
if current_platform.is_rocm():
vllm_runner_kwargs["max_num_seqs"] = 4
vllm_outputs_no_cache, _ = _get_vLLM_output( vllm_outputs_no_cache, _ = _get_vLLM_output(
vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
......
...@@ -214,6 +214,12 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -214,6 +214,12 @@ class MambaMixer(MambaBase, CustomOp):
time_step = self.dt_layernorm(time_step.contiguous()) time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous()) B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous()) C = self.c_layernorm(C.contiguous())
# ROCm: tensor from split is non-contiguous, causing incorrect
# GEMM results in dt_proj.
if current_platform.is_rocm():
time_step = time_step.contiguous()
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
return discrete_time_step, B, C return discrete_time_step, B, C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment