Unverified Commit 66072b36 authored by Asaf Joseph Gardin's avatar Asaf Joseph Gardin Committed by GitHub
Browse files

[Bugfix][Mamba] - Fix Conv State Kernel FP32 Support (#24883)


Signed-off-by: default avatarasafg <39553475+Josephasafg@users.noreply.github.com>
parent 3ed1ec4a
...@@ -418,7 +418,9 @@ def test_full_cuda_graph( ...@@ -418,7 +418,9 @@ def test_full_cuda_graph(
@pytest.mark.parametrize("model", FP32_STATE_MODELS) @pytest.mark.parametrize("model", FP32_STATE_MODELS)
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_fp32_state( @pytest.mark.parametrize("cache_dtype_param",
["mamba_ssm_cache_dtype", "mamba_cache_dtype"])
def test_fp32_cache_state(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
...@@ -426,6 +428,7 @@ def test_fp32_state( ...@@ -426,6 +428,7 @@ def test_fp32_state(
model: str, model: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
cache_dtype_param: str,
) -> None: ) -> None:
try: try:
...@@ -443,13 +446,13 @@ def test_fp32_state( ...@@ -443,13 +446,13 @@ def test_fp32_state(
m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_USE_V1", "0")
with vllm_runner(model, with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS, max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model: **{cache_dtype_param: "float32"}) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs( vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS, max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model: **{cache_dtype_param: "float32"}) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
......
...@@ -415,6 +415,9 @@ def causal_conv1d_fn( ...@@ -415,6 +415,9 @@ def causal_conv1d_fn(
activation = "silu" activation = "silu"
args = None args = None
# Store original dtype to cast back at the end
original_x_dtype = x.dtype
x = x.to(conv_states.dtype)
out = torch.empty_like(x) out = torch.empty_like(x)
if metadata is not None: if metadata is not None:
cu_seqlen = metadata.cu_seqlen cu_seqlen = metadata.cu_seqlen
...@@ -613,7 +616,7 @@ def causal_conv1d_fn( ...@@ -613,7 +616,7 @@ def causal_conv1d_fn(
BLOCK_N=256, BLOCK_N=256,
num_stages=2, num_stages=2,
) )
return out return out.to(original_x_dtype)
@triton.jit() @triton.jit()
...@@ -973,6 +976,9 @@ def causal_conv1d_update( ...@@ -973,6 +976,9 @@ def causal_conv1d_update(
activation = "silu" if activation is True else None activation = "silu" if activation is True else None
elif activation is not None: elif activation is not None:
assert activation in ["silu", "swish"] assert activation in ["silu", "swish"]
original_x_dtype = x.dtype
x = x.to(conv_state.dtype)
unsqueeze = query_start_loc is None and x.dim() == 2 unsqueeze = query_start_loc is None and x.dim() == 2
if unsqueeze: if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1 # make it (batch, dim, seqlen) with seqlen == 1
...@@ -1081,4 +1087,4 @@ def causal_conv1d_update( ...@@ -1081,4 +1087,4 @@ def causal_conv1d_update(
) )
if unsqueeze: if unsqueeze:
out = out.squeeze(-1) out = out.squeeze(-1)
return out return out.to(original_x_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment