Commit bd58c289 authored by 王敏's avatar 王敏
Browse files

[feat]支持mtp模型full_cuda_graph

parent 89eecc55
......@@ -232,7 +232,7 @@ class EagleProposer:
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = torch.argmax(logits, dim=-1)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
......@@ -380,7 +380,7 @@ class EagleProposer:
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
# # TODO(wenlong): get more than one token for tree attention
# TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment