Unverified Commit 4dbcbbeb authored by Yang Zheng's avatar Yang Zheng Committed by GitHub
Browse files

[Misc] Compute query_start_loc/seq_start_loc on CPU (#9447)


Co-authored-by: default avatarYang Zheng(SW)(Alex) <you@example.com>
parent b67feb12
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch import torch
...@@ -503,6 +504,8 @@ class FlashAttentionMetadataBuilder( ...@@ -503,6 +504,8 @@ class FlashAttentionMetadataBuilder(
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
if use_captured_graph: if use_captured_graph:
...@@ -525,29 +528,18 @@ class FlashAttentionMetadataBuilder( ...@@ -525,29 +528,18 @@ class FlashAttentionMetadataBuilder(
device, self.runner.pin_memory) device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory) self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory) device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
dtype=torch.int32, device,
device=device) self.runner.pin_memory)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
dtype=torch.int32, device, self.runner.pin_memory)
device=device)
placeholder_index_maps = { placeholder_index_maps = {
modality: placeholder_map.index_map() modality: placeholder_map.index_map()
for modality, placeholder_map in for modality, placeholder_map in
self.multimodal_placeholder_maps.items() self.multimodal_placeholder_maps.items()
} }
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
return FlashAttentionMetadata( return FlashAttentionMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
...@@ -561,8 +553,8 @@ class FlashAttentionMetadataBuilder( ...@@ -561,8 +553,8 @@ class FlashAttentionMetadataBuilder(
max_decode_query_len=max_decode_query_len, max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
......
"""Attention backend utils""" """Attention backend utils"""
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
...@@ -216,6 +217,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -216,6 +217,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
if use_captured_graph: if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
...@@ -244,29 +247,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -244,29 +247,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device, self.runner.pin_memory) device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory) self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory) device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
dtype=torch.int32, device,
device=device) self.runner.pin_memory)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
dtype=torch.int32, device, self.runner.pin_memory)
device=device)
placeholder_index_maps = { placeholder_index_maps = {
modality: placeholder_map.index_map() modality: placeholder_map.index_map()
for modality, placeholder_map in for modality, placeholder_map in
self.multimodal_placeholder_maps.items() self.multimodal_placeholder_maps.items()
} }
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
return self._metadata_cls( # type: ignore return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
...@@ -279,8 +271,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -279,8 +271,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
max_query_len=max_query_len, max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment