Unverified Commit 042da732 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Core] Refactor `_build_attention_metadata` (#29628)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent b5945d49
...@@ -1534,28 +1534,13 @@ class GPUModelRunner( ...@@ -1534,28 +1534,13 @@ class GPUModelRunner(
""" """
:return: tuple[attn_metadata, spec_decode_common_attn_metadata] :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
""" """
# Attention metadata is not needed for attention free models
if len(self.kv_cache_config.kv_cache_groups) == 0:
return {}, None
num_tokens_padded = num_tokens_padded or num_tokens num_tokens_padded = num_tokens_padded or num_tokens
num_reqs_padded = num_reqs_padded or num_reqs num_reqs_padded = num_reqs_padded or num_reqs
assert num_reqs_padded is not None and num_tokens_padded is not None
logits_indices_padded = None
num_logits_indices = None
if logits_indices is not None:
num_logits_indices = logits_indices.size(0)
if self.cache_config.kv_sharing_fast_prefill:
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices
)
# update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
self.seq_lens.cpu[:num_reqs],
self.dcp_world_size,
self.dcp_rank,
self.parallel_config.cp_kv_cache_interleave_size,
)
self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
attn_metadata: PerLayerAttnMetadata = {} attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None: if ubatch_slices is not None:
...@@ -1576,36 +1561,12 @@ class GPUModelRunner( ...@@ -1576,36 +1561,12 @@ class GPUModelRunner(
self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu() self.num_accepted_tokens.copy_to_gpu()
# Used in the below loop, uses padded shapes kv_cache_groups = self.kv_cache_config.kv_cache_groups
query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1]
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
seq_lens = self.seq_lens.gpu[:num_reqs_padded]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
]
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
if self.dcp_world_size > 1:
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs_padded]
spec_decode_common_attn_metadata = None
# Prepare the attention metadata for each KV cache group and make layers def _get_block_table_and_slot_mapping(kv_cache_gid: int):
# in the same group share the same metadata. assert num_reqs_padded is not None and num_tokens_padded is not None
for kv_cache_gid, kv_cache_group in enumerate( kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
self.kv_cache_config.kv_cache_groups if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
):
encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec,
num_reqs_padded,
)
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros( blk_table_tensor = torch.zeros(
(num_reqs_padded, 1), (num_reqs_padded, 1),
dtype=torch.int32, dtype=torch.int32,
...@@ -1626,61 +1587,71 @@ class GPUModelRunner( ...@@ -1626,61 +1587,71 @@ class GPUModelRunner(
slot_mapping[num_tokens:num_tokens_padded].fill_(-1) slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
common_attn_metadata = CommonAttentionMetadata( return blk_table_tensor, slot_mapping
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu, block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
seq_lens=seq_lens, cm_base = CommonAttentionMetadata(
_seq_lens_cpu=seq_lens_cpu, query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
_num_computed_tokens_cpu=num_computed_tokens_cpu, query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
num_actual_tokens=num_tokens_padded, seq_lens=self.seq_lens.gpu[:num_reqs_padded],
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded],
_num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
],
num_reqs=num_reqs_padded, num_reqs=num_reqs_padded,
num_actual_tokens=num_tokens_padded,
max_query_len=max_query_len, max_query_len=max_query_len,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
block_table_tensor=blk_table_tensor, block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping_gid_0,
logits_indices_padded=logits_indices_padded,
num_logits_indices=num_logits_indices,
causal=True, causal=True,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_cpu=encoder_seq_lens_cpu,
dcp_local_seq_lens=dcp_local_seq_lens,
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
) )
if self.speculative_config and spec_decode_common_attn_metadata is None: if self.dcp_world_size > 1:
if isinstance(self.drafter, EagleProposer): self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: self.seq_lens.cpu[:num_reqs],
spec_decode_common_attn_metadata = common_attn_metadata self.dcp_world_size,
else: self.dcp_rank,
spec_decode_common_attn_metadata = common_attn_metadata self.parallel_config.cp_kv_cache_interleave_size,
)
self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
cm_base.dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
cm_base.dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[
:num_reqs_padded
]
if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill:
cm_base.num_logits_indices = logits_indices.size(0)
cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices
)
for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]): def _build_attn_group_metadata(
kv_cache_gid: int,
attn_gid: int,
common_attn_metadata: CommonAttentionMetadata,
ubid: int | None = None,
) -> None:
attn_group = self.attn_groups[kv_cache_gid][attn_gid]
cascade_attn_prefix_len = ( cascade_attn_prefix_len = (
cascade_attn_prefix_lens[kv_cache_gid][attn_gid] cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
if cascade_attn_prefix_lens if cascade_attn_prefix_lens
else 0 else 0
) )
builder = attn_group.get_metadata_builder()
builder = attn_group.get_metadata_builder(ubid or 0)
extra_attn_metadata_args = {} extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
assert ubid is None, "UBatching not supported with GDN yet"
extra_attn_metadata_args = dict( extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.gpu[ num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded],
:num_reqs_padded
],
num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
:num_reqs_padded :num_reqs_padded
], ],
) )
if ubatch_slices is not None:
common_attn_metadata_list = split_attn_metadata(
ubatch_slices, common_attn_metadata
)
for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list
):
builder = attn_group.get_metadata_builder(ubatch_id=ubid)
if for_cudagraph_capture: if for_cudagraph_capture:
attn_metadata_i = builder.build_for_cudagraph_capture( attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata common_attn_metadata
...@@ -1689,24 +1660,51 @@ class GPUModelRunner( ...@@ -1689,24 +1660,51 @@ class GPUModelRunner(
attn_metadata_i = builder.build( attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len, common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
) )
for layer_name in kv_cache_group.layer_names:
assert type(attn_metadata) is list if ubid is None:
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
if for_cudagraph_capture: attn_metadata_dict = attn_metadata
attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata
)
else: else:
attn_metadata_i = builder.build( assert isinstance(attn_metadata, list)
common_prefix_len=cascade_attn_prefix_len, attn_metadata_dict = attn_metadata[ubid]
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
)
for layer_name in attn_group.layer_names: for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata_dict[layer_name] = attn_metadata_i
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
spec_decode_common_attn_metadata = None
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups):
cm = copy(cm_base) # shallow copy
# Basically only the encoder seq_lens, block_table and slot_mapping change
# for each kv_cache_group.
cm.encoder_seq_lens, cm.encoder_seq_lens_cpu = self._get_encoder_seq_lens(
num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec,
num_reqs_padded,
)
if kv_cache_gid > 0:
cm.block_table_tensor, cm.slot_mapping = (
_get_block_table_and_slot_mapping(kv_cache_gid)
)
if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
spec_decode_common_attn_metadata = cm
else:
spec_decode_common_attn_metadata = cm
for attn_gid in range(len(self.attn_groups[kv_cache_gid])):
if ubatch_slices is not None:
for ubid, _cm in enumerate(split_attn_metadata(ubatch_slices, cm)):
_build_attn_group_metadata(kv_cache_gid, attn_gid, _cm, ubid)
else:
_build_attn_group_metadata(kv_cache_gid, attn_gid, cm)
if self.is_mm_prefix_lm: if self.is_mm_prefix_lm:
req_doc_ranges = {} req_doc_ranges = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment