Commit d589e598 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'

[feat]优化medusa代码,通过VLLM_TREE_DECODING环境变量控制是否采用tree-style解码,计算逻辑与主干隔离

See merge request dcutoolkit/deeplearing/vllm!51
parents 54b92ba4 0bb491f8
......@@ -62,16 +62,6 @@ class OpenVINOAttentionBackend(AttentionBackend):
key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :]
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass
class OpenVINOAttentionMetadata:
......
......@@ -53,16 +53,6 @@ class PallasAttentionBackend(AttentionBackend):
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass
class PallasMetadata(AttentionMetadata):
......
......@@ -72,50 +72,6 @@ class ROCmFlashAttentionBackend(AttentionBackend):
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
......
......@@ -65,16 +65,6 @@ class TorchSDPABackend(AttentionBackend):
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
......
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional
import torch
from vllm.attention.backends.blocksparse_attn import BlocksparseFlashAttentionImpl
from vllm import _custom_ops as ops
from vllm.attention.ops.paged_attn import PagedAttention
def move_cache(
backend,
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
if backend.get_name() == "rocm-flash-attn" or \
backend.get_name() == "xformers":
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
else:
raise NotImplementedError("Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!")
\ No newline at end of file
......@@ -9,6 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
......@@ -188,8 +190,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.block_size, inter_data.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int,
tree_attention_masks_tensor: Optional[torch.Tensor] = None):
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
......@@ -272,7 +273,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
tree_attention_masks_tensor=tree_attention_masks_tensor,
block_tables_list=self.block_tables
)
......
......@@ -68,50 +68,6 @@ class XFormersBackend(AttentionBackend):
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -117,8 +117,6 @@ class LoRAModel(AdapterModel):
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
if "lora_A" not in tensor_name and "lora_B" not in tensor_name:
continue
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
if module_name not in loras:
lora_embeddings_tensor = None
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment