Unverified Commit 1acca3a2 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

FA3 speed up: skip len operation and get batch size directly from forward batch (#5969)


Signed-off-by: default avatarLifu Huang <lifu.hlf@gmail.com>
parent 6ea1e6ac
...@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend):
"""Initialize forward metadata hence all layers in the forward pass can reuse it.""" """Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
seqlens_in_batch = forward_batch.seq_lens seqlens_in_batch = forward_batch.seq_lens
batch_size = len(seqlens_in_batch) batch_size = forward_batch.batch_size
device = seqlens_in_batch.device device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment