Commit 059b1961 authored by yangql's avatar yangql
Browse files

增加mtp>1的pad

parent 19288a48
......@@ -282,6 +282,10 @@ class EagleProposer:
seq_lens=seq_lens,
)
#增加mtp>1的pad
num_pad, num_tokens_across_dp = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
for i in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
......
......@@ -203,6 +203,10 @@ class V1ZeroEagleProposer(EagleProposer):
seq_lens=seq_lens,
)
#增加mtp>1的pad
num_pad, num_tokens_across_dp = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
for i in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment