Unverified Commit 298e5108 authored by Francesco Fusco's avatar Francesco Fusco Committed by GitHub
Browse files

[Hybrid] calling get_mamba_groups() once at MambaCopyBuffers.create() (#37318)


Signed-off-by: default avatarFrancesco Fusco <ffu@zurich.ibm.com>
parent 3982bc2c
...@@ -36,6 +36,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx(): ...@@ -36,6 +36,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
spec = MagicMock(block_size=64, num_speculative_blocks=0) spec = MagicMock(block_size=64, num_speculative_blocks=0)
cache_config = MagicMock(enable_prefix_caching=True) cache_config = MagicMock(enable_prefix_caching=True)
input_batch = MagicMock(req_ids=[]) input_batch = MagicMock(req_ids=[])
copy_bufs = MagicMock(mamba_group_ids=[0], mamba_spec=spec)
mamba_state_idx = { mamba_state_idx = {
"finished": 1, "finished": 1,
...@@ -62,7 +63,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx(): ...@@ -62,7 +63,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
{}, {},
{}, {},
(), (),
MagicMock(), copy_bufs,
) )
assert mamba_state_idx == {"keep": 99} assert mamba_state_idx == {"keep": 99}
...@@ -67,6 +67,8 @@ class MambaCopyBuffers: ...@@ -67,6 +67,8 @@ class MambaCopyBuffers:
src_ptrs: CpuGpuBuffer src_ptrs: CpuGpuBuffer
dst_ptrs: CpuGpuBuffer dst_ptrs: CpuGpuBuffer
sizes: CpuGpuBuffer sizes: CpuGpuBuffer
mamba_group_ids: list[int]
mamba_spec: MambaSpec
offset: int = 0 offset: int = 0
@classmethod @classmethod
...@@ -77,7 +79,7 @@ class MambaCopyBuffers: ...@@ -77,7 +79,7 @@ class MambaCopyBuffers:
copy_funcs: tuple[MambaStateCopyFunc, ...], copy_funcs: tuple[MambaStateCopyFunc, ...],
make_buffer: Callable[..., CpuGpuBuffer], make_buffer: Callable[..., CpuGpuBuffer],
) -> "MambaCopyBuffers": ) -> "MambaCopyBuffers":
mamba_group_ids, _ = get_mamba_groups(kv_cache_config) mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
entries_per_req = sum( entries_per_req = sum(
len(kv_cache_config.kv_cache_groups[gid].layer_names) len(kv_cache_config.kv_cache_groups[gid].layer_names)
for gid in mamba_group_ids for gid in mamba_group_ids
...@@ -87,6 +89,8 @@ class MambaCopyBuffers: ...@@ -87,6 +89,8 @@ class MambaCopyBuffers:
src_ptrs=make_buffer(n, dtype=torch.int64), src_ptrs=make_buffer(n, dtype=torch.int64),
dst_ptrs=make_buffer(n, dtype=torch.int64), dst_ptrs=make_buffer(n, dtype=torch.int64),
sizes=make_buffer(n, dtype=torch.int32), sizes=make_buffer(n, dtype=torch.int32),
mamba_group_ids=mamba_group_ids,
mamba_spec=mamba_spec,
) )
...@@ -155,7 +159,8 @@ def preprocess_mamba( ...@@ -155,7 +159,8 @@ def preprocess_mamba(
Copy the mamba state of previous step to the last Copy the mamba state of previous step to the last
(1 + num_speculative_blocks) block. (1 + num_speculative_blocks) block.
""" """
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) mamba_group_ids = copy_bufs.mamba_group_ids
mamba_spec = copy_bufs.mamba_spec
num_speculative_blocks = mamba_spec.num_speculative_blocks num_speculative_blocks = mamba_spec.num_speculative_blocks
# TODO(Chen): we need to optimize this function a lot # TODO(Chen): we need to optimize this function a lot
assert cache_config.enable_prefix_caching assert cache_config.enable_prefix_caching
...@@ -231,8 +236,8 @@ def postprocess_mamba( ...@@ -231,8 +236,8 @@ def postprocess_mamba(
num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
# NOTE: can be optimized as this function always returns the same result mamba_group_ids = copy_bufs.mamba_group_ids
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) mamba_spec = copy_bufs.mamba_spec
copy_bufs.offset = 0 copy_bufs.offset = 0
for i, req_id in enumerate(input_batch.req_ids): for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id] req_state = requests[req_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment