Unverified Commit 57c629e9 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix] Fix block_size for hybrid model MTP (#36036)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent d106bf39
......@@ -37,6 +37,8 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting
BLOCK_SIZE = 16
def _create_proposer(
method: str,
......@@ -91,9 +93,11 @@ def _create_proposer(
)
if "eagle" in method:
return EagleProposer(vllm_config=vllm_config, device=device)
proposer = EagleProposer(vllm_config=vllm_config, device=device)
else:
return DraftModelProposer(vllm_config=vllm_config, device=device)
proposer = DraftModelProposer(vllm_config=vllm_config, device=device)
proposer.block_size = BLOCK_SIZE
return proposer
def test_prepare_next_token_ids():
......@@ -163,7 +167,7 @@ def test_prepare_next_token_ids():
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
......@@ -207,7 +211,7 @@ def test_prepare_inputs():
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
......@@ -302,7 +306,7 @@ def test_prepare_inputs_padded():
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
......@@ -371,7 +375,7 @@ def test_set_inputs_first_pass_default_eagle():
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
......@@ -462,7 +466,7 @@ def test_set_inputs_first_pass_draft_model():
device = torch.device(current_platform.device_type)
num_speculative_tokens = 2
block_size = 16
block_size = BLOCK_SIZE
# Create a proposer configured as a draft model (pass_hidden_states=False)
# We need to mock this since _create_proposer defaults to EAGLE
......@@ -600,7 +604,7 @@ def test_set_inputs_first_pass_parallel_drafting():
device = torch.device(current_platform.device_type)
num_speculative_tokens = 3
block_size = 16
block_size = BLOCK_SIZE
proposer = _create_proposer("eagle", num_speculative_tokens, parallel_drafting=True)
......@@ -926,7 +930,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
......@@ -1123,7 +1127,7 @@ def test_propose_tree(spec_token_tree):
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
sampling_metadata = mock.MagicMock()
......
......@@ -162,6 +162,9 @@ class SpecDecodeBaseProposer:
(self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
)
# Will be set when we initialize the attention backend
self.block_size: int = -1
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
......@@ -583,8 +586,8 @@ class SpecDecodeBaseProposer:
common_attn_metadata._num_computed_tokens_cpu += 1
# Compute the slot mapping.
# Use the first draft attention group's kv_cache_spec for block_size
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
block_size = self.block_size
assert block_size > 0, "block_size has not been initialized."
if self.uses_mrope:
# all dimensions of positions are the same
block_numbers = clamped_positions[0] // block_size
......@@ -778,17 +781,14 @@ class SpecDecodeBaseProposer:
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
# Use the first draft attention group's kv_cache_spec for block_size
# (all draft layers share the same kv-cache group)
assert len(self.draft_attn_groups) > 0
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
assert self.block_size > 0, "block_size has not been initialized."
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[
:total_num_output_tokens
],
block_size=block_size,
block_size=self.block_size,
num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len,
)
......@@ -1635,6 +1635,10 @@ class SpecDecodeBaseProposer:
attention_groups[backend_key].layer_names.append(layer_name)
self.draft_attn_groups = list(attention_groups.values())
self.block_size = (
self.draft_attn_groups[0].get_metadata_builder().kv_cache_spec.block_size
)
logger.debug("Using block size %d for drafting layers", self.block_size)
def _determine_batch_execution_and_padding(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment