Commit ffe9e7db authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat]支持mtp模型full_cuda_graph

parent f2218895
...@@ -756,13 +756,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -756,13 +756,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
rarange = np.repeat(query_lens, query_lens) - arange - 1 rarange = np.repeat(query_lens, query_lens) - arange - 1
repeats = torch.from_numpy(query_lens).pin_memory().to( repeats = torch.from_numpy(query_lens).pin_memory().to(
block_table_tensor.device, non_blocking=True) block_table_tensor.device, non_blocking=True).contiguous()
decode_block_table_tensor = torch.repeat_interleave( decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:num_decodes, ...], block_table_tensor[:num_decodes, ...],
repeats, dim=0) repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:num_decodes], repeats, dim=0) decode_seq_lens = torch.repeat_interleave(seq_lens[:num_decodes], repeats, dim=0).contiguous()
seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to( seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
seq_lens.device, non_blocking=True) seq_lens.device, non_blocking=True).contiguous()
decode_seq_lens = decode_seq_lens - seq_lens_minus decode_seq_lens = decode_seq_lens - seq_lens_minus
if self.spec_decode_block_table_tensor is not None: if self.spec_decode_block_table_tensor is not None:
......
...@@ -115,6 +115,8 @@ class EagleProposer: ...@@ -115,6 +115,8 @@ class EagleProposer:
# Replace the last token with the next token. # Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids self.input_ids[last_token_indices] = next_token_ids
seq_lens = (target_positions[last_token_indices] + 1).int()
assert self.runner is not None assert self.runner is not None
...@@ -186,7 +188,7 @@ class EagleProposer: ...@@ -186,7 +188,7 @@ class EagleProposer:
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = torch.argmax(logits, dim=-1) draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1: if self.num_speculative_tokens == 1:
...@@ -228,7 +230,7 @@ class EagleProposer: ...@@ -228,7 +230,7 @@ class EagleProposer:
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...] block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode( attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
block_table_tensor=block_table, block_table_tensor=block_table,
seq_lens=(seq_lens + 1), seq_lens=seq_lens,
) )
for i in range(self.num_speculative_tokens - 1): for i in range(self.num_speculative_tokens - 1):
......
...@@ -1640,8 +1640,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1640,8 +1640,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
) )
spec_token_ids = spec_token_ids.tolist()
self.eplb_step() self.eplb_step()
return ModelRunnerOutput( return ModelRunnerOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment