Commit d589e598 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'

[feat]优化medusa代码,通过VLLM_TREE_DECODING环境变量控制是否采用tree-style解码,计算逻辑与主干隔离

See merge request dcutoolkit/deeplearing/vllm!51
parents 54b92ba4 0bb491f8
......@@ -62,16 +62,6 @@ class OpenVINOAttentionBackend(AttentionBackend):
key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :]
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass
class OpenVINOAttentionMetadata:
......
......@@ -53,16 +53,6 @@ class PallasAttentionBackend(AttentionBackend):
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass
class PallasMetadata(AttentionMetadata):
......
......@@ -72,50 +72,6 @@ class ROCmFlashAttentionBackend(AttentionBackend):
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
......
......@@ -65,16 +65,6 @@ class TorchSDPABackend(AttentionBackend):
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
......
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional
import torch
from vllm.attention.backends.blocksparse_attn import BlocksparseFlashAttentionImpl
from vllm import _custom_ops as ops
from vllm.attention.ops.paged_attn import PagedAttention
def move_cache(
backend,
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
if backend.get_name() == "rocm-flash-attn" or \
backend.get_name() == "xformers":
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
else:
raise NotImplementedError("Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!")
\ No newline at end of file
......@@ -9,6 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
......@@ -188,8 +190,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.block_size, inter_data.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int,
tree_attention_masks_tensor: Optional[torch.Tensor] = None):
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
......@@ -272,7 +273,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
tree_attention_masks_tensor=tree_attention_masks_tensor,
block_tables_list=self.block_tables
)
......
......@@ -68,50 +68,6 @@ class XFormersBackend(AttentionBackend):
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
......
......@@ -142,7 +142,102 @@ class PagedAttention:
if envs.VLLM_USE_OPT_OP:
if envs.VLLM_USE_TC_PAGED_ATTN:
ops.paged_attention_v1_opt_tc(
if attn_masks is None:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v1_opt_tc_with_mask(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v1_opt_with_mask(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None:
ops.paged_attention_v1(
output,
query,
key_cache,
......@@ -161,12 +256,10 @@ class PagedAttention:
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v1_opt(
ops.paged_attention_v1_with_mask(
output,
query,
key_cache,
......@@ -189,30 +282,6 @@ class PagedAttention:
attn_masks,
attn_masks_stride
)
else:
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
......@@ -236,7 +305,114 @@ class PagedAttention:
if envs.VLLM_USE_OPT_OP:
if envs.VLLM_USE_TC_PAGED_ATTN:
ops.paged_attention_v2_opt_tc(
if attn_masks is None:
ops.paged_attention_v2_opt_tc(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v2_opt_tc_with_mask(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v2_opt_with_mask(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
......@@ -258,12 +434,10 @@ class PagedAttention:
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v2_opt(
ops.paged_attention_v2_with_mask(
output,
exp_sums,
max_logits,
......@@ -289,33 +463,6 @@ class PagedAttention:
attn_masks,
attn_masks_stride
)
else:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
return output
@staticmethod
......
......@@ -1130,7 +1130,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha: Optional[float],
disable_logprobs: Optional[bool],
num_speculative_heads: Optional[int],
tree_style_spec_decoding: Optional[bool]=None,
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
......@@ -1191,9 +1190,6 @@ class SpeculativeConfig:
num_speculative_heads (Optional[int]): It will be used in tree-style
speculative generation, representing how many heads the draft model
has.
tree_style_spec_decoding (Optional[bool]): If set to True,
tree-style generation will be activated. If not specified,
it defaults to False.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
......@@ -1308,9 +1304,9 @@ class SpeculativeConfig:
"n_predict parameter.")
if typical_acceptance_sampler_posterior_threshold is None:
typical_acceptance_sampler_posterior_threshold = 0.3
typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.09
typical_acceptance_sampler_posterior_alpha = 0.3
if disable_logprobs is None:
disable_logprobs = True
......@@ -1328,7 +1324,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
tree_style_spec_decoding=tree_style_spec_decoding
)
@staticmethod
......@@ -1423,7 +1418,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
tree_style_spec_decoding: bool,
):
"""Create a SpeculativeConfig object.
......@@ -1458,7 +1452,6 @@ class SpeculativeConfig:
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
tree_style_spec_decoding: Whether to use tree-style generation.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
......@@ -1474,7 +1467,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats
self.tree_style_spec_decoding = tree_style_spec_decoding
self._verify_args()
......
......@@ -176,7 +176,6 @@ class EngineArgs:
disable_logprobs_during_spec_decoding: Optional[bool] = None
otlp_traces_endpoint: Optional[str] = None
tree_style_spec_decoding: Optional[bool] = None
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None
......@@ -707,11 +706,6 @@ class EngineArgs:
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.')
parser.add_argument('--tree-style-spec-decoding',
type=bool,
default=False,
help='If set to True, tree-style generation will be activated.')
parser.add_argument(
'--typical-acceptance-sampler-posterior-threshold',
......@@ -997,7 +991,6 @@ class EngineArgs:
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
tree_style_spec_decoding=self.tree_style_spec_decoding,
num_speculative_heads=self.num_speculative_heads
)
......
import os
import time
from collections import deque
from contextlib import contextmanager
......@@ -463,6 +464,8 @@ class LLMEngine:
get_tokenizer_for_seq,
),
))
self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
......@@ -989,16 +992,15 @@ class LLMEngine:
output = [outputs_by_sequence_group[0][i]]
# tree style speculative decoding may generate empty output in first step
if self.speculative_config and self.speculative_config.tree_style_spec_decoding:
if outputs and isinstance(output[0], CompletionSequenceGroupOutput):
samples = [o.samples[0] for o in output]
valid_samples = [
sample for sample in samples
if sample.output_token != VLLM_INVALID_TOKEN_ID
]
if len(valid_samples) == 0:
empty_seq_indices.append(i)
continue
if self.tree_decoding and outputs and isinstance(output[0], CompletionSequenceGroupOutput):
samples = [o.samples[0] for o in output]
valid_samples = [
sample for sample in samples
if sample.output_token != VLLM_INVALID_TOKEN_ID
]
if len(valid_samples) == 0:
empty_seq_indices.append(i)
continue
if not is_async:
seq_group.update_num_computed_tokens(
......
......@@ -68,6 +68,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False
VLLM_TREE_DECODING: bool = False
def get_default_cache_root():
......@@ -453,6 +454,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda:
(os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in
("1", "true")),
# If set, vLLM will use tree-style speculative decoding.
"VLLM_TREE_DECODING":
lambda:
(os.environ.get("VLLM_TREE_DECODING", "0").strip().lower() in
("1", "true"))
}
# end-env-vars-definition
......
......@@ -117,8 +117,6 @@ class LoRAModel(AdapterModel):
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
if "lora_A" not in tensor_name and "lora_B" not in tensor_name:
continue
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
if module_name not in loras:
lora_embeddings_tensor = None
......
import os
from typing import Optional, List
import torch
import torch.jit
......@@ -39,6 +40,8 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self._posterior_alpha = posterior_alpha
super().__init__(strict_mode=strict_mode)
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def forward(
self,
target_with_bonus_probs: torch.Tensor,
......@@ -92,7 +95,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self._raise_if_incorrect_input(target_with_bonus_probs,
draft_token_ids, bonus_token_ids)
if cart_candidates is None:
if not self.tree_decoding:
target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids)
......@@ -101,6 +104,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
draft_token_ids,
bonus_token_ids)
else:
assert cart_candidates is not None
target_probs = target_with_bonus_probs
output_token_ids = self._evaluate_accepted_tokens_tree_style(target_probs,
draft_token_ids,
......@@ -199,7 +203,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
output_token_id_list = []
accept_length_list = accept_length.cpu().tolist()
logger.info("accept_length:%s", accept_length_list)
#logger.info("accept_length:%s", accept_length_list)
for i in range(batch_size):
output_best_candidates.append(best_candidate[i])
accept_lengths.append(accept_length_list[i])
......
import os
import weakref
from typing import List, Optional, Set, Tuple, Dict
......@@ -29,6 +30,8 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Lazy initialization list.
self._proposer: SpeculativeProposer
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def init_device(self):
super().init_device()
......@@ -38,7 +41,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# get medusa choices and generate medusa_buffers
self.medusa_buffers = None
if hasattr(self.model_runner.model, 'medusa_choices'):
if self.tree_decoding and hasattr(self.model_runner.model, 'medusa_choices'):
self.medusa_choices = self.model_runner.model.medusa_choices
if self.medusa_choices is not None:
self.medusa_buffers = self.generate_medusa_buffers(
......
import os
from collections import defaultdict
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple
......@@ -82,9 +83,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats,
tree_style_spec_decoding=speculative_config.tree_style_spec_decoding,
)
disable_log_stats=speculative_config.disable_log_stats)
return spec_decode_worker
......@@ -126,7 +125,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
tree_style_spec_decoding: bool,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
......@@ -191,8 +189,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
tree_style_spec_decoding=tree_style_spec_decoding)
allow_zero_draft_token_step=allow_zero_draft_token_step)
def __init__(
self,
......@@ -204,7 +201,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
tree_style_spec_decoding: bool = False,
):
"""
Create a SpecDecodeWorker.
......@@ -233,7 +229,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
tree_style_spec_decoding: Whether to use tree-style generation.
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
......@@ -268,7 +263,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats
self.tree_style_spec_decoding = tree_style_spec_decoding
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
......@@ -285,7 +280,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._metrics.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank)
if not self.tree_style_spec_decoding:
if not self.tree_decoding:
self.scorer = BatchExpansionTop1Scorer(
scorer_worker=self.scorer_worker,
device=self.device,
......@@ -324,7 +319,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) = True
# tree_style decoding modify probs in _verify_tokens
if not self.tree_style_spec_decoding:
if not self.tree_decoding:
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor()
......@@ -535,7 +530,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
if self.tree_style_spec_decoding and self.kvcache_slot_to_be_moved is not None:
if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
self.kvcache_slot_to_be_moved = None
......@@ -560,7 +555,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states, execute_model_req.seq_group_metadata_list)
# Store logits from target model execution.
if self.tree_style_spec_decoding:
if self.tree_decoding:
logits = sampler_output.logits
if logits is not None:
if self.previous_logits is None:
......@@ -612,7 +607,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.execute_model()
if not data["disable_all_speculation"]:
# if not self.tree_style_spec_decoding:
# if not self.tree_decoding:
# # Even if num_lookahead_slots is zero, we want to run the
# # proposer model as it may have KV.
# #
......@@ -677,7 +672,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"workers generate no tokens")
# Pass tree attention mask and postions to target model
if self.tree_style_spec_decoding:
if self.tree_decoding:
execute_model_req.tree_attn_masks = proposals.tree_attn_masks
execute_model_req.tree_position_ids = proposals.tree_position_ids
......@@ -695,7 +690,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals, execute_model_req.num_lookahead_slots)
# move kv_caches of selected tokens to right positions
if self.tree_style_spec_decoding:
if self.tree_decoding:
self.move_caches(execute_model_req, select_indices_list, accept_lengths)
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
......@@ -739,7 +734,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
else:
proposal_verifier_probs = proposal_scores.probs
if self.tree_style_spec_decoding:
if self.tree_decoding:
retrieve_indices = proposals.retrieve_indices
proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]
......@@ -797,7 +792,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
if not self.tree_style_spec_decoding:
if not self.tree_decoding:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
1).clone()
else:
......
......@@ -8,6 +8,7 @@ from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
is_pin_memory_available)
from vllm.attention.backends.tree_decoding_utils import move_cache
logger = init_logger(__name__)
......@@ -103,11 +104,12 @@ class CacheEngine:
def move_caches(self, kv_caches: List[torch.Tensor],
src_to_dsts: torch.Tensor) -> None:
self.attn_backend.move_cache(kv_caches,
src_to_dsts,
self.cache_config.cache_dtype,
self.num_kv_heads,
self.head_size)
move_cache(self.attn_backend,
kv_caches,
src_to_dsts,
self.cache_config.cache_dtype,
self.num_kv_heads,
self.head_size)
@staticmethod
def get_cache_block_size(
......
......@@ -198,7 +198,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_requests.clear() # type: ignore
self.prompt_adapter_index_mapping.clear() # type: ignore
self.prompt_adapter_prompt_mapping.clear() # type: ignore
self.tree_attn_masks[0] = None # type: ignore
def __init__(
self,
......@@ -246,9 +245,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
reinit: bool = False,
reinit_use_defaults: bool = False,
encoder_seq_len: int = 0,
# attention mask used in tree-style generation
tree_attn_masks: Optional[List[torch.Tensor]] = None,
):
if reinit:
assert len(self.seq_ids) == len(seq_ids) # type: ignore
......@@ -339,12 +335,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_prompt_mapping
else:
self.prompt_adapter_prompt_mapping.clear()
if tree_attn_masks:
self.tree_attn_masks = tree_attn_masks
else:
self.tree_attn_masks.clear()
else:
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
......@@ -364,7 +354,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_index_mapping or [])
self.prompt_adapter_prompt_mapping = (
prompt_adapter_prompt_mapping or [])
self.tree_attn_masks = tree_attn_masks or []
self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs
......@@ -380,7 +369,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)]
self.tree_attn_masks = [None for _ in range(self.n_seqs)]
self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs
......@@ -469,13 +457,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.sliding_window + self.block_size - 1) // self.block_size
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size
if hasattr(self.runner, "tree_attn_masks"):
self.tree_attn_masks = self.runner.tree_attn_masks
self.tree_position_ids = self.runner.tree_position_ids
else:
self.tree_attn_masks = None
self.tree_position_ids = None
self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder_model
......@@ -853,16 +834,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# prepare tree attention masks
tree_attention_masks_tensor = self.tree_attn_masks
if tree_attention_masks_tensor is not None:
tree_attention_masks_tensor = tree_attention_masks_tensor.contiguous()
input_positions_tensor = self.tree_position_ids.contiguous()
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, cuda_graph_pad_size, batch_size,
tree_attention_masks_tensor=tree_attention_masks_tensor)
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
# LoRA data.
lora_requests = set()
......@@ -1033,9 +1007,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.inter_data_cache: Dict[int, PyObjectCache] = {}
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache()
self.tree_attn_masks: Optional[torch.Tensor] = None
self.tree_position_ids : Optional[torch.Tensor] = None
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
......@@ -1503,11 +1474,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
def set_tree_style_args(self, tree_attn_masks: Optional[torch.Tensor],
tree_position_ids: Optional[torch.Tensor]):
self.tree_attn_masks = tree_attn_masks
self.tree_position_ids = tree_position_ids
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
......
......@@ -31,6 +31,7 @@ class WorkerBase(ABC):
"""
model_input: Optional[ModelRunnerInputBase] = None
tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
@abstractmethod
def init_device(self) -> None:
......@@ -103,18 +104,6 @@ class WorkerBase(ABC):
def list_loras(self) -> Set[int]:
raise NotImplementedError
@property
@abstractmethod
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
"""
Gets the list of kv caches to pass to the worker's model runner. Each
element in the list is a kv cache corresponding to a particular virtual
engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
"""
raise NotImplementedError
@property
@abstractmethod
def cache_engines(self) -> Optional[List[CacheEngine]]:
......@@ -138,10 +127,6 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
......@@ -282,10 +267,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
if hasattr(self.model_runner, "set_tree_style_args"):
self.model_runner.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks,
tree_position_ids=execute_model_req.tree_position_ids)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
......@@ -293,6 +274,17 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if self.tree_decoding and execute_model_req.tree_position_ids is not None and \
execute_model_req.tree_attn_masks is not None:
if hasattr(model_input, "input_positions") and \
hasattr(model_input, "attn_metadata") and \
hasattr(model_input.attn_metadata, "tree_attention_masks_tensor"):
attn_metadata = model_input.attn_metadata
attn_metadata.tree_attention_masks_tensor = execute_model_req.tree_attn_masks.contiguous()
model_input = dataclasses.replace(model_input,
input_positions=execute_model_req.tree_position_ids.contiguous(),
attn_metadata=attn_metadata)
kwargs = extract_previous_hidden_states(execute_model_req)
if self.do_metadata_broadcast:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment