"tests/vscode:/vscode.git/clone" did not exist on "5eec6110c336a4ff897cc52fdd55b6e8e31ee036"
Commit c17574bd authored by zhuwenwen's avatar zhuwenwen
Browse files

Revert "Revert "[feat]优化mtp相关函数返回类型""

This reverts commit c34fa0bf.
parent 467490e6
......@@ -98,7 +98,7 @@ class EagleProposer:
next_token_ids: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
......@@ -194,7 +194,7 @@ class EagleProposer:
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1), draft_probs_list
return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, draft_prob.shape[-1])
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
......
......@@ -1687,7 +1687,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata],
common_attn_metadata: CommonAttentionMetadata,
) -> list[list[int]]:
) -> tuple[list[list[int]], torch.Tensor]:
draft_probs = None
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment