Unverified Commit c48334d4 authored by Agata Dobrzyniewicz's avatar Agata Dobrzyniewicz Committed by GitHub
Browse files

[Hardware][Intel-Gaudi] Update hpu-extension and update bucketing system for HPU device (#17186)


Signed-off-by: default avatarAgata Dobrzyniewicz <adobrzyniewicz@habana.ai>
parent 909fdaf1
...@@ -9,4 +9,4 @@ numpy==1.26.4 ...@@ -9,4 +9,4 @@ numpy==1.26.4
tabulate tabulate
setuptools>=61 setuptools>=61
setuptools-scm>=8 setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768 vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
############################################################################### ###############################################################################
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, from vllm_hpu_extension.flags import enabled_flags
VLLMKVCache) from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionLayer,
...@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self.block2batch_matmul = Matmul() self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache() self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache() self.v_cache = VLLMKVCache()
ops.pa_impl = ops.pa self.fused_scaled_dot_product_attention = kernels.fsdpa()
self.prefill_impl = 'naive'
if "flex_attention" in enabled_flags():
self.prefill_impl = 'flex'
if "fsdpa" in enabled_flags():
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
self.prefill_impl = 'fsdpa'
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
...@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', if self.prefill_impl == 'fsdpa':
'0').lower() in ['1', 'true']
self.fused_scaled_dot_product_attention = None
if self.prefill_usefusedsdpa:
assert alibi_slopes is None, \ assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!' 'Prefill with FusedSDPA not supported with alibi slopes!'
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
FusedSDPA)
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")
supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes: if head_size not in supported_head_sizes:
...@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f"Head size {head_size} is not supported by PagedAttention. " f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.") f"Supported head sizes are: {supported_head_sizes}.")
if attn_type != AttentionType.DECODER: self.attn_type = attn_type
if self.attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
"are not implemented for " "are not implemented for "
...@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape _, seq_len_kv, _ = key.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt: key_cache = None
value_cache = None
if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY \
and attn_metadata.block_list is None:
key = key.unflatten(0, (block_indices.size(0), -1)) key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None: if kv_cache is not None and isinstance(kv_cache, tuple):
key_cache, value_cache = HPUPagedAttention.split_kv_cache( key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
...@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
# Prompt run. # Prompt run.
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
else:
attn_bias = None
query_shape = (batch_size, seq_len, self.num_heads, self.head_size) query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size) self.head_size)
attn_bias = attn_metadata.attn_bias
if attn_bias is not None and self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
out = ops.prompt_attention( out = ops.prompt_attention(
query.view(query_shape), impl=self.prefill_impl,
key.view(kv_shape), query=query.view(query_shape),
value.view(kv_shape), key=key.view(kv_shape),
value=value.view(kv_shape),
is_causal=True,
attn_bias=attn_bias, attn_bias=attn_bias,
p=0.0, valid_seq_lengths=attn_metadata.seq_lens_tensor,
scale=self.scale, **self.common_attention_args())
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
output = out.reshape(batch_size, seq_len, hidden_size) output = out.reshape(batch_size, seq_len, hidden_size)
else: else:
# Decoding run. # Decoding run.
...@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
block_list=attn_metadata.block_list, block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping, block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias, block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups, block_groups=attn_metadata.block_groups,
scale=self.scale, **self.common_attention_args())
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size) return output.view(batch_size, seq_len, hidden_size)
def common_attention_args(self):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None
return {
'scale': self.scale,
'matmul_qk_op': self.matmul_qk,
'matmul_av_op': self.matmul_av,
'batch2block_matmul_op': self.batch2block_matmul,
'block2batch_matmul_op': self.block2batch_matmul,
'fsdpa_op': fsdpa_op,
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
}
def _make_alibi_bias( def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
......
...@@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata: ...@@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata:
block_usage: Optional[torch.Tensor] block_usage: Optional[torch.Tensor]
block_indices: Optional[torch.Tensor] block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor] block_groups: Optional[torch.Tensor]
......
...@@ -168,7 +168,8 @@ class RMSNorm(CustomOp): ...@@ -168,7 +168,8 @@ class RMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm_hpu_extension.ops import HPUFusedRMSNorm from vllm_hpu_extension.kernels import rms_norm
HPUFusedRMSNorm = rms_norm()
if HPUFusedRMSNorm is None: if HPUFusedRMSNorm is None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
if residual is not None: if residual is not None:
......
This diff is collapsed.
...@@ -245,6 +245,7 @@ class HPUWorker(LocalOrDistributedWorkerBase): ...@@ -245,6 +245,7 @@ class HPUWorker(LocalOrDistributedWorkerBase):
cache_block_size) cache_block_size)
num_hpu_blocks = max(num_hpu_blocks, 0) num_hpu_blocks = max(num_hpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks
if self.model_runner.lora_manager: if self.model_runner.lora_manager:
self.model_runner.remove_all_loras() self.model_runner.remove_all_loras()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment