Commit 44e3ca68 authored by 王敏's avatar 王敏
Browse files

[feat]优化medusa代码,通过VLLM_TREE_DECODING环境变量控制是否采用tree-style解码,计算逻辑主干隔离

parent 54b92ba4
...@@ -53,16 +53,6 @@ class PallasAttentionBackend(AttentionBackend): ...@@ -53,16 +53,6 @@ class PallasAttentionBackend(AttentionBackend):
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices] v_cache[:, dst_indices] = v_cache[:, src_indices]
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass @dataclass
class PallasMetadata(AttentionMetadata): class PallasMetadata(AttentionMetadata):
......
...@@ -72,50 +72,6 @@ class ROCmFlashAttentionBackend(AttentionBackend): ...@@ -72,50 +72,6 @@ class ROCmFlashAttentionBackend(AttentionBackend):
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass @dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
......
...@@ -65,16 +65,6 @@ class TorchSDPABackend(AttentionBackend): ...@@ -65,16 +65,6 @@ class TorchSDPABackend(AttentionBackend):
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass @dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
......
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional
import torch
from vllm.attention.backends.blocksparse_attn import BlocksparseFlashAttentionImpl
from vllm import _custom_ops as ops
from vllm.attention.ops.paged_attn import PagedAttention
def move_cache(
backend,
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
if backend.get_name() == "rocm-flash-attn" or \
backend.get_name() == "xformers":
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
else:
raise NotImplementedError("Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!")
\ No newline at end of file
...@@ -9,6 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, ...@@ -9,6 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.model_runner_base import ModelRunnerBase
...@@ -188,8 +190,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -188,8 +190,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.block_size, inter_data.block_tables) self.block_size, inter_data.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int], def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int, cuda_graph_pad_size: int, batch_size: int):
tree_attention_masks_tensor: Optional[torch.Tensor] = None):
"""Build attention metadata with on-device tensors. """Build attention metadata with on-device tensors.
Args: Args:
...@@ -272,7 +273,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -272,7 +273,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
tree_attention_masks_tensor=tree_attention_masks_tensor,
block_tables_list=self.block_tables block_tables_list=self.block_tables
) )
......
...@@ -68,50 +68,6 @@ class XFormersBackend(AttentionBackend): ...@@ -68,50 +68,6 @@ class XFormersBackend(AttentionBackend):
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass @dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
......
This diff is collapsed.
This diff is collapsed.
...@@ -176,7 +176,6 @@ class EngineArgs: ...@@ -176,7 +176,6 @@ class EngineArgs:
disable_logprobs_during_spec_decoding: Optional[bool] = None disable_logprobs_during_spec_decoding: Optional[bool] = None
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
tree_style_spec_decoding: Optional[bool] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
...@@ -708,11 +707,6 @@ class EngineArgs: ...@@ -708,11 +707,6 @@ class EngineArgs:
'a higher acceptance rate at the cost of lower quality, ' 'a higher acceptance rate at the cost of lower quality, '
'and vice versa.') 'and vice versa.')
parser.add_argument('--tree-style-spec-decoding',
type=bool,
default=False,
help='If set to True, tree-style generation will be activated.')
parser.add_argument( parser.add_argument(
'--typical-acceptance-sampler-posterior-threshold', '--typical-acceptance-sampler-posterior-threshold',
type=float, type=float,
...@@ -997,7 +991,6 @@ class EngineArgs: ...@@ -997,7 +991,6 @@ class EngineArgs:
typical_acceptance_sampler_posterior_alpha=self. typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding, disable_logprobs=self.disable_logprobs_during_spec_decoding,
tree_style_spec_decoding=self.tree_style_spec_decoding,
num_speculative_heads=self.num_speculative_heads num_speculative_heads=self.num_speculative_heads
) )
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment