"plugins/amoeba/vscode:/vscode.git/clone" did not exist on "5270f8585e917ae21c66ad6ae0f8701ff7c37071"
Unverified Commit a28b94e6 authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Performance] Split FlashAttn attention and cache update (#25954)


Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarLuka Govedič <luka.govedic@gmail.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <luka.govedic@gmail.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarLuka Govedič <lgovedic@redhat.com>
parent 0118cdcc
......@@ -295,6 +295,7 @@ class UBatchWrapper:
self,
ubatch_slices,
attn_metadata,
slot_mapping,
input_ids,
positions,
inputs_embeds,
......@@ -306,6 +307,9 @@ class UBatchWrapper:
) -> list[UbatchMetadata]:
# Create one forward context per ubatch
forward_contexts = []
# slot_mapping can be None, an empty dict (from create_forward_context
# converting None to {}), or a list of dicts (one per ubatch)
has_slot_mapping = slot_mapping and isinstance(slot_mapping, list)
for i, ubatch_slice in enumerate(ubatch_slices):
forward_contexts.append(
create_forward_context(
......@@ -314,6 +318,7 @@ class UBatchWrapper:
dp_metadata=dp_metadata[i],
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping[i] if has_slot_mapping else None,
)
)
......@@ -406,6 +411,7 @@ class UBatchWrapper:
return self.cudagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata
slot_mapping = forward_context.slot_mapping
num_tokens = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
) * 2
......@@ -440,6 +446,7 @@ class UBatchWrapper:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
......@@ -462,6 +469,7 @@ class UBatchWrapper:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment