Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8e6293e8
Unverified
Commit
8e6293e8
authored
Mar 30, 2026
by
roikoren755
Committed by
GitHub
Mar 30, 2026
Browse files
[Mamba] Add stochastic rounding support (#35753)
Signed-off-by:
Roi Koren
<
roik@nvidia.com
>
parent
dbdd9ae0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
166 additions
and
2 deletions
+166
-2
tests/kernels/mamba/test_mamba_ssm.py
tests/kernels/mamba/test_mamba_ssm.py
+54
-0
vllm/config/cache.py
vllm/config/cache.py
+35
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+13
-0
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+2
-0
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+2
-0
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+58
-1
vllm/model_executor/models/plamo2.py
vllm/model_executor/models/plamo2.py
+2
-0
No files found.
tests/kernels/mamba/test_mamba_ssm.py
View file @
8e6293e8
...
...
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn
,
selective_state_update
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
...
...
@@ -429,6 +430,59 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"philox_rounds"
,
[
0
,
4
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
4096
])
@
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability_family
(
100
)
),
reason
=
"Stochastic rounding in triton is only supported"
" on compute capability 10.0 CUDA devices."
,
)
def
test_selective_state_update_stochastic_rounding
(
dim
,
dstate
,
has_z
,
philox_rounds
):
device
=
"cuda"
rtol
,
atol
=
5e-3
,
1e-1
# set seed
set_random_seed
(
0
)
batch_size
=
1
state
=
torch
.
randn
(
batch_size
,
dim
,
dstate
,
dtype
=
torch
.
float16
,
device
=
device
)
x
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
out
=
torch
.
empty_like
(
x
)
dt
=
torch
.
randn
(
batch_size
,
dim
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
batch_size
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
batch_size
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
# Reference uses fp32 state to get ground truth
state_ref
=
state
.
float
()
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
out
=
out
,
enable_stochastic_rounding
=
True
,
cache_philox_rounds
=
philox_rounds
,
)
out_ref
=
selective_state_update_ref
(
state_ref
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
)
assert
state
.
dtype
==
torch
.
float16
assert
torch
.
allclose
(
state
,
state_ref
.
to
(
torch
.
float16
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
64
])
...
...
vllm/config/cache.py
View file @
8e6293e8
...
...
@@ -115,6 +115,14 @@ class CacheConfig:
- "align": only cache the mamba state of the last token of each scheduler step and
when the token is at position i * block_size.
"""
enable_mamba_cache_stochastic_rounding
:
bool
=
False
"""Enable stochastic rounding when writing SSM state to fp16 cache.
Uses random bits to unbias the rounding error, which can improve
numerical stability for long sequences."""
mamba_cache_philox_rounds
:
int
=
0
"""Number of Philox PRNG rounds for stochastic rounding random number
generation. 0 uses the Triton default. Higher values improve randomness
quality at the cost of compute."""
# Will be set after profiling.
num_gpu_blocks
:
int
|
None
=
field
(
default
=
None
,
init
=
False
)
...
...
@@ -231,3 +239,29 @@ class CacheConfig:
"scaling factor."
)
return
cache_dtype
def
__post_init__
(
self
):
if
self
.
enable_mamba_cache_stochastic_rounding
:
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_cuda
():
raise
ValueError
(
"Stochastic rounding for Mamba cache is only supported "
"on NVIDIA CUDA platforms. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if
not
current_platform
.
is_device_capability_family
(
100
):
raise
ValueError
(
"Stochastic rounding for Mamba cache requires compute "
"capability 10.0 (data center Blackwell). The `cvt.rs` PTX "
"instruction is not supported on your GPU. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if
self
.
mamba_ssm_cache_dtype
!=
"float16"
:
raise
ValueError
(
"Stochastic rounding for Mamba cache requires "
"the SSM cache to be float16. Please set it explicitly, "
"by specifying `--mamba-ssm-cache-dtype float16`, or disable "
"stochastic rounding by not specifying "
"`--enable-mamba-cache-stochastic-rounding`."
)
vllm/engine/arg_utils.py
View file @
8e6293e8
...
...
@@ -604,6 +604,10 @@ class EngineArgs:
mamba_ssm_cache_dtype
:
MambaDType
=
CacheConfig
.
mamba_ssm_cache_dtype
mamba_block_size
:
int
|
None
=
get_field
(
CacheConfig
,
"mamba_block_size"
)
mamba_cache_mode
:
MambaCacheMode
=
CacheConfig
.
mamba_cache_mode
enable_mamba_cache_stochastic_rounding
:
bool
=
(
CacheConfig
.
enable_mamba_cache_stochastic_rounding
)
mamba_cache_philox_rounds
:
int
=
CacheConfig
.
mamba_cache_philox_rounds
additional_config
:
dict
[
str
,
Any
]
=
get_field
(
VllmConfig
,
"additional_config"
)
...
...
@@ -1024,6 +1028,13 @@ class EngineArgs:
cache_group
.
add_argument
(
"--mamba-cache-mode"
,
**
cache_kwargs
[
"mamba_cache_mode"
]
)
cache_group
.
add_argument
(
"--enable-mamba-cache-stochastic-rounding"
,
**
cache_kwargs
[
"enable_mamba_cache_stochastic_rounding"
],
)
cache_group
.
add_argument
(
"--mamba-cache-philox-rounds"
,
**
cache_kwargs
[
"mamba_cache_philox_rounds"
]
)
cache_group
.
add_argument
(
"--kv-offloading-size"
,
**
cache_kwargs
[
"kv_offloading_size"
]
)
...
...
@@ -1590,6 +1601,8 @@ class EngineArgs:
mamba_ssm_cache_dtype
=
self
.
mamba_ssm_cache_dtype
,
mamba_block_size
=
self
.
mamba_block_size
,
mamba_cache_mode
=
self
.
mamba_cache_mode
,
enable_mamba_cache_stochastic_rounding
=
self
.
enable_mamba_cache_stochastic_rounding
,
mamba_cache_philox_rounds
=
self
.
mamba_cache_philox_rounds
,
kv_offloading_size
=
self
.
kv_offloading_size
,
kv_offloading_backend
=
self
.
kv_offloading_backend
,
)
...
...
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
8e6293e8
...
...
@@ -428,6 +428,8 @@ class MambaMixer(MambaBase, PluggableLayer):
state_batch_indices
=
state_indices_tensor_d_input
,
dst_state_batch_indices
=
state_indices_tensor_d_output
,
out
=
scan_outputs_d
,
enable_stochastic_rounding
=
self
.
cache_config
.
enable_mamba_cache_stochastic_rounding
,
cache_philox_rounds
=
self
.
cache_config
.
mamba_cache_philox_rounds
,
)
scan_outputs_d
=
scan_outputs_d
.
transpose
(
0
,
1
)
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
8e6293e8
...
...
@@ -888,6 +888,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
num_accepted_tokens
=
num_accepted_tokens
,
cu_seqlens
=
query_start_loc_d
,
is_blackwell
=
self
.
is_blackwell
,
enable_stochastic_rounding
=
self
.
cache_config
.
enable_mamba_cache_stochastic_rounding
,
cache_philox_rounds
=
self
.
cache_config
.
mamba_cache_philox_rounds
,
)
def
get_state_dtype
(
self
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
...
...
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
8e6293e8
...
...
@@ -28,6 +28,21 @@ else:
return
dt
@
triton
.
jit
def
convert_rs_fp16x2
(
x
:
tl
.
tensor
,
rand
:
tl
.
tensor
)
->
tl
.
tensor
:
y
=
tl
.
inline_asm_elementwise
(
asm
=
"""{
cvt.rs.f16x2.f32 $0, $2, $1, $3;
}"""
,
constraints
=
"=r,r,r,r,r"
,
args
=
(
x
,
rand
),
dtype
=
tl
.
float16
,
is_pure
=
True
,
pack
=
2
,
)
return
y
@
triton
.
heuristics
({
"HAS_DT_BIAS"
:
lambda
args
:
args
[
"dt_bias_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_D"
:
lambda
args
:
args
[
"D_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_Z"
:
lambda
args
:
args
[
"z_ptr"
]
is
not
None
})
...
...
@@ -48,6 +63,7 @@ else:
def
_selective_scan_update_kernel
(
# Pointers to matrices
state_ptr
,
rand_seed_ptr
,
x_ptr
,
dt_ptr
,
dt_bias_ptr
,
...
...
@@ -113,6 +129,8 @@ def _selective_scan_update_kernel(
IS_SPEC_DECODING
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
BLOCK_SIZE_DSTATE
:
tl
.
constexpr
,
USE_RS_ROUNDING
:
tl
.
constexpr
,
PHILOX_ROUNDS
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
...
...
@@ -267,7 +285,35 @@ def _selective_scan_update_kernel(
z_ptr
+=
stride_z_batch
if
not
IS_SPEC_DECODING
:
tl
.
store
(
dst_state_ptrs
,
state
.
to
(
dst_state_ptrs
.
dtype
.
element_ty
),
mask
=
mask
)
if
USE_RS_ROUNDING
:
# Load random seed
rand_seed
=
tl
.
load
(
rand_seed_ptr
)
# Generate random offsets for each element in state
if
HAS_STATE_BATCH_INDICES
:
rand_offsets
=
(
state_batch_idx
*
stride_state_batch
+
pid_h
*
stride_state_head
)
else
:
rand_offsets
=
pid_b
*
stride_state_batch
+
pid_h
*
stride_state_head
rand_offsets
+=
(
offs_m
[:,
None
]
*
stride_state_dim
+
offs_n
[
None
,
:]
*
stride_state_dstate
)
# Generate random 32-bits for each element in state
if
PHILOX_ROUNDS
>
0
:
rand
=
tl
.
randint
(
rand_seed
,
rand_offsets
,
PHILOX_ROUNDS
)
else
:
rand
=
tl
.
randint
(
rand_seed
,
rand_offsets
)
# Convert state to fp16 with RS rounding
state
=
convert_rs_fp16x2
(
state
,
rand
)
tl
.
static_assert
(
state
.
dtype
==
tl
.
float16
,
"state must be fp16"
)
tl
.
static_assert
(
dst_state_ptrs
.
dtype
.
element_ty
==
tl
.
float16
,
"dst_state_ptrs must be fp16"
,
)
else
:
state
=
state
.
to
(
dst_state_ptrs
.
dtype
.
element_ty
)
tl
.
store
(
dst_state_ptrs
,
state
,
mask
=
mask
)
def
selective_state_update
(
...
...
@@ -288,6 +334,8 @@ def selective_state_update(
num_accepted_tokens
=
None
,
cu_seqlens
=
None
,
is_blackwell
=
False
,
enable_stochastic_rounding
=
False
,
cache_philox_rounds
=
0
,
):
"""
Argument:
...
...
@@ -419,9 +467,16 @@ def selective_state_update(
and
dt
.
stride
(
-
1
)
==
0
and
dt_bias
.
stride
(
-
1
)
==
0
)
rand_seed
=
(
torch
.
randint
(
0
,
2
**
32
,
(
1
,),
device
=
state
.
device
)
if
enable_stochastic_rounding
else
None
)
with
torch
.
accelerator
.
device_index
(
x
.
device
.
index
):
_selective_scan_update_kernel
[
grid
](
state
,
rand_seed
,
x
,
dt
,
dt_bias
,
...
...
@@ -476,6 +531,8 @@ def selective_state_update(
tie_hdim
,
BLOCK_SIZE_M
,
num_warps
=
num_warps
,
USE_RS_ROUNDING
=
enable_stochastic_rounding
,
PHILOX_ROUNDS
=
cache_philox_rounds
,
)
...
...
vllm/model_executor/models/plamo2.py
View file @
8e6293e8
...
...
@@ -445,6 +445,8 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
dt_softplus
=
True
,
state_batch_indices
=
state_indices_tensor_d
,
out
=
preallocated_ssm_out_d
.
view
(
num_decodes
,
-
1
,
self
.
head_dim
),
enable_stochastic_rounding
=
self
.
cache_config
.
enable_mamba_cache_stochastic_rounding
,
cache_philox_rounds
=
self
.
cache_config
.
mamba_cache_philox_rounds
,
)
# 4. Final linear projection
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment