"vllm/vscode:/vscode.git/clone" did not exist on "8b45c58fe9a04864c736e4da5ca7249ebf4be3cf"
Commit a1239b53 authored by 王敏's avatar 王敏
Browse files

[feat]支持mtp模型full_cuda_graph

parent 7d4f5027
...@@ -647,13 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -647,13 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
rarange = np.repeat(query_lens, query_lens) - arange - 1 rarange = np.repeat(query_lens, query_lens) - arange - 1
repeats = torch.from_numpy(query_lens).pin_memory().to( repeats = torch.from_numpy(query_lens).pin_memory().to(
block_table_tensor.device, non_blocking=True) block_table_tensor.device, non_blocking=True).contiguous()
decode_block_table_tensor = torch.repeat_interleave( decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...], block_table_tensor[:self._num_decodes, ...],
repeats, dim=0) repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0) decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to( seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
seq_lens.device, non_blocking=True) seq_lens.device, non_blocking=True).contiguous()
decode_seq_lens = decode_seq_lens - seq_lens_minus decode_seq_lens = decode_seq_lens - seq_lens_minus
if self.spec_decode_block_table_tensor is not None: if self.spec_decode_block_table_tensor is not None:
......
...@@ -269,7 +269,7 @@ class EagleProposer: ...@@ -269,7 +269,7 @@ class EagleProposer:
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...] block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode( attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
block_table_tensor=block_table, block_table_tensor=block_table,
seq_lens=(seq_lens + 1), seq_lens=seq_lens,
) )
for i in range(self.num_speculative_tokens - 1): for i in range(self.num_speculative_tokens - 1):
......
...@@ -1548,8 +1548,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1548,8 +1548,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata, attn_metadata,
) )
spec_token_ids = spec_token_ids.tolist()
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment