Commit 0ae9ce75 authored by jujl1's avatar jujl1
Browse files

feat: pp mtp加入零消耗调度,减少空泡

parent d8ea775f
...@@ -88,7 +88,7 @@ else: ...@@ -88,7 +88,7 @@ else:
"xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile")
logger = init_logger(__name__) logger = init_logger(__name__)
from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer
class GPUModelRunner(LoRAModelRunnerMixin): class GPUModelRunner(LoRAModelRunnerMixin):
...@@ -134,7 +134,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -134,7 +134,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs self.max_num_reqs = scheduler_config.max_num_seqs
self.spec_sampler_event = torch.cuda.Event(enable_timing=False)
self.spec_scheduler_max_num_tokens = 0
# Model-related. # Model-related.
self.num_query_heads = model_config.get_num_attention_heads( self.num_query_heads = model_config.get_num_attention_heads(
parallel_config) parallel_config)
...@@ -182,20 +183,22 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -182,20 +183,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the last PP rank. This is not ideal if there are many # the last PP rank. This is not ideal if there are many
# layers in the draft model. # layers in the draft model.
if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config and get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram": self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device,
self.drafter = NgramProposer(self.vllm_config) self)
elif self.speculative_config.use_eagle(): # if self.speculative_config.method == "ngram":
self.drafter = EagleProposer(self.vllm_config, self.device, # self.drafter = NgramProposer(self.vllm_config)
self) # type: ignore # elif self.speculative_config.use_eagle():
if self.speculative_config.method == "eagle3": # self.drafter = EagleProposer(self.vllm_config, self.device,
self.use_aux_hidden_state_outputs = True # self) # type: ignore
elif self.speculative_config.method == "medusa": # if self.speculative_config.method == "eagle3":
self.drafter = MedusaProposer( # self.use_aux_hidden_state_outputs = True
vllm_config=self.vllm_config, # elif self.speculative_config.method == "medusa":
device=self.device) # type: ignore # self.drafter = MedusaProposer(
else: # vllm_config=self.vllm_config,
raise ValueError("Unknown speculative decoding method: " # device=self.device) # type: ignore
f"{self.speculative_config.method}") # else:
# raise ValueError("Unknown speculative decoding method: "
# f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler() self.rejection_sampler = RejectionSampler()
# Request states. # Request states.
...@@ -609,7 +612,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -609,7 +612,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32) num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens) max_num_scheduled_tokens = max(tokens)
self.spec_scheduler_max_num_tokens = max_num_scheduled_tokens
# Get request indices. # Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs], req_indices = np.repeat(self.arange_np[:num_reqs],
...@@ -1543,18 +1546,39 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1543,18 +1546,39 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states[:num_scheduled_tokens], hidden_states[:num_scheduled_tokens],
scheduler_output, scheduler_output,
) )
#-----------------------------------
# Get the valid generated tokens. # Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
else:
self.spec_sampler_event.record()
mask = (sampled_token_ids == -1)
mask_int = mask.int()
first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_token_ids = self.zero_propose_draft_token_ids(
scheduler_output,
num_accepted_tokens_tensor,
sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist() valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
else: else:
# Includes spec decode tokens. # Includes spec decode tokens.
self.spec_sampler_event.synchronize()
valid_sampled_token_ids = self.rejection_sampler.parse_output( valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids, sampled_token_ids_cpu,
self.input_batch.vocab_size, self.input_batch.vocab_size,
) )
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
...@@ -1585,20 +1609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1585,20 +1609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_state = self.requests[req_id] req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
else:
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
...@@ -1619,6 +1629,109 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1619,6 +1629,109 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_nans_in_logits=num_nans_in_logits, num_nans_in_logits=num_nans_in_logits,
) )
def zero_propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
num_accepted_tokens_tensor: torch.Tensor,
sampled_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata],
attn_metadata: dict[str, Any],
) -> list[list[int]]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
if sample_hidden_states.shape[0] == len(sampled_token_ids):
# The input to the target model does not include draft tokens.
hidden_states = sample_hidden_states
else:
indices = []
offset = 0
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
sampled_token_ids):
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.drafter.propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
row_indices = torch.arange(sampled_token_ids.size(0), device=sampled_token_ids.device)
next_token_ids = sampled_token_ids[row_indices, num_accepted_tokens_tensor].flatten()
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata = attn_metadata[
self.drafter.attn_layer_names[0]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
block_table = eagle_attn_metadata.block_table
else:
block_table = None
spec_scheduler_max_num_tokens = self.spec_scheduler_max_num_tokens
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = eagle_attn_metadata.slot_mapping
cu_num_tokens = eagle_attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc,
num_accepted_tokens_tensor,
)
spec_scheduler_max_num_tokens = 1
target_token_ids = self.input_ids[token_indices]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=block_table,
sampling_metadata=sampling_metadata,
decoding=spec_decode_metadata is not None,
)
# spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
# self.last_draft_token_ids = draft_token_ids
# self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
# self.last_draft_event.record()
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment