Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
157722da
Unverified
Commit
157722da
authored
Feb 27, 2026
by
Huamin Li
Committed by
GitHub
Feb 28, 2026
Browse files
[perf] Use pinned memory for async H2D transfer in do_mamba_copy_block (#35480)
Signed-off-by:
Huamin Li
<
3ericli@gmail.com
>
parent
1d897ff0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
44 deletions
+85
-44
tests/v1/e2e/test_mamba_prefix_cache.py
tests/v1/e2e/test_mamba_prefix_cache.py
+10
-10
tests/v1/worker/test_mamba_utils.py
tests/v1/worker/test_mamba_utils.py
+1
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+14
-0
vllm/v1/worker/mamba_utils.py
vllm/v1/worker/mamba_utils.py
+60
-34
No files found.
tests/v1/e2e/test_mamba_prefix_cache.py
View file @
157722da
...
@@ -325,6 +325,7 @@ def get_fake_process_mamba_fn(
...
@@ -325,6 +325,7 @@ def get_fake_process_mamba_fn(
requests
:
dict
[
str
,
CachedRequestState
],
requests
:
dict
[
str
,
CachedRequestState
],
forward_context
:
dict
[
str
,
Any
],
forward_context
:
dict
[
str
,
Any
],
mamba_state_copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
mamba_state_copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
copy_bufs
:
mamba_utils
.
MambaCopyBuffers
,
):
):
nonlocal
copy_info
nonlocal
copy_info
copy_info
=
None
copy_info
=
None
...
@@ -337,6 +338,7 @@ def get_fake_process_mamba_fn(
...
@@ -337,6 +338,7 @@ def get_fake_process_mamba_fn(
requests
,
requests
,
forward_context
,
forward_context
,
mamba_state_copy_funcs
,
mamba_state_copy_funcs
,
copy_bufs
,
)
)
if
cur_step_action
is
not
None
:
if
cur_step_action
is
not
None
:
check_copy_info
(
check_copy_info
(
...
@@ -355,6 +357,7 @@ def get_fake_process_mamba_fn(
...
@@ -355,6 +357,7 @@ def get_fake_process_mamba_fn(
mamba_state_idx
:
dict
[
str
,
int
],
mamba_state_idx
:
dict
[
str
,
int
],
forward_context
:
dict
[
str
,
Any
],
forward_context
:
dict
[
str
,
Any
],
mamba_state_copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
mamba_state_copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
copy_bufs
:
mamba_utils
.
MambaCopyBuffers
,
):
):
nonlocal
copy_info
nonlocal
copy_info
copy_info
=
None
copy_info
=
None
...
@@ -366,6 +369,7 @@ def get_fake_process_mamba_fn(
...
@@ -366,6 +369,7 @@ def get_fake_process_mamba_fn(
mamba_state_idx
,
mamba_state_idx
,
forward_context
,
forward_context
,
mamba_state_copy_funcs
,
mamba_state_copy_funcs
,
copy_bufs
,
)
)
if
cur_step_action
is
not
None
:
if
cur_step_action
is
not
None
:
check_copy_info
(
check_copy_info
(
...
@@ -376,19 +380,15 @@ def get_fake_process_mamba_fn(
...
@@ -376,19 +380,15 @@ def get_fake_process_mamba_fn(
)
)
return
ret
return
ret
def
fake_copy_fn
(
def
fake_copy_fn
(
copy_bufs
:
mamba_utils
.
MambaCopyBuffers
):
src_state_list
:
list
[
int
],
dest_state_list
:
list
[
int
],
num_elements_list
:
list
[
int
],
):
nonlocal
copy_info
nonlocal
copy_info
assert
copy_info
is
None
assert
copy_info
is
None
n
=
copy_bufs
.
offset
src_state_list
=
copy_bufs
.
src_ptrs
.
cpu
[:
n
].
tolist
()
dest_state_list
=
copy_bufs
.
dst_ptrs
.
cpu
[:
n
].
tolist
()
num_elements_list
=
copy_bufs
.
sizes
.
cpu
[:
n
].
tolist
()
copy_info
=
(
src_state_list
,
dest_state_list
,
num_elements_list
)
copy_info
=
(
src_state_list
,
dest_state_list
,
num_elements_list
)
return
original_copy_fn
(
return
original_copy_fn
(
copy_bufs
)
src_state_list
,
dest_state_list
,
num_elements_list
,
)
return
fake_preprocess_mamba_fn
,
fake_post_process_mamba_fn
,
fake_copy_fn
return
fake_preprocess_mamba_fn
,
fake_post_process_mamba_fn
,
fake_copy_fn
...
...
tests/v1/worker/test_mamba_utils.py
View file @
157722da
...
@@ -62,6 +62,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
...
@@ -62,6 +62,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
{},
{},
{},
{},
(),
(),
MagicMock
(),
)
)
assert
mamba_state_idx
==
{
"keep"
:
99
}
assert
mamba_state_idx
==
{
"keep"
:
99
}
vllm/v1/worker/gpu_model_runner.py
View file @
157722da
...
@@ -755,6 +755,7 @@ class GPUModelRunner(
...
@@ -755,6 +755,7 @@ class GPUModelRunner(
self
.
execute_model_state
:
ExecuteModelState
|
None
=
None
self
.
execute_model_state
:
ExecuteModelState
|
None
=
None
self
.
kv_connector_output
:
KVConnectorOutput
|
None
=
None
self
.
kv_connector_output
:
KVConnectorOutput
|
None
=
None
self
.
mamba_state_idx
:
dict
[
str
,
int
]
=
{}
self
.
mamba_state_idx
:
dict
[
str
,
int
]
=
{}
self
.
_mamba_copy_bufs
:
mamba_utils
.
MambaCopyBuffers
|
None
=
None
self
.
layerwise_nvtx_hooks_registered
=
False
self
.
layerwise_nvtx_hooks_registered
=
False
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
...
@@ -849,6 +850,16 @@ class GPUModelRunner(
...
@@ -849,6 +850,16 @@ class GPUModelRunner(
with_numpy
=
numpy
,
with_numpy
=
numpy
,
)
)
def
_get_mamba_copy_bufs
(
self
)
->
mamba_utils
.
MambaCopyBuffers
:
if
self
.
_mamba_copy_bufs
is
None
:
self
.
_mamba_copy_bufs
=
mamba_utils
.
MambaCopyBuffers
.
create
(
self
.
max_num_reqs
,
self
.
kv_cache_config
,
self
.
model
.
get_mamba_state_copy_func
(),
self
.
_make_buffer
,
)
return
self
.
_mamba_copy_bufs
def
_init_model_kwargs
(
self
):
def
_init_model_kwargs
(
self
):
model_kwargs
=
dict
[
str
,
Any
]()
model_kwargs
=
dict
[
str
,
Any
]()
...
@@ -1211,6 +1222,7 @@ class GPUModelRunner(
...
@@ -1211,6 +1222,7 @@ class GPUModelRunner(
self
.
mamba_state_idx
,
self
.
mamba_state_idx
,
self
.
compilation_config
.
static_forward_context
,
self
.
compilation_config
.
static_forward_context
,
self
.
model
.
get_mamba_state_copy_func
(),
self
.
model
.
get_mamba_state_copy_func
(),
self
.
_get_mamba_copy_bufs
(),
)
)
def
_update_streaming_request
(
def
_update_streaming_request
(
...
@@ -3505,6 +3517,7 @@ class GPUModelRunner(
...
@@ -3505,6 +3517,7 @@ class GPUModelRunner(
self
.
requests
,
self
.
requests
,
self
.
compilation_config
.
static_forward_context
,
self
.
compilation_config
.
static_forward_context
,
self
.
model
.
get_mamba_state_copy_func
(),
self
.
model
.
get_mamba_state_copy_func
(),
self
.
_get_mamba_copy_bufs
(),
)
)
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
...
@@ -5997,6 +6010,7 @@ class GPUModelRunner(
...
@@ -5997,6 +6010,7 @@ class GPUModelRunner(
"""
"""
kv_cache_config
=
deepcopy
(
kv_cache_config
)
kv_cache_config
=
deepcopy
(
kv_cache_config
)
self
.
kv_cache_config
=
kv_cache_config
self
.
kv_cache_config
=
kv_cache_config
self
.
_mamba_copy_bufs
=
None
self
.
may_add_encoder_only_layers_to_kv_cache_config
()
self
.
may_add_encoder_only_layers_to_kv_cache_config
()
self
.
maybe_add_kv_sharing_layers_to_kv_cache_groups
(
kv_cache_config
)
self
.
maybe_add_kv_sharing_layers_to_kv_cache_groups
(
kv_cache_config
)
self
.
initialize_attn_backend
(
kv_cache_config
)
self
.
initialize_attn_backend
(
kv_cache_config
)
...
...
vllm/v1/worker/mamba_utils.py
View file @
157722da
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
itertools
import
itertools
from
collections.abc
import
Callable
from
typing
import
Any
from
typing
import
Any
import
torch
import
torch
...
@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton
...
@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
MambaSpec
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
MambaSpec
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
from
vllm.v1.worker.lora_model_runner_mixin
import
GPUInputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
GPUInputBatch
...
@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp
...
@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp
return
mamba_group_ids
,
mamba_specs
[
0
]
return
mamba_group_ids
,
mamba_specs
[
0
]
@
dataclasses
.
dataclass
class
MambaCopyBuffers
:
src_ptrs
:
CpuGpuBuffer
dst_ptrs
:
CpuGpuBuffer
sizes
:
CpuGpuBuffer
offset
:
int
=
0
@
classmethod
def
create
(
cls
,
max_num_reqs
:
int
,
kv_cache_config
:
KVCacheConfig
,
copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
make_buffer
:
Callable
[...,
CpuGpuBuffer
],
)
->
"MambaCopyBuffers"
:
mamba_group_ids
,
_
=
get_mamba_groups
(
kv_cache_config
)
entries_per_req
=
sum
(
len
(
kv_cache_config
.
kv_cache_groups
[
gid
].
layer_names
)
for
gid
in
mamba_group_ids
)
*
len
(
copy_funcs
)
n
=
max_num_reqs
*
entries_per_req
return
cls
(
src_ptrs
=
make_buffer
(
n
,
dtype
=
torch
.
int64
),
dst_ptrs
=
make_buffer
(
n
,
dtype
=
torch
.
int64
),
sizes
=
make_buffer
(
n
,
dtype
=
torch
.
int32
),
)
def
collect_mamba_copy_meta
(
def
collect_mamba_copy_meta
(
src_state_list
:
list
[
int
],
copy_bufs
:
MambaCopyBuffers
,
dest_state_list
:
list
[
int
],
num_elements_list
:
list
[
int
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
mamba_state_copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
mamba_state_copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
mamba_group_ids
:
list
[
int
],
mamba_group_ids
:
list
[
int
],
...
@@ -71,10 +100,15 @@ def collect_mamba_copy_meta(
...
@@ -71,10 +100,15 @@ def collect_mamba_copy_meta(
accept_token_bias
:
int
,
accept_token_bias
:
int
,
req_state
:
CachedRequestState
,
req_state
:
CachedRequestState
,
forward_context
:
dict
[
str
,
Any
],
forward_context
:
dict
[
str
,
Any
],
):
)
->
None
:
if
src_block_idx
==
dest_block_idx
and
accept_token_bias
==
0
:
if
src_block_idx
==
dest_block_idx
and
accept_token_bias
==
0
:
return
return
src_ptrs_np
=
copy_bufs
.
src_ptrs
.
np
dst_ptrs_np
=
copy_bufs
.
dst_ptrs
.
np
sizes_np
=
copy_bufs
.
sizes
.
np
offset
=
copy_bufs
.
offset
for
mamba_group_id
in
mamba_group_ids
:
for
mamba_group_id
in
mamba_group_ids
:
block_ids
=
req_state
.
block_ids
[
mamba_group_id
]
block_ids
=
req_state
.
block_ids
[
mamba_group_id
]
dest_block_id
=
block_ids
[
dest_block_idx
]
dest_block_id
=
block_ids
[
dest_block_idx
]
...
@@ -87,25 +121,23 @@ def collect_mamba_copy_meta(
...
@@ -87,25 +121,23 @@ def collect_mamba_copy_meta(
state
,
block_ids
,
src_block_idx
,
accept_token_bias
+
1
state
,
block_ids
,
src_block_idx
,
accept_token_bias
+
1
)
)
src_state_list
.
append
(
copy_spec
.
start_addr
)
src_ptrs_np
[
offset
]
=
copy_spec
.
start_addr
dest_state_list
.
append
(
state
[
dest_block_id
].
data_ptr
())
dst_ptrs_np
[
offset
]
=
state
[
dest_block_id
].
data_ptr
()
num_elements_list
.
append
(
copy_spec
.
num_elements
*
state
.
element_size
())
sizes_np
[
offset
]
=
copy_spec
.
num_elements
*
state
.
element_size
()
offset
+=
1
copy_bufs
.
offset
=
offset
def
do_mamba_copy_block
(
src_state_list
:
list
[
int
],
dest_state_list
:
list
[
int
],
num_elements_list
:
list
[
int
],
):
if
len
(
src_state_list
)
==
0
:
return
assert
len
(
src_state_list
)
==
len
(
dest_state_list
)
assert
len
(
src_state_list
)
==
len
(
num_elements_list
)
src_state_ptrs
=
torch
.
tensor
(
src_state_list
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
dst_state_ptrs
=
torch
.
tensor
(
dest_state_list
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
num_elements
=
torch
.
tensor
(
num_elements_list
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
batch_memcpy
(
src_state_ptrs
,
dst_state_ptrs
,
num_elements
)
def
do_mamba_copy_block
(
copy_bufs
:
MambaCopyBuffers
):
n
=
copy_bufs
.
offset
if
n
==
0
:
return
batch_memcpy
(
copy_bufs
.
src_ptrs
.
copy_to_gpu
(
n
),
copy_bufs
.
dst_ptrs
.
copy_to_gpu
(
n
),
copy_bufs
.
sizes
.
copy_to_gpu
(
n
),
)
def
preprocess_mamba
(
def
preprocess_mamba
(
...
@@ -117,6 +149,7 @@ def preprocess_mamba(
...
@@ -117,6 +149,7 @@ def preprocess_mamba(
requests
:
dict
[
str
,
CachedRequestState
],
requests
:
dict
[
str
,
CachedRequestState
],
forward_context
:
dict
[
str
,
Any
],
forward_context
:
dict
[
str
,
Any
],
mamba_state_copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
mamba_state_copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
copy_bufs
:
MambaCopyBuffers
,
):
):
"""
"""
Copy the mamba state of previous step to the last
Copy the mamba state of previous step to the last
...
@@ -138,9 +171,7 @@ def preprocess_mamba(
...
@@ -138,9 +171,7 @@ def preprocess_mamba(
for
req_id
in
itertools
.
chain
(
finished_req_ids
,
preempted_req_ids
,
resumed_req_ids
):
for
req_id
in
itertools
.
chain
(
finished_req_ids
,
preempted_req_ids
,
resumed_req_ids
):
mamba_state_idx
.
pop
(
req_id
,
None
)
mamba_state_idx
.
pop
(
req_id
,
None
)
src_state_list
:
list
[
int
]
=
[]
copy_bufs
.
offset
=
0
dest_state_list
:
list
[
int
]
=
[]
num_elements_list
:
list
[
int
]
=
[]
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
req_state
=
requests
[
req_id
]
req_state
=
requests
[
req_id
]
prev_state_idx
=
mamba_state_idx
.
get
(
req_id
)
prev_state_idx
=
mamba_state_idx
.
get
(
req_id
)
...
@@ -169,9 +200,7 @@ def preprocess_mamba(
...
@@ -169,9 +200,7 @@ def preprocess_mamba(
mamba_state_idx
[
req_id
]
=
curr_state_idx
mamba_state_idx
[
req_id
]
=
curr_state_idx
if
prev_state_idx
!=
-
1
and
prev_state_idx
!=
curr_state_idx
:
if
prev_state_idx
!=
-
1
and
prev_state_idx
!=
curr_state_idx
:
collect_mamba_copy_meta
(
collect_mamba_copy_meta
(
src_state_list
,
copy_bufs
,
dest_state_list
,
num_elements_list
,
kv_cache_config
,
kv_cache_config
,
mamba_state_copy_funcs
,
mamba_state_copy_funcs
,
mamba_group_ids
,
mamba_group_ids
,
...
@@ -182,7 +211,7 @@ def preprocess_mamba(
...
@@ -182,7 +211,7 @@ def preprocess_mamba(
forward_context
,
forward_context
,
)
)
input_batch
.
num_accepted_tokens_cpu
[
i
]
=
1
input_batch
.
num_accepted_tokens_cpu
[
i
]
=
1
do_mamba_copy_block
(
src_state_list
,
dest_state_list
,
num_elements_list
)
do_mamba_copy_block
(
copy_bufs
)
def
postprocess_mamba
(
def
postprocess_mamba
(
...
@@ -193,6 +222,7 @@ def postprocess_mamba(
...
@@ -193,6 +222,7 @@ def postprocess_mamba(
mamba_state_idx
:
dict
[
str
,
int
],
mamba_state_idx
:
dict
[
str
,
int
],
forward_context
:
dict
[
str
,
Any
],
forward_context
:
dict
[
str
,
Any
],
mamba_state_copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
mamba_state_copy_funcs
:
tuple
[
MambaStateCopyFunc
,
...],
copy_bufs
:
MambaCopyBuffers
,
):
):
"""
"""
If a blocks is converted from partial block to full block in this step, copy the
If a blocks is converted from partial block to full block in this step, copy the
...
@@ -203,9 +233,7 @@ def postprocess_mamba(
...
@@ -203,9 +233,7 @@ def postprocess_mamba(
num_accepted_tokens_cpu
=
input_batch
.
num_accepted_tokens_cpu
num_accepted_tokens_cpu
=
input_batch
.
num_accepted_tokens_cpu
# NOTE: can be optimized as this function always returns the same result
# NOTE: can be optimized as this function always returns the same result
mamba_group_ids
,
mamba_spec
=
get_mamba_groups
(
kv_cache_config
)
mamba_group_ids
,
mamba_spec
=
get_mamba_groups
(
kv_cache_config
)
src_state_list
:
list
[
int
]
=
[]
copy_bufs
.
offset
=
0
dest_state_list
:
list
[
int
]
=
[]
num_elements_list
:
list
[
int
]
=
[]
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
req_state
=
requests
[
req_id
]
req_state
=
requests
[
req_id
]
num_computed_tokens
=
req_state
.
num_computed_tokens
num_computed_tokens
=
req_state
.
num_computed_tokens
...
@@ -225,9 +253,7 @@ def postprocess_mamba(
...
@@ -225,9 +253,7 @@ def postprocess_mamba(
src_block_idx
=
mamba_state_idx
[
req_id
]
src_block_idx
=
mamba_state_idx
[
req_id
]
dest_block_idx
=
aligned_new_computed_tokens
//
mamba_spec
.
block_size
-
1
dest_block_idx
=
aligned_new_computed_tokens
//
mamba_spec
.
block_size
-
1
collect_mamba_copy_meta
(
collect_mamba_copy_meta
(
src_state_list
,
copy_bufs
,
dest_state_list
,
num_elements_list
,
kv_cache_config
,
kv_cache_config
,
mamba_state_copy_funcs
,
mamba_state_copy_funcs
,
mamba_group_ids
,
mamba_group_ids
,
...
@@ -239,4 +265,4 @@ def postprocess_mamba(
...
@@ -239,4 +265,4 @@ def postprocess_mamba(
)
)
if
src_block_idx
==
dest_block_idx
:
if
src_block_idx
==
dest_block_idx
:
num_accepted_tokens_cpu
[
i
]
=
1
num_accepted_tokens_cpu
[
i
]
=
1
do_mamba_copy_block
(
src_state_list
,
dest_state_list
,
num_elements_list
)
do_mamba_copy_block
(
copy_bufs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment