Unverified Commit f40d9879 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models][GDN] Remove GPU/CPU syncs in `GDNAttentionMetadata.build` during...


[Models][GDN] Remove GPU/CPU syncs in `GDNAttentionMetadata.build` during speculative decoding (#38047)
Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 47e60509
...@@ -253,7 +253,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -253,7 +253,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
) )
# Filter by spec_sequence_masks to exclude padded sequences # Filter by spec_sequence_masks to exclude padded sequences
spec_state_indices_tensor = block_table_tensor[ spec_state_indices_tensor = block_table_tensor[
spec_sequence_masks, : self.num_spec + 1 spec_sequence_masks_cpu, : self.num_spec + 1
] ]
non_spec_state_indices_tensor = None non_spec_state_indices_tensor = None
# Padded sequences are always at the back, so the first # Padded sequences are always at the back, so the first
...@@ -264,7 +264,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -264,7 +264,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_query_start_loc_cpu = None non_spec_query_start_loc_cpu = None
else: else:
spec_token_masks = torch.repeat_interleave( spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens spec_sequence_masks,
query_lens,
output_size=query_start_loc_cpu[-1].item(),
) )
index = torch.argsort(spec_token_masks, stable=True) index = torch.argsort(spec_token_masks, stable=True)
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
...@@ -272,10 +274,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -272,10 +274,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_token_indx = index[num_non_spec_tokens:] spec_token_indx = index[num_non_spec_tokens:]
spec_state_indices_tensor = block_table_tensor[ spec_state_indices_tensor = block_table_tensor[
spec_sequence_masks, : self.num_spec + 1 spec_sequence_masks_cpu, : self.num_spec + 1
] ]
non_spec_state_indices_tensor = block_table_tensor[ non_spec_state_indices_tensor = block_table_tensor[
~spec_sequence_masks, 0 ~spec_sequence_masks_cpu, 0
] ]
spec_query_start_loc = torch.zeros( spec_query_start_loc = torch.zeros(
...@@ -284,7 +286,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -284,7 +286,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
device=query_start_loc.device, device=query_start_loc.device,
) )
torch.cumsum( torch.cumsum(
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:] query_lens[spec_sequence_masks_cpu],
dim=0,
out=spec_query_start_loc[1:],
) )
non_spec_query_start_loc = torch.zeros( non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1, query_lens.size(0) - num_spec_decodes + 1,
...@@ -292,7 +296,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -292,7 +296,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
device=query_start_loc.device, device=query_start_loc.device,
) )
torch.cumsum( torch.cumsum(
query_lens[~spec_sequence_masks], query_lens[~spec_sequence_masks_cpu],
dim=0, dim=0,
out=non_spec_query_start_loc[1:], out=non_spec_query_start_loc[1:],
) )
...@@ -307,7 +311,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -307,7 +311,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
) )
assert num_accepted_tokens is not None assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] num_accepted_tokens = num_accepted_tokens[spec_sequence_masks_cpu]
chunk_indices: torch.Tensor | None = None chunk_indices: torch.Tensor | None = None
chunk_offsets: torch.Tensor | None = None chunk_offsets: torch.Tensor | None = None
...@@ -331,8 +335,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -331,8 +335,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
if num_prefills > 0: if num_prefills > 0:
has_initial_state = context_lens_tensor > 0 has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None: if spec_sequence_masks_cpu is not None:
has_initial_state = has_initial_state[~spec_sequence_masks] has_initial_state = has_initial_state[~spec_sequence_masks_cpu]
assert non_spec_query_start_loc_cpu is not None assert non_spec_query_start_loc_cpu is not None
nums_dict, batch_ptr, token_chunk_offset_ptr = ( nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata( compute_causal_conv1d_metadata(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment