Unverified Commit 843b2227 authored by Agata Dobrzyniewicz's avatar Agata Dobrzyniewicz Committed by GitHub
Browse files

[Hardware][Intel-Gaudi] Support Automatic Prefix Caching on HPU (#17648)


Signed-off-by: default avatarAgata Dobrzyniewicz <adobrzyniewicz@habana.ai>
parent e515668e
...@@ -57,16 +57,16 @@ class HPUAttentionBackend(AttentionBackend): ...@@ -57,16 +57,16 @@ class HPUAttentionBackend(AttentionBackend):
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int], src_to_dsts: torch.Tensor,
) -> None: ) -> None:
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]], src_to_dsts: torch.Tensor,
) -> None: ) -> None:
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists) HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)
@dataclass @dataclass
...@@ -77,6 +77,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): ...@@ -77,6 +77,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
is_prompt: bool is_prompt: bool
attn_bias: Optional[torch.Tensor] attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
class HPUAttentionImpl(AttentionImpl, torch.nn.Module): class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...@@ -198,8 +199,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -198,8 +199,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
key_cache = None key_cache = None
value_cache = None value_cache = None
if attn_metadata.is_prompt and self.attn_type \ if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY \ is not AttentionType.ENCODER_ONLY:
and attn_metadata.block_list is None:
key = key.unflatten(0, (block_indices.size(0), -1)) key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None and isinstance(kv_cache, tuple): if kv_cache is not None and isinstance(kv_cache, tuple):
...@@ -229,6 +229,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -229,6 +229,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias) attn_bias.add_(position_bias)
block_list = attn_metadata.block_list if attn_metadata \
and attn_metadata.block_list is not None else None
out = ops.prompt_attention( out = ops.prompt_attention(
impl=self.prefill_impl, impl=self.prefill_impl,
query=query.view(query_shape), query=query.view(query_shape),
...@@ -237,23 +240,25 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -237,23 +240,25 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
is_causal=True, is_causal=True,
attn_bias=attn_bias, attn_bias=attn_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor, valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args()) **self.common_attention_args(block_list, key_cache,
value_cache))
output = out.reshape(batch_size, seq_len, hidden_size) output = out.reshape(batch_size, seq_len, hidden_size)
else: else:
# Decoding run. # Decoding run.
output = HPUPagedAttention.forward_decode( output = HPUPagedAttention.forward_decode(
query=query, query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping, block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias, block_bias=attn_metadata.attn_bias,
block_groups=attn_metadata.block_groups, block_groups=attn_metadata.block_groups,
**self.common_attention_args()) **self.common_attention_args(attn_metadata.block_list,
key_cache, value_cache))
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size) return output.view(batch_size, seq_len, hidden_size)
def common_attention_args(self): def common_attention_args(self,
block_list=None,
key_cache=None,
value_cache=None):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \ fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None if self.fused_scaled_dot_product_attention is not None else None
return { return {
...@@ -266,6 +271,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -266,6 +271,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
'keys_fetch_func': self.k_cache.fetch_from_cache, 'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache, 'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax, 'softmax_op': self.softmax,
'block_list': block_list,
'key_cache': key_cache,
'value_cache': value_cache,
} }
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
############################################################################### ###############################################################################
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from vllm_hpu_extension import cache_ops, ops from vllm_hpu_extension import cache_ops, ops
...@@ -63,43 +63,25 @@ class HPUPagedAttention: ...@@ -63,43 +63,25 @@ class HPUPagedAttention:
def forward_decode(**kwargs) -> torch.Tensor: def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs) return ops.flat_pa(**kwargs)
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
) -> torch.Tensor:
raise NotImplementedError(
"forward_prefix is not implemented for HPUPagedAttention")
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
dst_kv_cache: torch.Tensor, dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
src_to_dst: Dict[int, int], src_to_dsts: torch.Tensor,
) -> None: ) -> None:
src_key_cache = src_kv_cache[0] src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0] dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts)
src_value_cache = src_kv_cache[1] src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1] dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Dict[int, List[int]], src_to_dsts: torch.Tensor,
) -> None: ) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches] key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
...@@ -14,7 +14,7 @@ import math ...@@ -14,7 +14,7 @@ import math
import os import os
import time import time
from array import array from array import array
from enum import IntEnum from enum import Enum, IntEnum
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional, Set, Tuple, Type, TypeVar, Union) Optional, Set, Tuple, Type, TypeVar, Union)
...@@ -75,6 +75,12 @@ LORA_WARMUP_RANK = 8 ...@@ -75,6 +75,12 @@ LORA_WARMUP_RANK = 8
DUMMY_TOKEN_ID = -1 DUMMY_TOKEN_ID = -1
class PhaseType(Enum):
PREFILL = 'prefill'
PREFIX_PREFILL = 'prefix_prefill'
DECODE = 'decode'
def subtuple(obj: object, def subtuple(obj: object,
typename: str, typename: str,
to_copy: List[str], to_copy: List[str],
...@@ -213,20 +219,40 @@ class HpuModelAdapter: ...@@ -213,20 +219,40 @@ class HpuModelAdapter:
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype): dtype):
prefill_metadata = attn_metadata if (attn_metadata is None
if prefill_metadata is None or self.prefill_use_fusedsdpa: or (self.prefill_use_fusedsdpa \
and attn_metadata.block_list is None)
or not attn_metadata.is_prompt):
return attn_metadata return attn_metadata
prefill_metadata = attn_metadata
seq_lens_t = prefill_metadata.seq_lens_tensor seq_lens_t = prefill_metadata.seq_lens_tensor
context_lens_t = prefill_metadata.context_lens_tensor
query_lens_t = seq_lens_t - context_lens_t
block_list = attn_metadata.block_list
max_context_len = (block_list.size(-1) //
batch_size if block_list is not None else 0)
max_context_len = max_context_len * self.block_size
past_mask = torch.arange(0,
max_context_len,
dtype=torch.int32,
device=device)
past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(
context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand(
batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1))
len_mask = (torch.arange(0, seq_len, device=device, len_mask = (torch.arange(0, seq_len, device=device,
dtype=torch.int32).view(1, seq_len).ge( dtype=torch.int32).view(1, seq_len).ge(
seq_lens_t.unsqueeze(-1)).view( query_lens_t.unsqueeze(-1)).view(
batch_size, 1, 1, seq_len)) batch_size, 1, 1, seq_len))
causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len),
device=device, device=device,
dtype=torch.bool), dtype=torch.bool),
diagonal=1) diagonal=1)
mask = causal_mask.logical_or(len_mask) mask = causal_mask.logical_or(len_mask)
mask = torch.concat((past_mask, mask), dim=-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf)) mask, -math.inf))
attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) attn_metadata = prefill_metadata._replace(attn_bias=attn_bias)
...@@ -517,6 +543,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -517,6 +543,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
False, self.max_model_len) False, self.max_model_len)
self.graphed_buckets: Set[Any] = set() self.graphed_buckets: Set[Any] = set()
self._set_gc_threshold() self._set_gc_threshold()
if self.vllm_config.cache_config.enable_prefix_caching:
os.environ.setdefault("VLLM_CONTIGUOUS_PA", "False")
assert os.environ.get(
"VLLM_CONTIGUOUS_PA",
"").lower() != "true", "Contiguous PA doesn't support APC"
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
# For multi-step scheduling # For multi-step scheduling
...@@ -702,6 +733,10 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -702,6 +733,10 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
computed_block_nums) > 0 and self.sliding_window is None: computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window # Prefix is not supported with sliding_window
context_len = len(computed_block_nums) * self.block_size context_len = len(computed_block_nums) * self.block_size
if context_len == seq_len \
and self.vllm_config.cache_config.enable_prefix_caching:
# Fully cached prompt - compute only last token
context_len = context_len - 1
prompt_tokens = prompt_tokens[context_len:] prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums) prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled: elif self.scheduler_config.chunked_prefill_enabled:
...@@ -779,12 +814,33 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -779,12 +814,33 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
if lora_id > 0: if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request) lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * (max_prompt_len - context_len) lora_index_mapping += [lora_id] * max_prompt_len
lora_prompt_mapping.extend( lora_prompt_mapping.extend(
[lora_id] * [lora_id] *
(max_prompt_len - context_len (max_prompt_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if any(context_lens):
assert not self.scheduler_config.chunked_prefill_enabled
# prefix caching
max_num_block = max(len(bt) for bt in prefix_block_tables)
prefix_block_list = list(
itertools.chain.from_iterable(
bt if len(bt) == max_num_block else bt +
([_PAD_BLOCK_ID] * (max_num_block - len(bt)))
for bt in prefix_block_tables))
pad_len = len(prefix_block_list)
prefix_block_list = pad_list(prefix_block_list, pad_len,
_PAD_BLOCK_ID)
prefix_block_list_tensor = torch.tensor(prefix_block_list,
dtype=torch.long,
device=self.device)
else:
prefix_block_list_tensor = None
input_tokens = make_tensor_with_pad(input_tokens, input_tokens = make_tensor_with_pad(input_tokens,
max_len=max_prompt_len, max_len=max_prompt_len,
pad=0, pad=0,
...@@ -807,11 +863,15 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -807,11 +863,15 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.long,
device=self.device)
block_indices, block_offsets = precompute_indices_and_offsets( block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, True) self.block_size, slot_mapping, True)
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
block_list=None, block_list=prefix_block_list_tensor,
block_mapping=None, block_mapping=None,
block_usage=None, block_usage=None,
block_indices=block_indices, block_indices=block_indices,
...@@ -819,6 +879,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -819,6 +879,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_groups=None, block_groups=None,
attn_bias=None, attn_bias=None,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
context_lens_tensor=context_lens_tensor,
num_prefills=real_num_seqs, num_prefills=real_num_seqs,
num_prefill_tokens=sum_query_len, num_prefill_tokens=sum_query_len,
num_decode_tokens=0, num_decode_tokens=0,
...@@ -987,6 +1048,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -987,6 +1048,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_groups=block_groups, block_groups=block_groups,
attn_bias=None, attn_bias=None,
seq_lens_tensor=None, seq_lens_tensor=None,
context_lens_tensor=None,
num_prefills=0, num_prefills=0,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
...@@ -1091,7 +1153,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1091,7 +1153,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
# FIXME: We need to adjust selected_token_indices to accommodate # FIXME: We need to adjust selected_token_indices to accommodate
# for padding # for padding
max_len = input_tokens.size(1) max_len = input_tokens.size(1)
paddings = [max_len - s for s in seq_lens] paddings = [max_len - q for q in query_lens]
paddings = [0] + paddings[:-1] paddings = [0] + paddings[:-1]
paddings = list(itertools.accumulate(paddings)) paddings = list(itertools.accumulate(paddings))
paddings_prompt_logprobs = [] paddings_prompt_logprobs = []
...@@ -1187,9 +1249,17 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1187,9 +1249,17 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
# input_hash(123) != input_hash(321) # input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba") # input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', 'attn_bias',
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', 'seq_lens_tensor',
'block_offsets', 'block_groups' 'context_lens_tensor',
'block_list',
'block_mapping',
'block_usage',
'slot_mapping',
'is_prompt',
'block_indices',
'block_offsets',
'block_groups',
]) ])
return attention_metadata return attention_metadata
...@@ -1733,14 +1803,44 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): ...@@ -1733,14 +1803,44 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
from neural_compressor.torch.quantization import finalize_calibration from neural_compressor.torch.quantization import finalize_calibration
finalize_calibration(self.model.model) finalize_calibration(self.model.model)
def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): def _num_blocks(self, attn_metadata):
cfg = (batch_size, seq_len, is_prompt) if attn_metadata.block_list is None:
return 0
return attn_metadata.block_list.numel()
def _phase(self, attn_metadata):
phase_type: PhaseType
is_prompt = attn_metadata.is_prompt
is_prefix_prefill = is_prompt and attn_metadata.block_list is not None
if is_prompt and is_prefix_prefill:
phase_type = PhaseType.PREFIX_PREFILL
elif is_prompt and not is_prefix_prefill:
phase_type = PhaseType.PREFILL
elif not is_prompt:
phase_type = PhaseType.DECODE
else:
raise ValueError("Unrecognized pass type, likely due to malformed "
"attention metadata")
return phase_type
def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode):
is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
cfg: Optional[tuple] = None
assert cfg is None, "Configs changed between 2D and 3D"
if is_prefix_caching:
phase = self._phase(attn_metadata)
num_blocks = self._num_blocks(attn_metadata)
cfg = (batch_size, seq_len, num_blocks, phase)
else:
phase = 'prompt' if attn_metadata.is_prompt else 'decode'
cfg = (batch_size, seq_len, phase)
seen = cfg in self.seen_configs seen = cfg in self.seen_configs
self.seen_configs.add(cfg) self.seen_configs.add(cfg)
if not seen and not warmup_mode: if not seen and not warmup_mode:
phase = 'prompt' if is_prompt else 'decode' logger.warning("Configuration: %s was not warmed-up!",
logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", (phase.value, batch_size, seq_len,
phase, batch_size, seq_len) num_blocks) if is_prefix_caching else
(phase, batch_size, seq_len))
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
is_prompt: bool): is_prompt: bool):
...@@ -1912,7 +2012,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): ...@@ -1912,7 +2012,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
batch_size = input_tokens.size(0) batch_size = input_tokens.size(0)
seq_len = self._seq_len(attn_metadata) seq_len = self._seq_len(attn_metadata)
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
self._check_config(batch_size, seq_len, is_prompt, warmup_mode) self._check_config(batch_size, seq_len, attn_metadata, warmup_mode)
lora_mask: torch.Tensor = None lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None lora_logits_mask: torch.Tensor = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment