Commit 66b809cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.2' into v0.7.2-dev

parents 37b63c24 0408efc6
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
import torch
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Tuple
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, TypeVar
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, Optional, Set
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Optional, Set
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Literal
from urllib.parse import urljoin
......
# SPDX-License-Identifier: Apache-2.0
from functools import lru_cache
from pathlib import Path
from typing import Optional
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Literal
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Literal
......
# SPDX-License-Identifier: Apache-2.0
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention."""
from collections import defaultdict
from dataclasses import dataclass
......
# SPDX-License-Identifier: Apache-2.0
import dataclasses
from collections import defaultdict
from contextlib import contextmanager
......
# SPDX-License-Identifier: Apache-2.0
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
......@@ -8,7 +10,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
......@@ -135,9 +138,17 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
self.fused_scaled_dot_product_attention = None
if self.prefill_usefusedsdpa:
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
FusedSDPA)
except ImportError:
logger().warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
......@@ -225,6 +236,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
......
# SPDX-License-Identifier: Apache-2.0
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
......
# SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple
......@@ -24,8 +26,13 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
except ImportError:
from flash_attn import flash_attn_varlen_func
@dataclass
......@@ -168,6 +175,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.v_head_dim = v_head_dim
self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
......@@ -414,6 +423,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
) -> torch.Tensor:
raise NotImplementedError
def apply_pure_rope(
self,
input_positions: torch.Tensor,
q_pe: torch.Tensor,
k_pe: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = input_positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe, k_pe = self.rotary_emb(
input_positions,
q_pe.reshape(seq_len, -1),
k_pe.reshape(seq_len, -1),
)
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
return q_pe, k_pe
def forward(
self,
layer: AttentionLayer,
......@@ -438,13 +465,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")
rope_fn = (self.rotary_emb
if self.use_yarn_rope else self.apply_pure_rope)
if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
q_pe, k_pe = \
self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
else:
assert is_prefill
q = self.q_proj(hidden_states_or_q_c)[0]\
......@@ -452,7 +480,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
self.rotary_emb(
rope_fn(
attn_metadata.input_positions,
q[..., self.qk_nope_head_dim:], k_pe)
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
......@@ -138,3 +140,7 @@ class OpenVINOAttentionMetadata:
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment