Unverified Commit a28b94e6 authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Performance] Split FlashAttn attention and cache update (#25954)


Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarLuka Govedič <luka.govedic@gmail.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <luka.govedic@gmail.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarLuka Govedič <lgovedic@redhat.com>
parent 0118cdcc
...@@ -295,6 +295,7 @@ class UBatchWrapper: ...@@ -295,6 +295,7 @@ class UBatchWrapper:
self, self,
ubatch_slices, ubatch_slices,
attn_metadata, attn_metadata,
slot_mapping,
input_ids, input_ids,
positions, positions,
inputs_embeds, inputs_embeds,
...@@ -306,6 +307,9 @@ class UBatchWrapper: ...@@ -306,6 +307,9 @@ class UBatchWrapper:
) -> list[UbatchMetadata]: ) -> list[UbatchMetadata]:
# Create one forward context per ubatch # Create one forward context per ubatch
forward_contexts = [] forward_contexts = []
# slot_mapping can be None, an empty dict (from create_forward_context
# converting None to {}), or a list of dicts (one per ubatch)
has_slot_mapping = slot_mapping and isinstance(slot_mapping, list)
for i, ubatch_slice in enumerate(ubatch_slices): for i, ubatch_slice in enumerate(ubatch_slices):
forward_contexts.append( forward_contexts.append(
create_forward_context( create_forward_context(
...@@ -314,6 +318,7 @@ class UBatchWrapper: ...@@ -314,6 +318,7 @@ class UBatchWrapper:
dp_metadata=dp_metadata[i], dp_metadata=dp_metadata[i],
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping[i] if has_slot_mapping else None,
) )
) )
...@@ -406,6 +411,7 @@ class UBatchWrapper: ...@@ -406,6 +411,7 @@ class UBatchWrapper:
return self.cudagraph_wrapper(*args, **kwargs) return self.cudagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
slot_mapping = forward_context.slot_mapping
num_tokens = ( num_tokens = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
) * 2 ) * 2
...@@ -440,6 +446,7 @@ class UBatchWrapper: ...@@ -440,6 +446,7 @@ class UBatchWrapper:
ubatch_metadata = self._make_ubatch_metadata( ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
slot_mapping=slot_mapping,
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
...@@ -462,6 +469,7 @@ class UBatchWrapper: ...@@ -462,6 +469,7 @@ class UBatchWrapper:
ubatch_metadata = self._make_ubatch_metadata( ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
slot_mapping=slot_mapping,
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment