Commit 7e71c143 authored by 王敏's avatar 王敏
Browse files

[feat]优化mtp相关函数返回类型

parent 8e0ae19d
......@@ -107,7 +107,7 @@ class EagleProposer:
num_rejected_tokens: list[int],
# [batch_size]
sampling_metadata: SamplingMetadata
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
......@@ -240,7 +240,7 @@ class EagleProposer:
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1), draft_probs_list
return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, draft_prob.shape[-1])
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
......
......@@ -1600,7 +1600,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata],
attn_metadata: dict[str, Any],
) -> list[list[int]]:
) -> tuple[list[list[int]], torch.Tensor]:
draft_probs = None
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment