Unverified Commit 157722da authored by Huamin Li's avatar Huamin Li Committed by GitHub
Browse files

[perf] Use pinned memory for async H2D transfer in do_mamba_copy_block (#35480)


Signed-off-by: default avatarHuamin Li <3ericli@gmail.com>
parent 1d897ff0
...@@ -325,6 +325,7 @@ def get_fake_process_mamba_fn( ...@@ -325,6 +325,7 @@ def get_fake_process_mamba_fn(
requests: dict[str, CachedRequestState], requests: dict[str, CachedRequestState],
forward_context: dict[str, Any], forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: mamba_utils.MambaCopyBuffers,
): ):
nonlocal copy_info nonlocal copy_info
copy_info = None copy_info = None
...@@ -337,6 +338,7 @@ def get_fake_process_mamba_fn( ...@@ -337,6 +338,7 @@ def get_fake_process_mamba_fn(
requests, requests,
forward_context, forward_context,
mamba_state_copy_funcs, mamba_state_copy_funcs,
copy_bufs,
) )
if cur_step_action is not None: if cur_step_action is not None:
check_copy_info( check_copy_info(
...@@ -355,6 +357,7 @@ def get_fake_process_mamba_fn( ...@@ -355,6 +357,7 @@ def get_fake_process_mamba_fn(
mamba_state_idx: dict[str, int], mamba_state_idx: dict[str, int],
forward_context: dict[str, Any], forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: mamba_utils.MambaCopyBuffers,
): ):
nonlocal copy_info nonlocal copy_info
copy_info = None copy_info = None
...@@ -366,6 +369,7 @@ def get_fake_process_mamba_fn( ...@@ -366,6 +369,7 @@ def get_fake_process_mamba_fn(
mamba_state_idx, mamba_state_idx,
forward_context, forward_context,
mamba_state_copy_funcs, mamba_state_copy_funcs,
copy_bufs,
) )
if cur_step_action is not None: if cur_step_action is not None:
check_copy_info( check_copy_info(
...@@ -376,19 +380,15 @@ def get_fake_process_mamba_fn( ...@@ -376,19 +380,15 @@ def get_fake_process_mamba_fn(
) )
return ret return ret
def fake_copy_fn( def fake_copy_fn(copy_bufs: mamba_utils.MambaCopyBuffers):
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
nonlocal copy_info nonlocal copy_info
assert copy_info is None assert copy_info is None
n = copy_bufs.offset
src_state_list = copy_bufs.src_ptrs.cpu[:n].tolist()
dest_state_list = copy_bufs.dst_ptrs.cpu[:n].tolist()
num_elements_list = copy_bufs.sizes.cpu[:n].tolist()
copy_info = (src_state_list, dest_state_list, num_elements_list) copy_info = (src_state_list, dest_state_list, num_elements_list)
return original_copy_fn( return original_copy_fn(copy_bufs)
src_state_list,
dest_state_list,
num_elements_list,
)
return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn
......
...@@ -62,6 +62,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx(): ...@@ -62,6 +62,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
{}, {},
{}, {},
(), (),
MagicMock(),
) )
assert mamba_state_idx == {"keep": 99} assert mamba_state_idx == {"keep": 99}
...@@ -755,6 +755,7 @@ class GPUModelRunner( ...@@ -755,6 +755,7 @@ class GPUModelRunner(
self.execute_model_state: ExecuteModelState | None = None self.execute_model_state: ExecuteModelState | None = None
self.kv_connector_output: KVConnectorOutput | None = None self.kv_connector_output: KVConnectorOutput | None = None
self.mamba_state_idx: dict[str, int] = {} self.mamba_state_idx: dict[str, int] = {}
self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None
self.layerwise_nvtx_hooks_registered = False self.layerwise_nvtx_hooks_registered = False
def update_max_model_len(self, max_model_len: int) -> None: def update_max_model_len(self, max_model_len: int) -> None:
...@@ -849,6 +850,16 @@ class GPUModelRunner( ...@@ -849,6 +850,16 @@ class GPUModelRunner(
with_numpy=numpy, with_numpy=numpy,
) )
def _get_mamba_copy_bufs(self) -> mamba_utils.MambaCopyBuffers:
if self._mamba_copy_bufs is None:
self._mamba_copy_bufs = mamba_utils.MambaCopyBuffers.create(
self.max_num_reqs,
self.kv_cache_config,
self.model.get_mamba_state_copy_func(),
self._make_buffer,
)
return self._mamba_copy_bufs
def _init_model_kwargs(self): def _init_model_kwargs(self):
model_kwargs = dict[str, Any]() model_kwargs = dict[str, Any]()
...@@ -1211,6 +1222,7 @@ class GPUModelRunner( ...@@ -1211,6 +1222,7 @@ class GPUModelRunner(
self.mamba_state_idx, self.mamba_state_idx,
self.compilation_config.static_forward_context, self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(), self.model.get_mamba_state_copy_func(),
self._get_mamba_copy_bufs(),
) )
def _update_streaming_request( def _update_streaming_request(
...@@ -3505,6 +3517,7 @@ class GPUModelRunner( ...@@ -3505,6 +3517,7 @@ class GPUModelRunner(
self.requests, self.requests,
self.compilation_config.static_forward_context, self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(), self.model.get_mamba_state_copy_func(),
self._get_mamba_copy_bufs(),
) )
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
...@@ -5997,6 +6010,7 @@ class GPUModelRunner( ...@@ -5997,6 +6010,7 @@ class GPUModelRunner(
""" """
kv_cache_config = deepcopy(kv_cache_config) kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self._mamba_copy_bufs = None
self.may_add_encoder_only_layers_to_kv_cache_config() self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config) self.initialize_attn_backend(kv_cache_config)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import itertools import itertools
from collections.abc import Callable
from typing import Any from typing import Any
import torch import torch
...@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton ...@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
...@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp ...@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp
return mamba_group_ids, mamba_specs[0] return mamba_group_ids, mamba_specs[0]
@dataclasses.dataclass
class MambaCopyBuffers:
src_ptrs: CpuGpuBuffer
dst_ptrs: CpuGpuBuffer
sizes: CpuGpuBuffer
offset: int = 0
@classmethod
def create(
cls,
max_num_reqs: int,
kv_cache_config: KVCacheConfig,
copy_funcs: tuple[MambaStateCopyFunc, ...],
make_buffer: Callable[..., CpuGpuBuffer],
) -> "MambaCopyBuffers":
mamba_group_ids, _ = get_mamba_groups(kv_cache_config)
entries_per_req = sum(
len(kv_cache_config.kv_cache_groups[gid].layer_names)
for gid in mamba_group_ids
) * len(copy_funcs)
n = max_num_reqs * entries_per_req
return cls(
src_ptrs=make_buffer(n, dtype=torch.int64),
dst_ptrs=make_buffer(n, dtype=torch.int64),
sizes=make_buffer(n, dtype=torch.int32),
)
def collect_mamba_copy_meta( def collect_mamba_copy_meta(
src_state_list: list[int], copy_bufs: MambaCopyBuffers,
dest_state_list: list[int],
num_elements_list: list[int],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
mamba_group_ids: list[int], mamba_group_ids: list[int],
...@@ -71,10 +100,15 @@ def collect_mamba_copy_meta( ...@@ -71,10 +100,15 @@ def collect_mamba_copy_meta(
accept_token_bias: int, accept_token_bias: int,
req_state: CachedRequestState, req_state: CachedRequestState,
forward_context: dict[str, Any], forward_context: dict[str, Any],
): ) -> None:
if src_block_idx == dest_block_idx and accept_token_bias == 0: if src_block_idx == dest_block_idx and accept_token_bias == 0:
return return
src_ptrs_np = copy_bufs.src_ptrs.np
dst_ptrs_np = copy_bufs.dst_ptrs.np
sizes_np = copy_bufs.sizes.np
offset = copy_bufs.offset
for mamba_group_id in mamba_group_ids: for mamba_group_id in mamba_group_ids:
block_ids = req_state.block_ids[mamba_group_id] block_ids = req_state.block_ids[mamba_group_id]
dest_block_id = block_ids[dest_block_idx] dest_block_id = block_ids[dest_block_idx]
...@@ -87,25 +121,23 @@ def collect_mamba_copy_meta( ...@@ -87,25 +121,23 @@ def collect_mamba_copy_meta(
state, block_ids, src_block_idx, accept_token_bias + 1 state, block_ids, src_block_idx, accept_token_bias + 1
) )
src_state_list.append(copy_spec.start_addr) src_ptrs_np[offset] = copy_spec.start_addr
dest_state_list.append(state[dest_block_id].data_ptr()) dst_ptrs_np[offset] = state[dest_block_id].data_ptr()
num_elements_list.append(copy_spec.num_elements * state.element_size()) sizes_np[offset] = copy_spec.num_elements * state.element_size()
offset += 1
copy_bufs.offset = offset
def do_mamba_copy_block(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
if len(src_state_list) == 0:
return
assert len(src_state_list) == len(dest_state_list)
assert len(src_state_list) == len(num_elements_list)
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements) def do_mamba_copy_block(copy_bufs: MambaCopyBuffers):
n = copy_bufs.offset
if n == 0:
return
batch_memcpy(
copy_bufs.src_ptrs.copy_to_gpu(n),
copy_bufs.dst_ptrs.copy_to_gpu(n),
copy_bufs.sizes.copy_to_gpu(n),
)
def preprocess_mamba( def preprocess_mamba(
...@@ -117,6 +149,7 @@ def preprocess_mamba( ...@@ -117,6 +149,7 @@ def preprocess_mamba(
requests: dict[str, CachedRequestState], requests: dict[str, CachedRequestState],
forward_context: dict[str, Any], forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: MambaCopyBuffers,
): ):
""" """
Copy the mamba state of previous step to the last Copy the mamba state of previous step to the last
...@@ -138,9 +171,7 @@ def preprocess_mamba( ...@@ -138,9 +171,7 @@ def preprocess_mamba(
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids): for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
mamba_state_idx.pop(req_id, None) mamba_state_idx.pop(req_id, None)
src_state_list: list[int] = [] copy_bufs.offset = 0
dest_state_list: list[int] = []
num_elements_list: list[int] = []
for i, req_id in enumerate(input_batch.req_ids): for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id] req_state = requests[req_id]
prev_state_idx = mamba_state_idx.get(req_id) prev_state_idx = mamba_state_idx.get(req_id)
...@@ -169,9 +200,7 @@ def preprocess_mamba( ...@@ -169,9 +200,7 @@ def preprocess_mamba(
mamba_state_idx[req_id] = curr_state_idx mamba_state_idx[req_id] = curr_state_idx
if prev_state_idx != -1 and prev_state_idx != curr_state_idx: if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
collect_mamba_copy_meta( collect_mamba_copy_meta(
src_state_list, copy_bufs,
dest_state_list,
num_elements_list,
kv_cache_config, kv_cache_config,
mamba_state_copy_funcs, mamba_state_copy_funcs,
mamba_group_ids, mamba_group_ids,
...@@ -182,7 +211,7 @@ def preprocess_mamba( ...@@ -182,7 +211,7 @@ def preprocess_mamba(
forward_context, forward_context,
) )
input_batch.num_accepted_tokens_cpu[i] = 1 input_batch.num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list) do_mamba_copy_block(copy_bufs)
def postprocess_mamba( def postprocess_mamba(
...@@ -193,6 +222,7 @@ def postprocess_mamba( ...@@ -193,6 +222,7 @@ def postprocess_mamba(
mamba_state_idx: dict[str, int], mamba_state_idx: dict[str, int],
forward_context: dict[str, Any], forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: MambaCopyBuffers,
): ):
""" """
If a blocks is converted from partial block to full block in this step, copy the If a blocks is converted from partial block to full block in this step, copy the
...@@ -203,9 +233,7 @@ def postprocess_mamba( ...@@ -203,9 +233,7 @@ def postprocess_mamba(
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
# NOTE: can be optimized as this function always returns the same result # NOTE: can be optimized as this function always returns the same result
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
src_state_list: list[int] = [] copy_bufs.offset = 0
dest_state_list: list[int] = []
num_elements_list: list[int] = []
for i, req_id in enumerate(input_batch.req_ids): for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id] req_state = requests[req_id]
num_computed_tokens = req_state.num_computed_tokens num_computed_tokens = req_state.num_computed_tokens
...@@ -225,9 +253,7 @@ def postprocess_mamba( ...@@ -225,9 +253,7 @@ def postprocess_mamba(
src_block_idx = mamba_state_idx[req_id] src_block_idx = mamba_state_idx[req_id]
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
collect_mamba_copy_meta( collect_mamba_copy_meta(
src_state_list, copy_bufs,
dest_state_list,
num_elements_list,
kv_cache_config, kv_cache_config,
mamba_state_copy_funcs, mamba_state_copy_funcs,
mamba_group_ids, mamba_group_ids,
...@@ -239,4 +265,4 @@ def postprocess_mamba( ...@@ -239,4 +265,4 @@ def postprocess_mamba(
) )
if src_block_idx == dest_block_idx: if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1 num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list) do_mamba_copy_block(copy_bufs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment