Unverified Commit 9c3c21c5 authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[CI] fix mamba kernel test (#26250)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
parent 512b8aff
...@@ -477,6 +477,7 @@ steps: ...@@ -477,6 +477,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- csrc/mamba/ - csrc/mamba/
- tests/kernels/mamba - tests/kernels/mamba
- vllm/model_executor/layers/mamba/ops
commands: commands:
- pytest -v -s kernels/mamba - pytest -v -s kernels/mamba
......
...@@ -165,7 +165,17 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity ...@@ -165,7 +165,17 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone() conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=conv_state_indices,
)
out_ref = causal_conv1d_update_ref( out_ref = causal_conv1d_update_ref(
x_ref, conv_state_ref, weight, bias, activation=activation x_ref, conv_state_ref, weight, bias, activation=activation
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment