Commit d589e598 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'

[feat]优化medusa代码,通过VLLM_TREE_DECODING环境变量控制是否采用tree-style解码,计算逻辑与主干隔离

See merge request dcutoolkit/deeplearing/vllm!51
parents 54b92ba4 0bb491f8
...@@ -62,16 +62,6 @@ class OpenVINOAttentionBackend(AttentionBackend): ...@@ -62,16 +62,6 @@ class OpenVINOAttentionBackend(AttentionBackend):
key_cache.data[dst, :] = key_cache.data[src, :] key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :] value_cache.data[dst, :] = value_cache.data[src, :]
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass @dataclass
class OpenVINOAttentionMetadata: class OpenVINOAttentionMetadata:
......
...@@ -53,16 +53,6 @@ class PallasAttentionBackend(AttentionBackend): ...@@ -53,16 +53,6 @@ class PallasAttentionBackend(AttentionBackend):
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices] v_cache[:, dst_indices] = v_cache[:, src_indices]
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass @dataclass
class PallasMetadata(AttentionMetadata): class PallasMetadata(AttentionMetadata):
......
...@@ -72,50 +72,6 @@ class ROCmFlashAttentionBackend(AttentionBackend): ...@@ -72,50 +72,6 @@ class ROCmFlashAttentionBackend(AttentionBackend):
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass @dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
......
...@@ -65,16 +65,6 @@ class TorchSDPABackend(AttentionBackend): ...@@ -65,16 +65,6 @@ class TorchSDPABackend(AttentionBackend):
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass @dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
......
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional
import torch
from vllm.attention.backends.blocksparse_attn import BlocksparseFlashAttentionImpl
from vllm import _custom_ops as ops
from vllm.attention.ops.paged_attn import PagedAttention
def move_cache(
backend,
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
if backend.get_name() == "rocm-flash-attn" or \
backend.get_name() == "xformers":
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
else:
raise NotImplementedError("Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!")
\ No newline at end of file
...@@ -9,6 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, ...@@ -9,6 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.model_runner_base import ModelRunnerBase
...@@ -188,8 +190,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -188,8 +190,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.block_size, inter_data.block_tables) self.block_size, inter_data.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int], def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int, cuda_graph_pad_size: int, batch_size: int):
tree_attention_masks_tensor: Optional[torch.Tensor] = None):
"""Build attention metadata with on-device tensors. """Build attention metadata with on-device tensors.
Args: Args:
...@@ -272,7 +273,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -272,7 +273,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
tree_attention_masks_tensor=tree_attention_masks_tensor,
block_tables_list=self.block_tables block_tables_list=self.block_tables
) )
......
...@@ -68,50 +68,6 @@ class XFormersBackend(AttentionBackend): ...@@ -68,50 +68,6 @@ class XFormersBackend(AttentionBackend):
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass @dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
......
...@@ -142,7 +142,102 @@ class PagedAttention: ...@@ -142,7 +142,102 @@ class PagedAttention:
if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
if envs.VLLM_USE_TC_PAGED_ATTN: if envs.VLLM_USE_TC_PAGED_ATTN:
ops.paged_attention_v1_opt_tc( if attn_masks is None:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v1_opt_tc_with_mask(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v1_opt_with_mask(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None:
ops.paged_attention_v1(
output, output,
query, query,
key_cache, key_cache,
...@@ -161,12 +256,10 @@ class PagedAttention: ...@@ -161,12 +256,10 @@ class PagedAttention:
blocksparse_local_blocks, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step, blocksparse_head_sliding_step
attn_masks,
attn_masks_stride
) )
else: else:
ops.paged_attention_v1_opt( ops.paged_attention_v1_with_mask(
output, output,
query, query,
key_cache, key_cache,
...@@ -189,30 +282,6 @@ class PagedAttention: ...@@ -189,30 +282,6 @@ class PagedAttention:
attn_masks, attn_masks,
attn_masks_stride attn_masks_stride
) )
else:
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0 assert _PARTITION_SIZE % block_size == 0
...@@ -236,7 +305,114 @@ class PagedAttention: ...@@ -236,7 +305,114 @@ class PagedAttention:
if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
if envs.VLLM_USE_TC_PAGED_ATTN: if envs.VLLM_USE_TC_PAGED_ATTN:
ops.paged_attention_v2_opt_tc( if attn_masks is None:
ops.paged_attention_v2_opt_tc(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v2_opt_tc_with_mask(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v2_opt_with_mask(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None:
ops.paged_attention_v2(
output, output,
exp_sums, exp_sums,
max_logits, max_logits,
...@@ -258,12 +434,10 @@ class PagedAttention: ...@@ -258,12 +434,10 @@ class PagedAttention:
blocksparse_local_blocks, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step, blocksparse_head_sliding_step
attn_masks,
attn_masks_stride
) )
else: else:
ops.paged_attention_v2_opt( ops.paged_attention_v2_with_mask(
output, output,
exp_sums, exp_sums,
max_logits, max_logits,
...@@ -289,33 +463,6 @@ class PagedAttention: ...@@ -289,33 +463,6 @@ class PagedAttention:
attn_masks, attn_masks,
attn_masks_stride attn_masks_stride
) )
else:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
return output return output
@staticmethod @staticmethod
......
...@@ -1130,7 +1130,6 @@ class SpeculativeConfig: ...@@ -1130,7 +1130,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha: Optional[float], typical_acceptance_sampler_posterior_alpha: Optional[float],
disable_logprobs: Optional[bool], disable_logprobs: Optional[bool],
num_speculative_heads: Optional[int], num_speculative_heads: Optional[int],
tree_style_spec_decoding: Optional[bool]=None,
) -> Optional["SpeculativeConfig"]: ) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None. """Create a SpeculativeConfig if possible, else return None.
...@@ -1191,9 +1190,6 @@ class SpeculativeConfig: ...@@ -1191,9 +1190,6 @@ class SpeculativeConfig:
num_speculative_heads (Optional[int]): It will be used in tree-style num_speculative_heads (Optional[int]): It will be used in tree-style
speculative generation, representing how many heads the draft model speculative generation, representing how many heads the draft model
has. has.
tree_style_spec_decoding (Optional[bool]): If set to True,
tree-style generation will be activated. If not specified,
it defaults to False.
Returns: Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...@@ -1308,9 +1304,9 @@ class SpeculativeConfig: ...@@ -1308,9 +1304,9 @@ class SpeculativeConfig:
"n_predict parameter.") "n_predict parameter.")
if typical_acceptance_sampler_posterior_threshold is None: if typical_acceptance_sampler_posterior_threshold is None:
typical_acceptance_sampler_posterior_threshold = 0.3 typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None: if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.09 typical_acceptance_sampler_posterior_alpha = 0.3
if disable_logprobs is None: if disable_logprobs is None:
disable_logprobs = True disable_logprobs = True
...@@ -1328,7 +1324,6 @@ class SpeculativeConfig: ...@@ -1328,7 +1324,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs, disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats, disable_log_stats=disable_log_stats,
tree_style_spec_decoding=tree_style_spec_decoding
) )
@staticmethod @staticmethod
...@@ -1423,7 +1418,6 @@ class SpeculativeConfig: ...@@ -1423,7 +1418,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha: float, typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool, disable_logprobs: bool,
disable_log_stats: bool, disable_log_stats: bool,
tree_style_spec_decoding: bool,
): ):
"""Create a SpeculativeConfig object. """Create a SpeculativeConfig object.
...@@ -1458,7 +1452,6 @@ class SpeculativeConfig: ...@@ -1458,7 +1452,6 @@ class SpeculativeConfig:
returned. returned.
disable_log_stats: Whether to disable periodic printing of stage disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding. times in speculative decoding.
tree_style_spec_decoding: Whether to use tree-style generation.
""" """
self.draft_model_config = draft_model_config self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config self.draft_parallel_config = draft_parallel_config
...@@ -1474,7 +1467,6 @@ class SpeculativeConfig: ...@@ -1474,7 +1467,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats self.disable_log_stats = disable_log_stats
self.tree_style_spec_decoding = tree_style_spec_decoding
self._verify_args() self._verify_args()
......
...@@ -176,7 +176,6 @@ class EngineArgs: ...@@ -176,7 +176,6 @@ class EngineArgs:
disable_logprobs_during_spec_decoding: Optional[bool] = None disable_logprobs_during_spec_decoding: Optional[bool] = None
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
tree_style_spec_decoding: Optional[bool] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
...@@ -707,11 +706,6 @@ class EngineArgs: ...@@ -707,11 +706,6 @@ class EngineArgs:
'2) TypicalAcceptanceSampler which is configurable, allowing for ' '2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, ' 'a higher acceptance rate at the cost of lower quality, '
'and vice versa.') 'and vice versa.')
parser.add_argument('--tree-style-spec-decoding',
type=bool,
default=False,
help='If set to True, tree-style generation will be activated.')
parser.add_argument( parser.add_argument(
'--typical-acceptance-sampler-posterior-threshold', '--typical-acceptance-sampler-posterior-threshold',
...@@ -997,7 +991,6 @@ class EngineArgs: ...@@ -997,7 +991,6 @@ class EngineArgs:
typical_acceptance_sampler_posterior_alpha=self. typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding, disable_logprobs=self.disable_logprobs_during_spec_decoding,
tree_style_spec_decoding=self.tree_style_spec_decoding,
num_speculative_heads=self.num_speculative_heads num_speculative_heads=self.num_speculative_heads
) )
......
import os
import time import time
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
...@@ -463,6 +464,8 @@ class LLMEngine: ...@@ -463,6 +464,8 @@ class LLMEngine:
get_tokenizer_for_seq, get_tokenizer_for_seq,
), ),
)) ))
self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
...@@ -989,16 +992,15 @@ class LLMEngine: ...@@ -989,16 +992,15 @@ class LLMEngine:
output = [outputs_by_sequence_group[0][i]] output = [outputs_by_sequence_group[0][i]]
# tree style speculative decoding may generate empty output in first step # tree style speculative decoding may generate empty output in first step
if self.speculative_config and self.speculative_config.tree_style_spec_decoding: if self.tree_decoding and outputs and isinstance(output[0], CompletionSequenceGroupOutput):
if outputs and isinstance(output[0], CompletionSequenceGroupOutput): samples = [o.samples[0] for o in output]
samples = [o.samples[0] for o in output] valid_samples = [
valid_samples = [ sample for sample in samples
sample for sample in samples if sample.output_token != VLLM_INVALID_TOKEN_ID
if sample.output_token != VLLM_INVALID_TOKEN_ID ]
] if len(valid_samples) == 0:
if len(valid_samples) == 0: empty_seq_indices.append(i)
empty_seq_indices.append(i) continue
continue
if not is_async: if not is_async:
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
......
...@@ -68,6 +68,7 @@ if TYPE_CHECKING: ...@@ -68,6 +68,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_AWQ: bool = False VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False
VLLM_TREE_DECODING: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -453,6 +454,12 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -453,6 +454,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: lambda:
(os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in
("1", "true")), ("1", "true")),
# If set, vLLM will use tree-style speculative decoding.
"VLLM_TREE_DECODING":
lambda:
(os.environ.get("VLLM_TREE_DECODING", "0").strip().lower() in
("1", "true"))
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -117,8 +117,6 @@ class LoRAModel(AdapterModel): ...@@ -117,8 +117,6 @@ class LoRAModel(AdapterModel):
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {} loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items(): for tensor_name, tensor in tensors.items():
if "lora_A" not in tensor_name and "lora_B" not in tensor_name:
continue
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
if module_name not in loras: if module_name not in loras:
lora_embeddings_tensor = None lora_embeddings_tensor = None
......
import os
from typing import Optional, List from typing import Optional, List
import torch import torch
import torch.jit import torch.jit
...@@ -39,6 +40,8 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -39,6 +40,8 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self._posterior_alpha = posterior_alpha self._posterior_alpha = posterior_alpha
super().__init__(strict_mode=strict_mode) super().__init__(strict_mode=strict_mode)
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def forward( def forward(
self, self,
target_with_bonus_probs: torch.Tensor, target_with_bonus_probs: torch.Tensor,
...@@ -92,7 +95,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -92,7 +95,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self._raise_if_incorrect_input(target_with_bonus_probs, self._raise_if_incorrect_input(target_with_bonus_probs,
draft_token_ids, bonus_token_ids) draft_token_ids, bonus_token_ids)
if cart_candidates is None: if not self.tree_decoding:
target_probs = target_with_bonus_probs[:, :-1] target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs, accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids) draft_token_ids)
...@@ -101,6 +104,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -101,6 +104,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
draft_token_ids, draft_token_ids,
bonus_token_ids) bonus_token_ids)
else: else:
assert cart_candidates is not None
target_probs = target_with_bonus_probs target_probs = target_with_bonus_probs
output_token_ids = self._evaluate_accepted_tokens_tree_style(target_probs, output_token_ids = self._evaluate_accepted_tokens_tree_style(target_probs,
draft_token_ids, draft_token_ids,
...@@ -199,7 +203,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -199,7 +203,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
output_token_id_list = [] output_token_id_list = []
accept_length_list = accept_length.cpu().tolist() accept_length_list = accept_length.cpu().tolist()
logger.info("accept_length:%s", accept_length_list) #logger.info("accept_length:%s", accept_length_list)
for i in range(batch_size): for i in range(batch_size):
output_best_candidates.append(best_candidate[i]) output_best_candidates.append(best_candidate[i])
accept_lengths.append(accept_length_list[i]) accept_lengths.append(accept_length_list[i])
......
import os
import weakref import weakref
from typing import List, Optional, Set, Tuple, Dict from typing import List, Optional, Set, Tuple, Dict
...@@ -29,6 +30,8 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -29,6 +30,8 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Lazy initialization list. # Lazy initialization list.
self._proposer: SpeculativeProposer self._proposer: SpeculativeProposer
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def init_device(self): def init_device(self):
super().init_device() super().init_device()
...@@ -38,7 +41,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -38,7 +41,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# get medusa choices and generate medusa_buffers # get medusa choices and generate medusa_buffers
self.medusa_buffers = None self.medusa_buffers = None
if hasattr(self.model_runner.model, 'medusa_choices'): if self.tree_decoding and hasattr(self.model_runner.model, 'medusa_choices'):
self.medusa_choices = self.model_runner.model.medusa_choices self.medusa_choices = self.model_runner.model.medusa_choices
if self.medusa_choices is not None: if self.medusa_choices is not None:
self.medusa_buffers = self.generate_medusa_buffers( self.medusa_buffers = self.generate_medusa_buffers(
......
import os
from collections import defaultdict from collections import defaultdict
from functools import cached_property from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
...@@ -82,9 +83,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -82,9 +83,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_alpha=speculative_config. typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs, disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats, disable_log_stats=speculative_config.disable_log_stats)
tree_style_spec_decoding=speculative_config.tree_style_spec_decoding,
)
return spec_decode_worker return spec_decode_worker
...@@ -126,7 +125,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -126,7 +125,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha: float, typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool, disable_logprobs: bool,
disable_log_stats: bool, disable_log_stats: bool,
tree_style_spec_decoding: bool,
) -> "SpecDecodeWorker": ) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True allow_zero_draft_token_step = True
...@@ -191,8 +189,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -191,8 +189,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_log_stats=disable_log_stats, disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size, disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler, spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step, allow_zero_draft_token_step=allow_zero_draft_token_step)
tree_style_spec_decoding=tree_style_spec_decoding)
def __init__( def __init__(
self, self,
...@@ -204,7 +201,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -204,7 +201,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None, disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True, allow_zero_draft_token_step: Optional[bool] = True,
tree_style_spec_decoding: bool = False,
): ):
""" """
Create a SpecDecodeWorker. Create a SpecDecodeWorker.
...@@ -233,7 +229,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -233,7 +229,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
allow_zero_draft_token_step: whether to allow a step where the draft allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814) draft model is larger than 1 (TODO: #5814)
tree_style_spec_decoding: Whether to use tree-style generation.
""" """
self.proposer_worker = proposer_worker self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker self.scorer_worker = scorer_worker
...@@ -268,7 +263,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -268,7 +263,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._disable_logprobs = disable_logprobs self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats self._disable_log_stats = disable_log_stats
self.tree_style_spec_decoding = tree_style_spec_decoding self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize both scorer and proposer models. """Initialize both scorer and proposer models.
...@@ -285,7 +280,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -285,7 +280,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._metrics.init_gpu_tensors(self.rank) self._metrics.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank) self.spec_decode_sampler.init_gpu_tensors(self.rank)
if not self.tree_style_spec_decoding: if not self.tree_decoding:
self.scorer = BatchExpansionTop1Scorer( self.scorer = BatchExpansionTop1Scorer(
scorer_worker=self.scorer_worker, scorer_worker=self.scorer_worker,
device=self.device, device=self.device,
...@@ -324,7 +319,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -324,7 +319,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) = True ) = True
# tree_style decoding modify probs in _verify_tokens # tree_style decoding modify probs in _verify_tokens
if not self.tree_style_spec_decoding: if not self.tree_decoding:
(self.scorer_worker.model_runner.model.sampler. (self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor() self.proposer_worker.set_include_gpu_probs_tensor()
...@@ -535,7 +530,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -535,7 +530,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
not called, meaning that the kv-cache in proposer for requests is not not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding. updated, so they cannot enable spec decode in the rest decoding.
""" """
if self.tree_style_spec_decoding and self.kvcache_slot_to_be_moved is not None: if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
self.kvcache_slot_to_be_moved = None self.kvcache_slot_to_be_moved = None
...@@ -560,7 +555,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -560,7 +555,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states, execute_model_req.seq_group_metadata_list) hidden_states, execute_model_req.seq_group_metadata_list)
# Store logits from target model execution. # Store logits from target model execution.
if self.tree_style_spec_decoding: if self.tree_decoding:
logits = sampler_output.logits logits = sampler_output.logits
if logits is not None: if logits is not None:
if self.previous_logits is None: if self.previous_logits is None:
...@@ -612,7 +607,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -612,7 +607,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.execute_model() self.scorer_worker.execute_model()
if not data["disable_all_speculation"]: if not data["disable_all_speculation"]:
# if not self.tree_style_spec_decoding: # if not self.tree_decoding:
# # Even if num_lookahead_slots is zero, we want to run the # # Even if num_lookahead_slots is zero, we want to run the
# # proposer model as it may have KV. # # proposer model as it may have KV.
# # # #
...@@ -677,7 +672,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -677,7 +672,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"workers generate no tokens") "workers generate no tokens")
# Pass tree attention mask and postions to target model # Pass tree attention mask and postions to target model
if self.tree_style_spec_decoding: if self.tree_decoding:
execute_model_req.tree_attn_masks = proposals.tree_attn_masks execute_model_req.tree_attn_masks = proposals.tree_attn_masks
execute_model_req.tree_position_ids = proposals.tree_position_ids execute_model_req.tree_position_ids = proposals.tree_position_ids
...@@ -695,7 +690,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -695,7 +690,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals, execute_model_req.num_lookahead_slots) proposals, execute_model_req.num_lookahead_slots)
# move kv_caches of selected tokens to right positions # move kv_caches of selected tokens to right positions
if self.tree_style_spec_decoding: if self.tree_decoding:
self.move_caches(execute_model_req, select_indices_list, accept_lengths) self.move_caches(execute_model_req, select_indices_list, accept_lengths)
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
...@@ -739,7 +734,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -739,7 +734,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
else: else:
proposal_verifier_probs = proposal_scores.probs proposal_verifier_probs = proposal_scores.probs
if self.tree_style_spec_decoding: if self.tree_decoding:
retrieve_indices = proposals.retrieve_indices retrieve_indices = proposals.retrieve_indices
proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices] proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]
...@@ -797,7 +792,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -797,7 +792,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) )
# Append output tokens from non-speculative sequences to # Append output tokens from non-speculative sequences to
# the accepted token ids tensor. # the accepted token ids tensor.
if not self.tree_style_spec_decoding: if not self.tree_decoding:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
1).clone() 1).clone()
else: else:
......
...@@ -8,6 +8,7 @@ from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig ...@@ -8,6 +8,7 @@ from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
is_pin_memory_available) is_pin_memory_available)
from vllm.attention.backends.tree_decoding_utils import move_cache
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -103,11 +104,12 @@ class CacheEngine: ...@@ -103,11 +104,12 @@ class CacheEngine:
def move_caches(self, kv_caches: List[torch.Tensor], def move_caches(self, kv_caches: List[torch.Tensor],
src_to_dsts: torch.Tensor) -> None: src_to_dsts: torch.Tensor) -> None:
self.attn_backend.move_cache(kv_caches, move_cache(self.attn_backend,
src_to_dsts, kv_caches,
self.cache_config.cache_dtype, src_to_dsts,
self.num_kv_heads, self.cache_config.cache_dtype,
self.head_size) self.num_kv_heads,
self.head_size)
@staticmethod @staticmethod
def get_cache_block_size( def get_cache_block_size(
......
...@@ -198,7 +198,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -198,7 +198,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_requests.clear() # type: ignore self.lora_requests.clear() # type: ignore
self.prompt_adapter_index_mapping.clear() # type: ignore self.prompt_adapter_index_mapping.clear() # type: ignore
self.prompt_adapter_prompt_mapping.clear() # type: ignore self.prompt_adapter_prompt_mapping.clear() # type: ignore
self.tree_attn_masks[0] = None # type: ignore
def __init__( def __init__(
self, self,
...@@ -246,9 +245,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -246,9 +245,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
reinit: bool = False, reinit: bool = False,
reinit_use_defaults: bool = False, reinit_use_defaults: bool = False,
encoder_seq_len: int = 0, encoder_seq_len: int = 0,
# attention mask used in tree-style generation
tree_attn_masks: Optional[List[torch.Tensor]] = None,
): ):
if reinit: if reinit:
assert len(self.seq_ids) == len(seq_ids) # type: ignore assert len(self.seq_ids) == len(seq_ids) # type: ignore
...@@ -339,12 +335,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -339,12 +335,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_prompt_mapping prompt_adapter_prompt_mapping
else: else:
self.prompt_adapter_prompt_mapping.clear() self.prompt_adapter_prompt_mapping.clear()
if tree_attn_masks:
self.tree_attn_masks = tree_attn_masks
else:
self.tree_attn_masks.clear()
else: else:
self.input_tokens = input_tokens or [] self.input_tokens = input_tokens or []
self.input_positions = input_positions or [] self.input_positions = input_positions or []
...@@ -364,7 +354,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -364,7 +354,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_index_mapping or []) prompt_adapter_index_mapping or [])
self.prompt_adapter_prompt_mapping = ( self.prompt_adapter_prompt_mapping = (
prompt_adapter_prompt_mapping or []) prompt_adapter_prompt_mapping or [])
self.tree_attn_masks = tree_attn_masks or []
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs self.multi_modal_inputs = multi_modal_inputs
...@@ -380,7 +369,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -380,7 +369,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)]
self.tree_attn_masks = [None for _ in range(self.n_seqs)]
self.mrope_input_positions = None self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs
...@@ -469,13 +457,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -469,13 +457,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.sliding_window + self.block_size - 1) // self.block_size self.sliding_window + self.block_size - 1) // self.block_size
self.block_aligned_sliding_window = \ self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size self.sliding_window_blocks * self.block_size
if hasattr(self.runner, "tree_attn_masks"):
self.tree_attn_masks = self.runner.tree_attn_masks
self.tree_position_ids = self.runner.tree_position_ids
else:
self.tree_attn_masks = None
self.tree_position_ids = None
self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder_model self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder_model
...@@ -853,16 +834,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -853,16 +834,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if cuda_graph_pad_size: if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# prepare tree attention masks
tree_attention_masks_tensor = self.tree_attn_masks
if tree_attention_masks_tensor is not None:
tree_attention_masks_tensor = tree_attention_masks_tensor.contiguous()
input_positions_tensor = self.tree_position_ids.contiguous()
# Attention metadata. # Attention metadata.
attn_metadata = self.attn_metadata_builder.build( attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, cuda_graph_pad_size, batch_size, seq_lens, query_lens, cuda_graph_pad_size, batch_size)
tree_attention_masks_tensor=tree_attention_masks_tensor)
# LoRA data. # LoRA data.
lora_requests = set() lora_requests = set()
...@@ -1033,9 +1007,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1033,9 +1007,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.inter_data_cache: Dict[int, PyObjectCache] = {} self.inter_data_cache: Dict[int, PyObjectCache] = {}
self.sampling_metadata_cache: SamplingMetadataCache = \ self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() SamplingMetadataCache()
self.tree_attn_masks: Optional[torch.Tensor] = None
self.tree_position_ids : Optional[torch.Tensor] = None
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
...@@ -1503,11 +1474,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1503,11 +1474,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
return self.model_config.get_vocab_size() return self.model_config.get_vocab_size()
def set_tree_style_args(self, tree_attn_masks: Optional[torch.Tensor],
tree_position_ids: Optional[torch.Tensor]):
self.tree_attn_masks = tree_attn_masks
self.tree_position_ids = tree_position_ids
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
......
...@@ -31,6 +31,7 @@ class WorkerBase(ABC): ...@@ -31,6 +31,7 @@ class WorkerBase(ABC):
""" """
model_input: Optional[ModelRunnerInputBase] = None model_input: Optional[ModelRunnerInputBase] = None
tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
@abstractmethod @abstractmethod
def init_device(self) -> None: def init_device(self) -> None:
...@@ -103,18 +104,6 @@ class WorkerBase(ABC): ...@@ -103,18 +104,6 @@ class WorkerBase(ABC):
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
"""
Gets the list of kv caches to pass to the worker's model runner. Each
element in the list is a kv cache corresponding to a particular virtual
engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
"""
raise NotImplementedError
@property @property
@abstractmethod @abstractmethod
def cache_engines(self) -> Optional[List[CacheEngine]]: def cache_engines(self) -> Optional[List[CacheEngine]]:
...@@ -138,10 +127,6 @@ class LoraNotSupportedWorkerBase(WorkerBase): ...@@ -138,10 +127,6 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA") raise ValueError(f"{type(self)} does not support LoRA")
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None
@property @property
def cache_engines(self) -> Optional[List[CacheEngine]]: def cache_engines(self) -> Optional[List[CacheEngine]]:
...@@ -282,10 +267,6 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -282,10 +267,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input: WorkerInput = self.prepare_worker_input( worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
if hasattr(self.model_runner, "set_tree_style_args"):
self.model_runner.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks,
tree_position_ids=execute_model_req.tree_position_ids)
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
...@@ -293,6 +274,17 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -293,6 +274,17 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req.virtual_engine, execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids)) execute_model_req.finished_requests_ids))
if self.tree_decoding and execute_model_req.tree_position_ids is not None and \
execute_model_req.tree_attn_masks is not None:
if hasattr(model_input, "input_positions") and \
hasattr(model_input, "attn_metadata") and \
hasattr(model_input.attn_metadata, "tree_attention_masks_tensor"):
attn_metadata = model_input.attn_metadata
attn_metadata.tree_attention_masks_tensor = execute_model_req.tree_attn_masks.contiguous()
model_input = dataclasses.replace(model_input,
input_positions=execute_model_req.tree_position_ids.contiguous(),
attn_metadata=attn_metadata)
kwargs = extract_previous_hidden_states(execute_model_req) kwargs = extract_previous_hidden_states(execute_model_req)
if self.do_metadata_broadcast: if self.do_metadata_broadcast:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment