Commit a1239b53 authored by 王敏's avatar 王敏
Browse files

[feat]支持mtp模型full_cuda_graph

parent 7d4f5027
......@@ -647,13 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
rarange = np.repeat(query_lens, query_lens) - arange - 1
repeats = torch.from_numpy(query_lens).pin_memory().to(
block_table_tensor.device, non_blocking=True)
block_table_tensor.device, non_blocking=True).contiguous()
decode_block_table_tensor = torch.repeat_interleave(
block_table_tensor[:self._num_decodes, ...],
repeats, dim=0)
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0)
repeats, dim=0).contiguous()
decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
seq_lens.device, non_blocking=True)
seq_lens.device, non_blocking=True).contiguous()
decode_seq_lens = decode_seq_lens - seq_lens_minus
if self.spec_decode_block_table_tensor is not None:
......
......@@ -77,7 +77,7 @@ class EagleProposer:
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
# attention metadata captured in full cudagraph mode
self.attn_metadata_cudagraph = None
# We need +1 here because the arange is used to set query_start_loc,
......@@ -210,7 +210,7 @@ class EagleProposer:
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
if attn_metadata.decode is not None:
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
......@@ -269,7 +269,7 @@ class EagleProposer:
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
block_table_tensor=block_table,
seq_lens=(seq_lens + 1),
seq_lens=seq_lens,
)
for i in range(self.num_speculative_tokens - 1):
......@@ -354,7 +354,7 @@ class EagleProposer:
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
......
......@@ -1548,8 +1548,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata,
)
spec_token_ids = spec_token_ids.tolist()
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment