Commit 19bc93d9 authored by 王敏's avatar 王敏
Browse files

增加medusa并行解码功能,后续增加使用说明和测试文档

parent aba40fda
...@@ -16,6 +16,7 @@ from vllm.attention.backends.utils import (CommonAttentionState, ...@@ -16,6 +16,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm import _custom_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -67,6 +68,50 @@ class XFormersBackend(AttentionBackend): ...@@ -67,6 +68,50 @@ class XFormersBackend(AttentionBackend):
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass @dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...@@ -144,6 +189,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -144,6 +189,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_slot_mapping: Optional[torch.Tensor] = None cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None
tree_attention_masks_tensor: Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt # It is a list because it is needed to set per prompt
...@@ -223,7 +270,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -223,7 +270,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len, max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping, cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables) cross_block_tables=self.cross_block_tables,
tree_attention_masks_tensor=self.tree_attention_masks_tensor)
return self._cached_prefill_metadata return self._cached_prefill_metadata
@property @property
...@@ -262,7 +310,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -262,7 +310,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len, max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping, cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables) cross_block_tables=self.cross_block_tables,
tree_attention_masks_tensor=self.tree_attention_masks_tensor)
return self._cached_decode_metadata return self._cached_decode_metadata
...@@ -633,6 +682,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -633,6 +682,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
block_tables_arg, block_tables_arg,
) = _get_seq_len_block_table_args(decode_meta, False, attn_type) ) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
output[num_prefill_tokens:] = PagedAttention.forward_decode( output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query, decode_query,
key_cache, key_cache,
...@@ -646,6 +697,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -646,6 +697,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self.alibi_slopes, self.alibi_slopes,
k_scale, k_scale,
v_scale, v_scale,
attn_masks=tree_attention_masks_tensor,
attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -103,6 +103,8 @@ class PagedAttention: ...@@ -103,6 +103,8 @@ class PagedAttention:
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64, blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0, blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0
) -> torch.Tensor: ) -> torch.Tensor:
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention # use blocksparse paged attention
...@@ -160,6 +162,8 @@ class PagedAttention: ...@@ -160,6 +162,8 @@ class PagedAttention:
blocksparse_vert_stride, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step, blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
) )
else: else:
ops.paged_attention_v1_opt( ops.paged_attention_v1_opt(
...@@ -182,6 +186,8 @@ class PagedAttention: ...@@ -182,6 +186,8 @@ class PagedAttention:
blocksparse_vert_stride, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step, blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
) )
else: else:
ops.paged_attention_v1( ops.paged_attention_v1(
...@@ -204,6 +210,8 @@ class PagedAttention: ...@@ -204,6 +210,8 @@ class PagedAttention:
blocksparse_vert_stride, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step, blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
...@@ -251,6 +259,8 @@ class PagedAttention: ...@@ -251,6 +259,8 @@ class PagedAttention:
blocksparse_vert_stride, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step, blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
) )
else: else:
ops.paged_attention_v2_opt( ops.paged_attention_v2_opt(
...@@ -276,6 +286,8 @@ class PagedAttention: ...@@ -276,6 +286,8 @@ class PagedAttention:
blocksparse_vert_stride, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step, blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
) )
else: else:
ops.paged_attention_v2( ops.paged_attention_v2(
...@@ -301,6 +313,8 @@ class PagedAttention: ...@@ -301,6 +313,8 @@ class PagedAttention:
blocksparse_vert_stride, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step, blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
) )
return output return output
......
...@@ -1129,6 +1129,8 @@ class SpeculativeConfig: ...@@ -1129,6 +1129,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold: Optional[float], typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float], typical_acceptance_sampler_posterior_alpha: Optional[float],
disable_logprobs: Optional[bool], disable_logprobs: Optional[bool],
num_speculative_heads: Optional[int],
tree_style_spec_decoding: Optional[bool]=None,
) -> Optional["SpeculativeConfig"]: ) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None. """Create a SpeculativeConfig if possible, else return None.
...@@ -1186,6 +1188,12 @@ class SpeculativeConfig: ...@@ -1186,6 +1188,12 @@ class SpeculativeConfig:
If set to False, token log probabilities are returned If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams. according to the log probability settings in SamplingParams.
If not specified, it defaults to True. If not specified, it defaults to True.
num_speculative_heads (Optional[int]): It will be used in tree-style
speculative generation, representing how many heads the draft model
has.
tree_style_spec_decoding (Optional[bool]): If set to True,
tree-style generation will be activated. If not specified,
it defaults to False.
Returns: Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...@@ -1264,6 +1272,10 @@ class SpeculativeConfig: ...@@ -1264,6 +1272,10 @@ class SpeculativeConfig:
and hasattr(draft_hf_config, "num_lookahead_tokens")): and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens draft_hf_config.num_lookahead_tokens = num_speculative_tokens
if (num_speculative_heads is not None
and hasattr(draft_hf_config, "num_lookahead_heads")):
draft_hf_config.num_lookahead_heads = num_speculative_heads
n_predict = getattr(draft_hf_config, "n_predict", None) n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None: if n_predict is not None:
if num_speculative_tokens is None: if num_speculative_tokens is None:
...@@ -1296,9 +1308,9 @@ class SpeculativeConfig: ...@@ -1296,9 +1308,9 @@ class SpeculativeConfig:
"n_predict parameter.") "n_predict parameter.")
if typical_acceptance_sampler_posterior_threshold is None: if typical_acceptance_sampler_posterior_threshold is None:
typical_acceptance_sampler_posterior_threshold = 0.09 typical_acceptance_sampler_posterior_threshold = 0.3
if typical_acceptance_sampler_posterior_alpha is None: if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.3 typical_acceptance_sampler_posterior_alpha = 0.09
if disable_logprobs is None: if disable_logprobs is None:
disable_logprobs = True disable_logprobs = True
...@@ -1316,6 +1328,7 @@ class SpeculativeConfig: ...@@ -1316,6 +1328,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs, disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats, disable_log_stats=disable_log_stats,
tree_style_spec_decoding=tree_style_spec_decoding
) )
@staticmethod @staticmethod
...@@ -1410,6 +1423,7 @@ class SpeculativeConfig: ...@@ -1410,6 +1423,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha: float, typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool, disable_logprobs: bool,
disable_log_stats: bool, disable_log_stats: bool,
tree_style_spec_decoding: bool,
): ):
"""Create a SpeculativeConfig object. """Create a SpeculativeConfig object.
...@@ -1444,6 +1458,7 @@ class SpeculativeConfig: ...@@ -1444,6 +1458,7 @@ class SpeculativeConfig:
returned. returned.
disable_log_stats: Whether to disable periodic printing of stage disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding. times in speculative decoding.
tree_style_spec_decoding: Whether to use tree-style generation.
""" """
self.draft_model_config = draft_model_config self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config self.draft_parallel_config = draft_parallel_config
...@@ -1459,6 +1474,7 @@ class SpeculativeConfig: ...@@ -1459,6 +1474,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats self.disable_log_stats = disable_log_stats
self.tree_style_spec_decoding = tree_style_spec_decoding
self._verify_args() self._verify_args()
...@@ -1526,6 +1542,8 @@ class LoRAConfig: ...@@ -1526,6 +1542,8 @@ class LoRAConfig:
# This is a constant. # This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256 lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None long_lora_scaling_factors: Optional[Tuple[float]] = None
merge_lora: bool = False
lora_target_modules: Optional[List[str]] = None
def __post_init__(self): def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast # Setting the maximum rank to 256 should be able to satisfy the vast
...@@ -1548,6 +1566,10 @@ class LoRAConfig: ...@@ -1548,6 +1566,10 @@ class LoRAConfig:
raise ValueError( raise ValueError(
f"max_cpu_loras ({self.max_cpu_loras}) must be >= " f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
f"max_loras ({self.max_loras})") f"max_loras ({self.max_loras})")
if self.merge_lora and self.max_loras > 1:
raise ValueError(
f"merge_lora ({self.merge_lora}) can only be used when "
f"max_loras ({self.max_loras}) is 1")
def verify_with_model_config(self, model_config: ModelConfig): def verify_with_model_config(self, model_config: ModelConfig):
if self.lora_dtype in (None, "auto"): if self.lora_dtype in (None, "auto"):
......
...@@ -143,6 +143,8 @@ class EngineArgs: ...@@ -143,6 +143,8 @@ class EngineArgs:
long_lora_scaling_factors: Optional[Tuple[float]] = None long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
merge_lora: bool = False
lora_target_modules: Optional[List[str]] = None
device: str = 'auto' device: str = 'auto'
num_scheduler_steps: int = 1 num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = False multi_step_stream_outputs: bool = False
...@@ -162,6 +164,7 @@ class EngineArgs: ...@@ -162,6 +164,7 @@ class EngineArgs:
speculative_model_quantization: Optional[str] = None speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
num_speculative_heads: Optional[int] = None
speculative_max_model_len: Optional[int] = None speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None
...@@ -173,6 +176,7 @@ class EngineArgs: ...@@ -173,6 +176,7 @@ class EngineArgs:
disable_logprobs_during_spec_decoding: Optional[bool] = None disable_logprobs_during_spec_decoding: Optional[bool] = None
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
tree_style_spec_decoding: Optional[bool] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
...@@ -534,6 +538,14 @@ class EngineArgs: ...@@ -534,6 +538,14 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.max_lora_rank, default=EngineArgs.max_lora_rank,
help='Max LoRA rank.') help='Max LoRA rank.')
parser.add_argument('--merge-lora',
type=bool,
default=False,
help='If set to True, the weights of the base layer will be merged with the weights of Lora.')
parser.add_argument('--lora-target-modules',
nargs='*',
default=None,
help='List of lora module name, If not specified, modules will be chosen according to the model architecture.')
parser.add_argument( parser.add_argument(
'--lora-extra-vocab-size', '--lora-extra-vocab-size',
type=int, type=int,
...@@ -639,6 +651,12 @@ class EngineArgs: ...@@ -639,6 +651,12 @@ class EngineArgs:
default=EngineArgs.num_speculative_tokens, default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from ' help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.') 'the draft model in speculative decoding.')
parser.add_argument(
'--num-speculative-heads',
type=int,
default=EngineArgs.num_speculative_heads,
help='The number of speculative heads to sample from '
'the draft model in speculative decoding.')
parser.add_argument( parser.add_argument(
'--speculative-draft-tensor-parallel-size', '--speculative-draft-tensor-parallel-size',
'-spec-draft-tp', '-spec-draft-tp',
...@@ -690,6 +708,11 @@ class EngineArgs: ...@@ -690,6 +708,11 @@ class EngineArgs:
'a higher acceptance rate at the cost of lower quality, ' 'a higher acceptance rate at the cost of lower quality, '
'and vice versa.') 'and vice versa.')
parser.add_argument('--tree-style-spec-decoding',
type=bool,
default=False,
help='If set to True, tree-style generation will be activated.')
parser.add_argument( parser.add_argument(
'--typical-acceptance-sampler-posterior-threshold', '--typical-acceptance-sampler-posterior-threshold',
type=float, type=float,
...@@ -974,6 +997,8 @@ class EngineArgs: ...@@ -974,6 +997,8 @@ class EngineArgs:
typical_acceptance_sampler_posterior_alpha=self. typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding, disable_logprobs=self.disable_logprobs_during_spec_decoding,
tree_style_spec_decoding=self.tree_style_spec_decoding,
num_speculative_heads=self.num_speculative_heads
) )
if self.num_scheduler_steps > 1: if self.num_scheduler_steps > 1:
...@@ -1016,7 +1041,9 @@ class EngineArgs: ...@@ -1016,7 +1041,9 @@ class EngineArgs:
long_lora_scaling_factors=self.long_lora_scaling_factors, long_lora_scaling_factors=self.long_lora_scaling_factors,
lora_dtype=self.lora_dtype, lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None and self.max_cpu_loras > 0 else None,
merge_lora=self.merge_lora,
lora_target_modules=self.lora_target_modules) if self.enable_lora else None
if self.qlora_adapter_name_or_path is not None and \ if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "": self.qlora_adapter_name_or_path != "":
......
...@@ -131,6 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -131,6 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.base_layer = base_layer self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]] self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor] self.embeddings_weights: Optional[torch.Tensor]
self.device = _get_lora_device(self.base_layer)
def create_lora_weights( def create_lora_weights(
self, self,
...@@ -207,6 +208,18 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -207,6 +208,18 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index, self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
self.lora_a = lora_a.to(self.device)
self.lora_b = lora_b.to(self.device)
if self.lora_config.merge_lora:
merged_weights = torch.matmul(self.lora_a, self.lora_b)
if merged_weights.shape != self.base_layer.weight.data:
merged_weights = merged_weights.T + self.base_layer.weight
else:
merged_weights = merged_weights + self.base_layer.weight
self.base_layer.weight.data.copy_(merged_weights)
if embeddings_tensor is not None: if embeddings_tensor is not None:
self.embeddings_tensors[ self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor. index, :embeddings_tensor.shape[0], :embeddings_tensor.
...@@ -225,12 +238,15 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -225,12 +238,15 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1 added_tokens_mask = x > self.base_layer.org_vocab_size - 1
embeddings_indices = self.punica_wrapper.embeddings_indices embeddings_indices = self.punica_wrapper.embeddings_indices
indices = embeddings_indices[1].view_as(x) indices = embeddings_indices[0].view_as(x)
if not self.lora_config.merge_lora:
indices_0 = embeddings_indices[1].view_as(x)
full_lora_a_embeddings = F.embedding( full_lora_a_embeddings = F.embedding(
x + indices, x + indices_0,
self.lora_a_stacked_2d, self.lora_a_stacked_2d,
) )
indices = embeddings_indices[0].view_as(x)
full_output = self.base_layer.forward( full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask)) x.add_(indices * added_tokens_mask))
...@@ -251,6 +267,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -251,6 +267,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked, self.lora_b_stacked,
add_input=True) add_input=True)
return full_output.view_as(full_output_org) return full_output.view_as(full_output_org)
else:
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
return full_output
@classmethod @classmethod
def can_replace_layer( def can_replace_layer(
...@@ -778,6 +798,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -778,6 +798,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.indices_len: List[int] self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
if not self.lora_config.merge_lora:
self.lora_a_stacked[0][index] = 0 self.lora_a_stacked[0][index] = 0
self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[0][index] = 0
self.lora_a_stacked[1][index] = 0 self.lora_a_stacked[1][index] = 0
...@@ -822,6 +843,33 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -822,6 +843,33 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
lora_a = self.slice_lora_a(lora_a) lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b) lora_b = self.slice_lora_b(lora_b)
if self.lora_config.merge_lora:
qkv_weights_list = []
for i in range(len(self.output_slices)):
if lora_a[i] is not None:
if lora_a[i].numel() == 0 or lora_b[i].numel() == 0:
continue
weight_A = lora_a[i].to(self.device)
weight_B = lora_b[i].to(self.device)
delta_weight = torch.matmul(weight_A, weight_B)
qkv_weights_list.append(delta_weight)
else:
if i == 0:
qkv_weights_list.append(torch.zeros(self.input_size, self.q_proj_shard_size,
dtype=self.base_layer.weight.dtype, device=self.device))
else:
qkv_weights_list.append(torch.zeros(self.input_size, self.kv_proj_shard_size,
dtype=self.base_layer.weight.dtype, device=self.device))
if len(qkv_weights_list) > 0:
qkv_weights = torch.cat(qkv_weights_list, dim=-1)
if qkv_weights.shape != self.base_layer.weight.shape:
qkv_weights = qkv_weights.T + self.base_layer.weight.data
else:
qkv_weights = qkv_weights + self.base_layer.weight.data
self.base_layer.weight.data.copy_(qkv_weights)
else:
if lora_b[0] is not None: if lora_b[0] is not None:
lora_b_q = lora_b[0] lora_b_q = lora_b[0]
self.lora_b_stacked[0][ self.lora_b_stacked[0][
...@@ -853,11 +901,14 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -853,11 +901,14 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
if not self.lora_config.merge_lora:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.punica_wrapper.add_lora_packed_nslice(output, x, self.punica_wrapper.add_lora_packed_nslice(output, x,
self.lora_a_stacked, self.lora_a_stacked,
self.lora_b_stacked, 1.0, self.lora_b_stacked, 1.0,
self.output_slices) self.output_slices)
else:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
return output return output
@classmethod @classmethod
......
...@@ -117,6 +117,8 @@ class LoRAModel(AdapterModel): ...@@ -117,6 +117,8 @@ class LoRAModel(AdapterModel):
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {} loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items(): for tensor_name, tensor in tensors.items():
if "lora_A" not in tensor_name and "lora_B" not in tensor_name:
continue
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
if module_name not in loras: if module_name not in loras:
lora_embeddings_tensor = None lora_embeddings_tensor = None
...@@ -324,6 +326,9 @@ class LoRAModelManager(AdapterModelManager): ...@@ -324,6 +326,9 @@ class LoRAModelManager(AdapterModelManager):
self.scaling_factor_to_offset: Dict[float, int] = {} self.scaling_factor_to_offset: Dict[float, int] = {}
super().__init__(model) super().__init__(model)
if hasattr(self.model, "supported_lora_modules"): if hasattr(self.model, "supported_lora_modules"):
if lora_config.lora_target_modules is not None:
self.supported_lora_modules = lora_config.lora_target_modules
else:
self.supported_lora_modules = copy.deepcopy( self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules) self.model.supported_lora_modules)
if lora_config.long_lora_scaling_factors: if lora_config.long_lora_scaling_factors:
......
...@@ -117,6 +117,15 @@ class SamplerOutput( ...@@ -117,6 +117,15 @@ class SamplerOutput(
# block/sync across workers, cpu-gpu sync time and sampling time. # block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None model_execute_time: Optional[float] = None
# Optional lm_head logits from the model.
logits: Optional[torch.Tensor] = None
# tree-style cartesian candidates
cart_candidates: Optional[torch.Tensor] = None
# tree-style cartesian candidates
tree_attn_masks: Optional[torch.Tensor] = None
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self.outputs[idx] return self.outputs[idx]
...@@ -141,7 +150,9 @@ class SamplerOutput( ...@@ -141,7 +150,9 @@ class SamplerOutput(
f"SamplerOutput(outputs={self.outputs}, " f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, " f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, " f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") f"spec_decode_worker_metrics={self.spec_decode_worker_metrics}, "
f"logits={self.logits}, "
f"tree_attn_masks={self.tree_attn_masks})")
class Sampler(nn.Module): class Sampler(nn.Module):
...@@ -224,6 +235,7 @@ class Sampler(nn.Module): ...@@ -224,6 +235,7 @@ class Sampler(nn.Module):
sampling_metadata: Metadata for sampling. sampling_metadata: Metadata for sampling.
""" """
assert logits is not None assert logits is not None
original_logits = logits.clone()
_, vocab_size = logits.shape _, vocab_size = logits.shape
# Prepare sampling tensors with pinned memory to avoid blocking. # Prepare sampling tensors with pinned memory to avoid blocking.
...@@ -307,7 +319,8 @@ class Sampler(nn.Module): ...@@ -307,7 +319,8 @@ class Sampler(nn.Module):
prompt_logprobs, prompt_logprobs,
sample_logprobs, sample_logprobs,
on_device_tensors=on_device_tensors, on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=original_logits)
@property @property
def _should_modify_greedy_probs_inplace(self) -> bool: def _should_modify_greedy_probs_inplace(self) -> bool:
...@@ -1237,6 +1250,7 @@ def _build_sampler_output( ...@@ -1237,6 +1250,7 @@ def _build_sampler_output(
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]], torch.Tensor]],
skip_sampler_cpu_output: bool = False, skip_sampler_cpu_output: bool = False,
logits: Optional[torch.Tensor] = None
) -> SamplerOutput: ) -> SamplerOutput:
"""Construct Python objects with the output of sampling. """Construct Python objects with the output of sampling.
...@@ -1287,7 +1301,8 @@ def _build_sampler_output( ...@@ -1287,7 +1301,8 @@ def _build_sampler_output(
sampled_token_probs=sampled_token_probs, sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor, logprobs=logprobs_tensor,
deferred_sample_results_args=deferred_sample_results_args) deferred_sample_results_args=deferred_sample_results_args,
logits=logits)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
......
from typing import Optional, List
import torch import torch
import torch.jit import torch.jit
import torch.nn.functional as F
from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeDeterministicBaseSampler) SpecDecodeDeterministicBaseSampler)
...@@ -40,6 +42,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -40,6 +42,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
cart_candidates: Optional[torch.Tensor] = None,
best_candidates: Optional[List] = None,
accept_lengths: Optional[List] = None,
first_step_flags: Optional[List] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Sample token ids using typical acceptance sampling. This accepts """Sample token ids using typical acceptance sampling. This accepts
or rejects tokens proposed by the draft model using the probability or rejects tokens proposed by the draft model using the probability
...@@ -66,6 +72,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -66,6 +72,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
probabilities. probabilities.
shape = [batch_size, num_speculative_tokens] shape = [batch_size, num_speculative_tokens]
cart_candidates: tree-style cartesian candidates
best_candidates: pending to write best candidates index
accept_lengths: pending to write accept lengths
first_step_flags: whether this is the first decoding step
Returns: Returns:
output_token_ids: The token ids sampled via rejection sampling, output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token or -1 if unable to sample a token because the previous token
...@@ -77,6 +88,8 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -77,6 +88,8 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
if self._strict_mode: if self._strict_mode:
self._raise_if_incorrect_input(target_with_bonus_probs, self._raise_if_incorrect_input(target_with_bonus_probs,
draft_token_ids, bonus_token_ids) draft_token_ids, bonus_token_ids)
if cart_candidates is None:
target_probs = target_with_bonus_probs[:, :-1] target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs, accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids) draft_token_ids)
...@@ -84,8 +97,120 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -84,8 +97,120 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
output_token_ids = self._create_output(accepted, recovered_token_ids, output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids, draft_token_ids,
bonus_token_ids) bonus_token_ids)
else:
target_probs = target_with_bonus_probs
output_token_ids = self._evaluate_accepted_tokens_tree_style(target_probs,
draft_token_ids,
cart_candidates,
best_candidates,
accept_lengths,
first_step_flags)
return output_token_ids return output_token_ids
def _evaluate_accepted_tokens_tree_style(self, target_probs, draft_token_ids,
cart_candidates, output_best_candidates,
accept_lengths, first_step_flags):
r"""
Evaluates and returns a mask of accepted tokens based on the
posterior probabilities.
Parameters:
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) representing
the probabilities of each token in the vocabulary for each
position in the proposed sequence. This is the distribution
generated by the target model.
draft_token_ids : torch.Tensor
A tensor of shape (batch_size, k) representing the proposed
token ids.
cart_candidates : torch.Tensor
A tensor of shape (batch_size, retrieve_size, tree_depth)
representing the cart candidates of tree proposals.
A draft token_id x_{n+k} is accepted if it satisfies the
following condition
.. math::
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
where :math:`p_{\text{original}}` corresponds to target_probs
and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given
draft token ids based on the provided target probabilities. It
calculates the entropy of the posterior distribution and determines
a dynamic threshold for each token position using the provided
posterior_threshold and posterior_alpha values. The method then
returns a boolean mask indicating which tokens can be accepted.
Returns:
-------
torch.Tensor
A boolean tensor of shape (batch_size, k) where each element
indicates whether the corresponding draft token has been accepted
or rejected. True indicates acceptance and false indicates
rejection.
"""
target_probs = target_probs[:, :, :-1]
device = target_probs.device
batch_size = cart_candidates.shape[0]
candidates_prob = torch.gather(
target_probs, dim=-1, index=cart_candidates[:, :, 1:].unsqueeze(-1)
).squeeze(-1) # [batch_size, retrieve_size, max_depth]
posterior_entropy = -torch.sum(
target_probs * torch.log(target_probs + 1e-5), dim=-1
) # torch.sum(torch.log(*)) is faster than torch.prod [batch_size, retrieve_size, max_depth]
threshold = torch.minimum(
torch.ones_like(posterior_entropy) * self._posterior_threshold,
torch.exp(-posterior_entropy) * self._posterior_alpha,
)
posterior_mask = candidates_prob > threshold # [batch_size, retrieve_size, max_depth]
candidates_accept_length = (torch.cumprod(posterior_mask, dim=2)).sum(dim=-1) # [batch_size, retrieve_size]
# Choose the best candidate based on the evaluated posterior probabilities
accept_length, _ = candidates_accept_length.max(dim=-1) # [batch_size]
if torch.any(accept_length > 0):
valid_index = (candidates_accept_length == accept_length.unsqueeze(-1)).unsqueeze(-1) # [batch_size, retrieve_size, 1]
candidates_prob = candidates_prob * valid_index # [batch_size, retrieve_size, max_depth]
valid_index = torch.arange(candidates_prob.shape[-1], device=device).unsqueeze(0).unsqueeze(0).repeat(
batch_size, candidates_prob.shape[1], 1) # [batch_size, retrieve_size, max_depth]
valid_index = (valid_index < accept_length.unsqueeze(1).unsqueeze(2).repeat(1, candidates_prob.shape[1], 1)) # [batch_size, retrieve_size, 1]
candidates_prob = candidates_prob*valid_index # [batch_size, retrieve_size, max_depth]
# add 1e-3 to avoid zero value
likelihood = torch.sum(torch.log(candidates_prob + 1e-3), dim=-1) # [batch_size, retrieve_size]
best_candidate = torch.argmax(likelihood, dim=-1) # [batch_size]
else:
# Choose the best candidate
best_candidate = torch.zeros((batch_size), dtype=torch.long, device=device) # [batch_size]
k = draft_token_ids.shape[-1]
output_token_id_list = []
print("####################################accept_length:", accept_length)
for i in range(batch_size):
output_best_candidates.append(best_candidate[i])
accept_lengths.append(accept_length[i])
if not first_step_flags[i]:
select_indices = cart_candidates[i, best_candidate[i], : accept_length[i] + 1]
select_indices = F.pad(select_indices, (0, k - 1 - accept_length[i]), 'constant', -1)
else:
select_indices = cart_candidates[i, best_candidate[i], 1 : accept_length[i] + 1]
select_indices = F.pad(select_indices, (0, k - accept_length[i]), 'constant', -1)
output_token_id_list.append(select_indices)
return torch.stack(output_token_id_list, dim=0)
def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
r""" r"""
Evaluates and returns a mask of accepted tokens based on the Evaluates and returns a mask of accepted tokens based on the
......
...@@ -23,7 +23,7 @@ def get_model_architecture( ...@@ -23,7 +23,7 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visual = getattr(model_config.hf_config, "visual", []) visual = getattr(model_config.hf_config, "visual", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if architectures == ['QWenLMHeadModel'] and visual != []: if architectures == ['QWenLMHeadModel'] and visual != []:
......
from typing import Iterable, List, Optional, Tuple import os
from typing import Iterable, List, Optional, Tuple, Any, Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,6 +11,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -10,6 +11,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm import _custom_ops as ops
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
...@@ -46,6 +50,8 @@ class Medusa(nn.Module): ...@@ -46,6 +50,8 @@ class Medusa(nn.Module):
def __init__(self, config: MedusaConfig, **_) -> None: def __init__(self, config: MedusaConfig, **_) -> None:
super().__init__() super().__init__()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.config = config self.config = config
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
ResidualBlock(hidden_size=self.config.hidden_size, ResidualBlock(hidden_size=self.config.hidden_size,
...@@ -55,6 +61,7 @@ class Medusa(nn.Module): ...@@ -55,6 +61,7 @@ class Medusa(nn.Module):
self.orig_vocab_size = config.vocab_size self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size
self.medusa_choices = config.medusa_choices
self.lm_heads = nn.ModuleList([ self.lm_heads = nn.ModuleList([
ParallelLMHead( ParallelLMHead(
...@@ -138,11 +145,62 @@ class Medusa(nn.Module): ...@@ -138,11 +145,62 @@ class Medusa(nn.Module):
return outputs return outputs
def medusa_sample(
self,
medusa_logits: List[torch.Tensor],
sampling_metadata: SamplingMetadata,
logits: torch.Tensor,
medusa_buffers: Dict[str, Any]
) -> List[SamplerOutput]:
batch_size = logits.shape[0]
candidates_logit = torch.argmax(logits, dim=-1).view(batch_size, -1) # [batch_size, 1]
medusa_logits = torch.stack(medusa_logits, dim=0) # [medusa_heads, batch_size, vocab_size]
# Extract the TOPK candidates from the medusa logits.
candidates_medusa_logits = torch.topk(medusa_logits, TOPK, dim=-1).indices # [medusa_heads, batch_size, TOPK]
candidates_medusa_logits = candidates_medusa_logits.permute(1, 0 ,2) # [batch_size, medusa_heads, TOPK]
candidates_medusa_logits = candidates_medusa_logits.reshape(batch_size, -1)
# Combine the selected candidate from the original logits with the topk medusa logits.
candidates = torch.cat([candidates_logit, candidates_medusa_logits], dim=-1) #[batch_size, 1+medusa_heads*TOPK]
# Map the combined candidates to the tree indices to get tree candidates.
tree_candidates = torch.index_select(candidates, dim=-1, index=medusa_buffers['tree_indices']) # [batch_size, choices]
# Extend the tree candidates by appending a zero.
tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((batch_size, 1),
dtype=torch.long, device=tree_candidates.device)], dim=-1) # [batch_size, choices]
# Retrieve the cartesian candidates using the retrieve indices.
cart_candidates = tree_candidates_ext[:, medusa_buffers['retrieve_indices']] # [batch_size, retrieve_size, max_depth]
token_id_list = []
cart_candidate_list = []
for idx, seq_group in enumerate(sampling_metadata.seq_groups):
token_id_list.append(tree_candidates[seq_group.sample_indices, :])
cart_candidate_list.append(cart_candidates[seq_group.sample_indices, :])
outputs: List[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_ids=token_id_list[idx].squeeze(1),
cart_candidates=cart_candidate_list[idx]
))
return outputs
def generate_proposals( def generate_proposals(
self, self,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
previous_logits: torch.Tensor=None,
medusa_buffers: Dict[str, Any]=None
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
if medusa_buffers is None:
return self.sample( return self.sample(
logits=self.compute_logits( logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states), hidden_states=self.forward(previous_hidden_states),
...@@ -150,6 +208,14 @@ class Medusa(nn.Module): ...@@ -150,6 +208,14 @@ class Medusa(nn.Module):
), ),
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
else:
return self.medusa_sample(medusa_logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),
sampling_metadata=sampling_metadata,
),
sampling_metadata=sampling_metadata,
logits=previous_logits,
medusa_buffers=medusa_buffers)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -177,6 +243,15 @@ class Medusa(nn.Module): ...@@ -177,6 +243,15 @@ class Medusa(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and "lm_head" in name:
_weight = torch.zeros_like(param.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1])
param.data.copy_(_weight)
param.data=param.data.reshape(ori_shape[1],-1)
if self.token_map is not None: if self.token_map is not None:
self.token_map.to(device=self.lm_heads[0].weight.device) self.token_map.to(device=self.lm_heads[0].weight.device)
......
...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -329,6 +329,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -329,6 +329,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
"o_proj", "o_proj",
"gate_up_proj", "gate_up_proj",
"down_proj", "down_proj",
"lm_head"
] ]
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
...@@ -363,9 +364,18 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -363,9 +364,18 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead(config.vocab_size, # self.lm_head = ParallelLMHead(config.vocab_size,
# config.hidden_size,
# quant_config=quant_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -988,6 +988,9 @@ class SequenceGroupMetadata( ...@@ -988,6 +988,9 @@ class SequenceGroupMetadata(
# TODO: We should maintain this states out of the sequence group. # TODO: We should maintain this states out of the sequence group.
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
tree_attn_masks : Optional[torch.Tensor] = None
tree_position_ids : Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
if self.seq_data is not None and self.token_chunk_size is None: if self.seq_data is not None and self.token_chunk_size is None:
if self.is_prompt: if self.is_prompt:
...@@ -1025,6 +1028,11 @@ class SequenceGroupMetadata( ...@@ -1025,6 +1028,11 @@ class SequenceGroupMetadata(
assert self.state.current_step < self.state.num_steps assert self.state.current_step < self.state.num_steps
self.state.current_step += 1 self.state.current_step += 1
def set_tree_style_args(self, tree_attn_masks: Optional[torch.Tensor],
tree_position_ids: Optional[torch.Tensor]):
self.tree_attn_masks = tree_attn_masks
self.tree_position_ids = tree_position_ids
class SequenceOutput( class SequenceOutput(
msgspec.Struct, msgspec.Struct,
...@@ -1271,6 +1279,56 @@ class HiddenStates(msgspec.Struct, array_like=True, ...@@ -1271,6 +1279,56 @@ class HiddenStates(msgspec.Struct, array_like=True,
[self.hidden_states, self.second_last_token_hidden_states])[index] [self.hidden_states, self.second_last_token_hidden_states])[index]
class Logits(msgspec.Struct, array_like=True,
omit_defaults=True): # type: ignore[call-arg]
"""Logits corresponding to in-progress sequences.
Used in speculative decoding to pass lm_head logits from
the target model to the proposer model in the subsequent step.
seq_ids are the sequence ids of each entry of the batch
dimension of the logits tensor"""
# Scorer hidden states. For prefill step, it is used for hidden states of
# all tokens, whereas for decode step, it use used for last accepted tokens.
logits: torch.Tensor
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
_seq_ids: List[int] = msgspec.field(default_factory=list)
def __post_init__(self):
if self.seq_group_metadata_list is not None:
assert len(self.seq_group_metadata_list) == len(self.logits)
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
@property
def seq_ids(self) -> List[int]:
return self._seq_ids
def update(self,
logits: torch.Tensor,
seq_group_metadata_list: List[SequenceGroupMetadata]):
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert len(seq_group_metadata_list) == len(logits)
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.logits = torch.cat([self.logits, logits])
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
# Currently this prunes all seq_ids not present in
# seq_group_metadata_list which might cause problems where a sequence
# may be "paused" then "resumed" later. This should only prune sequences
# which are confirmed to be aborted.
seq_ids = get_all_seq_ids(seq_group_metadata_list)
if seq_ids != self._seq_ids:
# Batch contents changed - prune removed sequences.
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
self.logits = self.logits[index]
self._seq_ids = seq_ids
class ExecuteModelRequest( class ExecuteModelRequest(
msgspec.Struct, msgspec.Struct,
array_like=True, # type: ignore[call-arg] array_like=True, # type: ignore[call-arg]
...@@ -1296,6 +1354,8 @@ class ExecuteModelRequest( ...@@ -1296,6 +1354,8 @@ class ExecuteModelRequest(
running_queue_size: int = 0 running_queue_size: int = 0
# Optional hidden states from prior step. # Optional hidden states from prior step.
previous_hidden_states: Optional[HiddenStates] = None previous_hidden_states: Optional[HiddenStates] = None
# Optional logits from prior step.
previous_logits: Optional[Logits] = None
# The number of forward steps to run. # The number of forward steps to run.
num_steps: int = 1 num_steps: int = 1
# Finished request ids since last step. # Finished request ids since last step.
...@@ -1305,6 +1365,12 @@ class ExecuteModelRequest( ...@@ -1305,6 +1365,12 @@ class ExecuteModelRequest(
# Async callback # Async callback
async_callback: Optional[Callable] = None async_callback: Optional[Callable] = None
# Optional tree attention mask from draft model.
tree_attn_masks: Optional[torch.Tensor] = None
# Optional tree position ids from draft model.
tree_position_ids: Optional[torch.Tensor] = None
@property @property
def is_first_multi_step(self) -> bool: def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of # TODO(will) make this be able to handle batches with variable number of
...@@ -1346,8 +1412,11 @@ class ExecuteModelRequest( ...@@ -1346,8 +1412,11 @@ class ExecuteModelRequest(
num_lookahead_slots=self.num_lookahead_slots, num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size, running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states, previous_hidden_states=self.previous_hidden_states,
previous_logits=self.previous_logits,
num_steps=self.num_steps, num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids, finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone() last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None, if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback) async_callback=self.async_callback,
tree_attn_masks=self.tree_attn_masks,
tree_position_ids=self.tree_position_ids)
...@@ -112,6 +112,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -112,6 +112,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids=all_tokens, token_ids=all_tokens,
logprobs=spec_logprobs, logprobs=spec_logprobs,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
logits=target_sampler_output.logits,
) )
def _expand_batch( def _expand_batch(
...@@ -460,3 +461,191 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -460,3 +461,191 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids_to_score.extend(full_spec_token_ids[:i + 1] token_ids_to_score.extend(full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids))) for i in range(len(full_spec_token_ids)))
return token_ids_to_score return token_ids_to_score
class BatchExpansionTreeStyleScorer(BatchExpansionTop1Scorer):
def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
super().__init__(scorer_worker, device, vocab_size)
def _contract_batch(
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int], k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
# of shape [batch_size * k] back to [batch_size, k].
expanded_batch_size, k = proposals.proposal_token_ids.shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs = len(non_spec_target_token_ids)
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k)
target_probs = target_probs.reshape(*target_token_ids.shape,
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
all_tokens = target_token_ids.new_full(size=(contracted_bs, k),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))
if target_sampler_output.hidden_states is not None:
all_hidden_states = target_hidden_states.new_zeros(
size=(contracted_bs, k, target_hidden_states.shape[-1]))
else:
all_hidden_states = None
if non_spec_indices:
all_tokens[non_spec_indices, :1] = \
non_spec_target_token_ids.unsqueeze(1)
all_probs[non_spec_indices, :1, :] = \
non_spec_target_probs.unsqueeze(1)
all_logprobs[non_spec_indices, :1, :] = \
non_spec_target_logprobs.unsqueeze(1)
if all_hidden_states is not None:
assert non_spec_target_hidden_states is not None
all_hidden_states[non_spec_indices, :1, :] = \
non_spec_target_hidden_states.unsqueeze(1)
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states
return all_tokens, all_probs, all_logprobs, all_hidden_states
def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape
# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
contracted_bs, k)
target_probs = target_sampler_output.sampled_token_probs.reshape(
*target_token_ids.shape, self._vocab_size)
target_logprobs = target_sampler_output.logprobs.reshape(
target_probs.shape)
target_hidden_states = target_sampler_output.hidden_states
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
return (target_token_ids, target_probs, target_logprobs,
target_hidden_states)
@staticmethod
def _create_single_target_seq_group_metadata(
seq_group_metadata: SequenceGroupMetadata,
seq_id: SeqId,
target_seq_id: TargetSeqId,
token_ids: List[TokenId],
sampling_params: SamplingParams,
) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata.
Args:
seq_group_metadata: The metadata for the input sequence.
seq_id: The input sequence ID.
target_seq_id: The corresponding target sequence ID.
token_ids: The list of token ids that are to be appended to the
input sequence.
"""
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.prompt_token_ids_array
# first step need to ignore output token generated by prefill phase
if len(seq_data.get_output_token_ids()) == 1:
new_output_token_ids = [*seq_data.get_output_token_ids()[:-1], *token_ids]
else:
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
new_seq_data_dict = {
target_seq_id:
SequenceData(
prompt_token_ids,
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids),
),
}
# This is a hack. Technically, spec decoding should compute
# num_lookahead slots at one shot, but instead, it expands the batch
# and evaluate one by one right now. context_len is seq_len - 1 because
# the kv cache is filled by a previous batch in the batch expansion.
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
)
def _get_token_ids_to_score(
self,
full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of
token ids that should be scored.
Returns k+1 output lists. The additional one is used for generating the
bonus token.
Example:
Input: [0, 1, 2, 3] (k=4)
Output: (k+1 lists)
[0]
[0, 1]
[0, 1, 2]
[0, 1, 2, 3]
"""
token_ids_to_score = []
token_ids_to_score.extend([
full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids))
])
return token_ids_to_score
...@@ -25,6 +25,19 @@ class SpeculativeProposals: ...@@ -25,6 +25,19 @@ class SpeculativeProposals:
# A flag to mark that there's no available proposals # A flag to mark that there's no available proposals
no_proposals: bool = False no_proposals: bool = False
# The cart_candidates used in tree-style generation
cart_candidates: Optional[torch.Tensor] = None
# The cart_candidates used in tree-style generation
retrieve_indices: Optional[torch.Tensor] = None
# tree-style attention masks
tree_attn_masks: Optional[torch.Tensor] = None
# tree-style cartesian candidates
tree_position_ids: Optional[torch.Tensor] = None
def __repr__(self): def __repr__(self):
return (f"SpeculativeProposals(" return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids}, " f"proposal_token_ids={self.proposal_token_ids}, "
...@@ -53,6 +66,9 @@ class SpeculativeScores: ...@@ -53,6 +66,9 @@ class SpeculativeScores:
# Optional last hidden states from the scoring model. # Optional last hidden states from the scoring model.
hidden_states: Optional[torch.Tensor] = None hidden_states: Optional[torch.Tensor] = None
# Optional lm_head logits from the scoring model.
logits: Optional[torch.Tensor] = None
def __repr__(self): def __repr__(self):
return (f"SpeculativeScores(" return (f"SpeculativeScores("
f"probs={self.probs.shape}, " f"probs={self.probs.shape}, "
......
...@@ -2,35 +2,63 @@ import weakref ...@@ -2,35 +2,63 @@ import weakref
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
import torch.nn.functional as F
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
class MedusaWorker(NonLLMProposerWorkerBase, Worker): class MedusaWorker(NonLLMProposerWorkerBase, Worker):
"""Worker for Medusa. """Worker for Medusa.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) # skip lora config in medusa
kwargs_copy = kwargs.copy()
kwargs_copy['lora_config'] = None
super().__init__(*args, **kwargs_copy)
# Lazy initialization list. # Lazy initialization list.
self._proposer: Top1Proposer self._proposer: SpeculativeProposer
def init_device(self): def init_device(self):
super().init_device() super().init_device()
def load_model(self):
super().load_model()
# get medusa choices and generate medusa_buffers
self.medusa_buffers = None
if hasattr(self.model_runner.model, 'medusa_choices'):
self.medusa_choices = self.model_runner.model.medusa_choices
self.medusa_buffers = self.generate_medusa_buffers(
self.medusa_choices, device=self.device
)
if self.medusa_buffers is None:
self._proposer = Top1Proposer( self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type] weakref.proxy(self), # type: ignore[arg-type]
self.device, self.device,
self.vocab_size, self.vocab_size,
max_proposal_len=self.max_model_len, max_proposal_len=self.max_model_len,
) )
else:
self._proposer = TreeStyleProposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
self.medusa_buffers,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self): def set_include_gpu_probs_tensor(self):
pass pass
...@@ -69,7 +97,24 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -69,7 +97,24 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
model_outputs = self.model_runner.model.generate_proposals( model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states. previous_hidden_states=execute_model_req.previous_hidden_states.
hidden_states, hidden_states,
sampling_metadata=sampling_metadata) sampling_metadata=sampling_metadata,
previous_logits=execute_model_req.previous_logits.logits if execute_model_req.previous_logits is not None else None,
medusa_buffers=self.medusa_buffers)
# create tree attn masks
if self.medusa_buffers is not None:
max_context_len = max(seq_lens)
for sampler_output, seq_len in zip(model_outputs, seq_lens):
context_len = seq_len
attn_masks = self.medusa_buffers['tree_attn_masks']
left_mask = torch.ones(attn_masks.shape[0], context_len,
dtype=attn_masks.dtype,
device=attn_masks.device)
attn_masks = torch.cat([left_mask, attn_masks], dim=-1)
right_pad = max_context_len - context_len
if right_pad > 0:
attn_masks = F.pad(attn_masks, (0, right_pad), "constant", 0)
sampler_output.tree_attn_masks = attn_masks
return model_outputs, False return model_outputs, False
...@@ -96,6 +141,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -96,6 +141,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
seq_lens.append(seq_len) seq_lens.append(seq_len)
query_lens.append(seq_len - context_len) query_lens.append(seq_len - context_len)
else: else:
# first step of tree decoding need to ignore first token
if self.medusa_buffers is not None and seq_data.get_output_len() == 1:
seq_data_len -= 1
seq_lens.append(seq_data_len) seq_lens.append(seq_data_len)
query_lens.append(1) query_lens.append(1)
...@@ -134,3 +182,124 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -134,3 +182,124 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
execute_model_req.seq_group_metadata_list): execute_model_req.seq_group_metadata_list):
raise NotImplementedError( raise NotImplementedError(
"MedusaWorker does not support beam search.") "MedusaWorker does not support beam search.")
def pad_path(self, path, length, pad_value=-2):
"""
Pad the given path list with a specific value up to a specified length.
Parameters:
- path (list): The original list that needs padding.
- length (int): The desired length of the padded list.
- pad_value (optional, default=-2): The value to use for padding.
Returns:
- list: A new list based on the original path but padded to the desired length.
Example:
>>> pad_path([1,2,3], 5)
[1, 2, 3, -2, -2]
Note:
If the given path is already longer than the specified length,
then no padding occurs, and the original path is returned.
"""
# Calculate the number of padding values needed by subtracting the length
# of the path from the desired length.
# Append the padding values to the original path and return the new list.
return path + [pad_value] * (length - len(path))
def generate_medusa_buffers(self, medusa_choices, device="cuda"):
"""
Generate buffers for the Medusa structure based on the provided choices.
Parameters:
- medusa_choices (list): A nested list representing tree in the Medusa structure.
- device (str): Device to which the tensors should be moved. Default is "cuda".
Returns:
- dict: A dictionary containing buffers related to the Medusa structure.
"""
# Sort the medusa_choices based on their lengths and then their values
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
medusa_len = len(sorted_medusa_choices) + 1
# Initialize depth_counts to keep track of how many choices have a particular depth
depth_counts = []
prev_depth = 0
for path in sorted_medusa_choices:
depth = len(path)
if depth != prev_depth:
depth_counts.append(0)
depth_counts[depth - 1] += 1
prev_depth = depth
# Create the attention mask for Medusa
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
medusa_attn_mask[:, 0] = 1
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
# retrieve ancestor position
if len(cur_medusa_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_medusa_choice) - 1):
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
start += depth_counts[i]
# Generate tree indices for the Medusa structure
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
medusa_tree_indices[0] = 0
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
start += depth_counts[i]
# Generate position IDs for the Medusa structure
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
start = 0
for i in range(len(depth_counts)):
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
start += depth_counts[i]
# Generate retrieval indices for Medusa structure verification
retrieve_indices_nest = []
retrieve_paths = []
for i in range(len(sorted_medusa_choices)):
cur_medusa_choice = sorted_medusa_choices[-i-1]
retrieve_indice = []
if cur_medusa_choice in retrieve_paths:
continue
else:
for c in range(len(cur_medusa_choice)):
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
retrieve_paths.append(cur_medusa_choice[:c+1])
retrieve_indices_nest.append(retrieve_indice)
max_length = max([len(x) for x in retrieve_indices_nest])
retrieve_indices = [self.pad_path(path, max_length) for path in retrieve_indices_nest]
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
retrieve_indices = retrieve_indices + 1
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
# Aggregate the generated buffers into a dictionary
medusa_buffers = {
"tree_attn_masks": medusa_attn_mask.int(),
"tree_indices": medusa_tree_indices,
"tree_position_ids": medusa_position_ids,
"retrieve_indices": retrieve_indices,
}
# Move the tensors in the dictionary to the specified device
medusa_buffers = {
k: v.clone().to(device)
if isinstance(v, torch.Tensor)
else torch.tensor(v, device=device)
for k, v in medusa_buffers.items()
}
return medusa_buffers
...@@ -16,8 +16,8 @@ from vllm.model_executor.layers.typical_acceptance_sampler import ( ...@@ -16,8 +16,8 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest, CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata, HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids) get_all_seq_ids_and_request_ids, Logits)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer, BatchExpansionTreeStyleScorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
...@@ -36,6 +36,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output, ...@@ -36,6 +36,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -80,6 +82,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -80,6 +82,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs, disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats, disable_log_stats=speculative_config.disable_log_stats,
tree_style_spec_decoding=speculative_config.tree_style_spec_decoding,
) )
return spec_decode_worker return spec_decode_worker
...@@ -122,6 +125,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -122,6 +125,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha: float, typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool, disable_logprobs: bool,
disable_log_stats: bool, disable_log_stats: bool,
tree_style_spec_decoding: bool,
) -> "SpecDecodeWorker": ) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True allow_zero_draft_token_step = True
...@@ -183,7 +187,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -183,7 +187,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_log_stats=disable_log_stats, disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size, disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler, spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step) allow_zero_draft_token_step=allow_zero_draft_token_step,
tree_style_spec_decoding=tree_style_spec_decoding)
def __init__( def __init__(
self, self,
...@@ -195,6 +200,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -195,6 +200,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None, disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True, allow_zero_draft_token_step: Optional[bool] = True,
tree_style_spec_decoding: bool = False,
): ):
""" """
Create a SpecDecodeWorker. Create a SpecDecodeWorker.
...@@ -223,6 +229,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -223,6 +229,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
allow_zero_draft_token_step: whether to allow a step where the draft allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814) draft model is larger than 1 (TODO: #5814)
tree_style_spec_decoding: Whether to use tree-style generation.
""" """
self.proposer_worker = proposer_worker self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker self.scorer_worker = scorer_worker
...@@ -247,14 +254,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -247,14 +254,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.probs_dtype = self.spec_decode_sampler.probs_dtype self.probs_dtype = self.spec_decode_sampler.probs_dtype
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
# Lazy initialization. # Lazy initialization.
self.scorer: SpeculativeScorer self.scorer: BatchExpansionTop1Scorer
# Hidden states from target model to pass to proposer # Hidden states from target model to pass to proposer
# in the subsequent step. # in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None self.previous_hidden_states: Optional[HiddenStates] = None
self.previous_logits: Optional[Logits] = None
self._disable_logprobs = disable_logprobs self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats self._disable_log_stats = disable_log_stats
self.tree_style_spec_decoding = tree_style_spec_decoding
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize both scorer and proposer models. """Initialize both scorer and proposer models.
""" """
...@@ -270,10 +280,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -270,10 +280,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._metrics.init_gpu_tensors(self.rank) self._metrics.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank) self.spec_decode_sampler.init_gpu_tensors(self.rank)
if not self.tree_style_spec_decoding:
self.scorer = BatchExpansionTop1Scorer( self.scorer = BatchExpansionTop1Scorer(
scorer_worker=self.scorer_worker, scorer_worker=self.scorer_worker,
device=self.device, device=self.device,
vocab_size=self._vocab_size) vocab_size=self._vocab_size)
else:
self.scorer = BatchExpansionTreeStyleScorer(
scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
self._configure_model_sampler_for_spec_decode() self._configure_model_sampler_for_spec_decode()
...@@ -532,6 +548,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -532,6 +548,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.previous_hidden_states.update( self.previous_hidden_states.update(
hidden_states, execute_model_req.seq_group_metadata_list) hidden_states, execute_model_req.seq_group_metadata_list)
# Store logits from target model execution.
logits = sampler_output.logits
if logits is not None:
if self.previous_logits is None:
self.previous_logits = Logits(
logits, execute_model_req.seq_group_metadata_list)
else:
self.previous_logits.update(
logits, execute_model_req.seq_group_metadata_list)
if not skip_proposer: if not skip_proposer:
# We prepare the prefill hidden states here so that there no # We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode # additional complexity in worker for spec_decode vs non_spec_decode
...@@ -605,6 +631,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -605,6 +631,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req.previous_hidden_states = self.previous_hidden_states execute_model_req.previous_hidden_states = self.previous_hidden_states
self.previous_hidden_states = None self.previous_hidden_states = None
# Pass last logits from target model to proposer
execute_model_req.previous_logits = self.previous_logits
self.previous_logits = None
with Timer() as proposal_timer: with Timer() as proposal_timer:
# Generate proposals using draft worker. # Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals( proposals = self.proposer_worker.get_spec_proposals(
...@@ -615,6 +645,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -615,6 +645,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
raise RuntimeError("Cannot handle cases where distributed draft " raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens") "workers generate no tokens")
# Pass tree attention mask and postions to target model
if self.tree_style_spec_decoding:
execute_model_req.tree_attn_masks = proposals.tree_attn_masks
execute_model_req.tree_position_ids = proposals.tree_position_ids
execute_model_req.previous_hidden_states = None execute_model_req.previous_hidden_states = None
with Timer() as scoring_timer: with Timer() as scoring_timer:
...@@ -624,10 +659,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -624,10 +659,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) )
with Timer() as verification_timer: with Timer() as verification_timer:
accepted_token_ids, target_logprobs = self._verify_tokens( accepted_token_ids, target_logprobs, select_indices_list, accept_lengths = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores, execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots) proposals, execute_model_req.num_lookahead_slots)
# move kv_caches of selected tokens to right positions
if self.tree_style_spec_decoding:
self.move_caches(execute_model_req, select_indices_list, accept_lengths)
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
scoring_timer.elapsed_time_ms, scoring_timer.elapsed_time_ms,
verification_timer.elapsed_time_ms) verification_timer.elapsed_time_ms)
...@@ -646,7 +685,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -646,7 +685,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores: SpeculativeScores, proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
max_proposal_len: int, max_proposal_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], List[int]]:
"""Determine which speculative tokens are accepted using the """Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models. probabilities of each token according to the proposer and scorer models.
...@@ -666,6 +705,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -666,6 +705,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get probabilities of target model, including bonus tokens. # Get probabilities of target model, including bonus tokens.
proposal_verifier_probs = proposal_scores.probs[spec_indices] proposal_verifier_probs = proposal_scores.probs[spec_indices]
if self.tree_style_spec_decoding:
retrieve_indices = proposals.retrieve_indices
proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]
# Get non-speculative sampled tokens from target model. # Get non-speculative sampled tokens from target model.
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
...@@ -673,11 +716,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -673,11 +716,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
# Get probabilities according to proposal method. # Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs[spec_indices] proposal_probs = proposals.proposal_probs[spec_indices] \
if proposals.proposal_probs is not None else None
# Get proposed tokens. # Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices] proposal_token_ids = proposals.proposal_token_ids[spec_indices]
# Get tree buffers.
cart_candidates = proposals.cart_candidates[spec_indices] if proposals.cart_candidates is not None else None
# Sampler arguments # Sampler arguments
sampler_extra_kwargs: Dict[str, Any] = {} sampler_extra_kwargs: Dict[str, Any] = {}
if self.generators and isinstance(self.spec_decode_sampler, if self.generators and isinstance(self.spec_decode_sampler,
...@@ -688,6 +735,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -688,6 +735,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if sgm.sampling_params.seed is not None if sgm.sampling_params.seed is not None
} }
if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler):
sampler_extra_kwargs["cart_candidates"] = cart_candidates
sampler_extra_kwargs["best_candidates"] = []
sampler_extra_kwargs["accept_lengths"] = []
first_step_flags = []
for i, sgm in enumerate(seq_group_metadata_list):
seq = next(iter(sgm.seq_data.values()))
first_step_flags.append(True if seq.get_output_len() == 1 else False)
sampler_extra_kwargs["first_step_flags"] = first_step_flags
accepted_token_ids = self.spec_decode_sampler( accepted_token_ids = self.spec_decode_sampler(
target_with_bonus_probs=proposal_verifier_probs, target_with_bonus_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids, bonus_token_ids=bonus_token_ids,
...@@ -697,8 +756,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -697,8 +756,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) )
# Append output tokens from non-speculative sequences to # Append output tokens from non-speculative sequences to
# the accepted token ids tensor. # the accepted token ids tensor.
if not self.tree_style_spec_decoding:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
1).clone() 1).clone()
else:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone()
non_spec_token_ids[:, 1:] = -1 non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat( accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids]) [accepted_token_ids, non_spec_token_ids])
...@@ -708,6 +771,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -708,6 +771,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids[original_indices] = accepted_token_ids.clone() accepted_token_ids[original_indices] = accepted_token_ids.clone()
hidden_states = proposal_scores.hidden_states hidden_states = proposal_scores.hidden_states
select_indices = None
accept_lengths = None
select_indices_list = []
if cart_candidates is None:
if hidden_states is not None: if hidden_states is not None:
# Contract hidden states based on accepted tokens # Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[-1] hs_size = hidden_states.shape[-1]
...@@ -721,8 +791,129 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -721,8 +791,129 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.previous_hidden_states = HiddenStates( self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_metadata_list, hidden_states, seq_group_metadata_list,
second_last_token_hidden_states) second_last_token_hidden_states)
else:
retrieve_indices = proposals.retrieve_indices
batch_size = len(seq_group_metadata_list)
best_candidates = sampler_extra_kwargs["best_candidates"]
accept_lengths = sampler_extra_kwargs["accept_lengths"]
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[-1]
hidden_states = hidden_states.view(batch_size, -1, hs_size)
# Store logits from target model for subsequent proposal
logits = proposal_scores.logits
logits = logits.view(batch_size, -1, logits.shape[-1])
logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size]
previous_logits_list = []
previous_hidden_state_list = []
for i in range(batch_size):
logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
previous_logits_list.append(logit)
select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1]
hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0)
select_indices_list.append(select_indices)
previous_hidden_state_list.append(hidden_state)
logits = torch.cat(previous_logits_list, dim=0)
self.previous_logits = Logits(logits, seq_group_metadata_list)
hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size]
self.previous_hidden_states = HiddenStates(hidden_states,
seq_group_metadata_list,)
return accepted_token_ids, logprobs, select_indices_list, accept_lengths
def move_caches(self, execute_model_req: ExecuteModelRequest,
select_indices_list: List[torch.Tensor],
accept_lengths: List[int]):
"""Given selected output tokens and accept length,
move kv_caches of selected tokens to right positions.
"""
seq_lens = []
for sg in execute_model_req.seq_group_metadata_list:
seq_ids = list(sg.seq_data.keys())
for seq_id in seq_ids:
seq_data = sg.seq_data[seq_id]
seq_len = seq_data.get_len()
token_chunk_size = sg.token_chunk_size
context_len = seq_len - 1
seq_len = min(seq_len, context_len + token_chunk_size)
# first step of tree-style decoding need to ignore first generated token
if seq_data.get_output_len() == 1:
seq_len -= 1
seq_lens.append(seq_len)
model_input = self.scorer._scorer_worker.model_input
block_tables = None
if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables'):
block_tables = model_input.attn_metadata.block_tables
if block_tables is None:
raise RuntimeError("Can not get block_tables from model_input.")
block_tables = block_tables.cpu().tolist()
cache_engine = self.scorer._scorer_worker.cache_engines[execute_model_req.virtual_engine]
block_size = cache_engine.block_size
batch_size = len(select_indices_list)
block_table_stride = len(block_tables) // batch_size
select_indices_slot_mapping = []
target_slot_mapping = []
for i in range(batch_size):
accept_legth = accept_lengths[i]
if accept_legth > 0:
select_indices = select_indices_list[i][1:] + seq_lens[i]
self.compute_slot_mapping(select_indices_slot_mapping, i*block_table_stride,
select_indices, block_size, block_tables)
target_indices = torch.arange(accept_legth+1)[1:] + seq_lens[i]
self.compute_slot_mapping(target_slot_mapping, i*block_table_stride,
target_indices, block_size, block_tables)
if len(select_indices_slot_mapping) >0:
select_indices_slot_tensor = torch.tensor(select_indices_slot_mapping,
dtype=torch.long,
device=self.device).view(-1, 1)
target_slot_mapping_tensor = torch.tensor(target_slot_mapping,
dtype=torch.long,
device=self.device).view(-1, 1)
src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2]
kv_caches = self.scorer._scorer_worker.kv_cache[execute_model_req.virtual_engine]
kv_cache_dtype = cache_engine.cache_config.cache_dtype
backend = cache_engine.attn_backend
num_kv_heads = cache_engine.num_kv_heads
head_size = cache_engine.head_size
backend.move_cache(kv_caches, src_dst_tensor, kv_cache_dtype, num_kv_heads, head_size)
def compute_slot_mapping(self, slot_mapping: List[int],
seq_id: int, select_indices: List[int], block_size: int,
block_tables: List[List[int]]):
"""
Compute slot mapping.
"""
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table = block_tables[seq_id]
for index in select_indices:
block_number = block_table[index // block_size]
block_offset = index % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
return accepted_token_ids, logprobs
def _create_output_sampler_list( def _create_output_sampler_list(
self, self,
......
from typing import List, Optional, Set, Tuple, Any, Dict
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (ExecuteModelRequest,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.util import sampler_output_to_torch
class TreeStyleProposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
worker: ProposerWorkerBase,
device: str,
vocab_size: int,
tree_buffers: Dict[str, Any],
max_proposal_len: Optional[int] = None,
):
self._worker = worker
self._device = device
self.tree_buffers = tree_buffers
self.max_proposal_len = max_proposal_len
self._vocab_size = vocab_size
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
#proposal_len = execute_model_req.num_lookahead_slots
proposal_len = self.tree_buffers["tree_indices"].shape[0]
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
hidden_states = execute_model_req.previous_hidden_states
if hidden_states is not None:
hidden_states.prune(nonzero_proposal_len_seqs)
logits = execute_model_req.previous_logits
if logits is not None:
logits.prune(nonzero_proposal_len_seqs)
nonzero_execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=nonzero_proposal_len_seqs,
num_lookahead_slots=proposal_len,
previous_hidden_states=hidden_states,
previous_logits=logits,
)
maybe_sampler_output, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
seq_ids_with_bonus_token_in_last_step=\
seq_ids_with_bonus_token_in_last_step,
)
(
proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
) = self._remove_no_proposal_seqs(proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
transposed)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
transposed = False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens, cart_candidates, tree_attn_masks = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
proposal_len=proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
sampler_transposed=transposed,
)
tree_position_ids_list = []
for seq_group_metadata in seq_group_metadata_list:
seq_data = next(iter(seq_group_metadata.seq_data.values()))
if seq_data.get_output_len() == 1:
seq_len = seq_data.get_len() -1
else:
seq_len = seq_data.get_len()
tree_position_ids = self.tree_buffers['tree_position_ids'] + seq_len
tree_position_ids_list.append(tree_position_ids)
tree_position_ids = torch.stack(tree_position_ids_list, dim=0).reshape(-1, 1)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
no_proposals=maybe_sampler_output is None,
cart_candidates=cart_candidates,
retrieve_indices=self.tree_buffers['retrieve_indices'],
tree_attn_masks=tree_attn_masks,
tree_position_ids=tree_position_ids)
return proposals
def _split_by_proposal_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Split sequences by two groups:
1. Sequences with non-zero proposal length.
2. Sequences with zero proposal length (due to disabled speculation
or exceed the maximum model length).
"""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
# The speculative decoding for this request has been disabled
# (e.g. due to high traffic).
if seq_group_metadata.num_speculative_tokens == 0:
proposal_lens.append(0)
continue
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exceed this
# quota for nonzero_proposal
new_k = 0
if (self.max_proposal_len is None
or seq_len + proposal_len < self.max_proposal_len):
new_k = proposal_len
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
proposal_lens.append(new_k)
seq_group_metadata.num_speculative_tokens = new_k
return (
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
)
@staticmethod
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices, transposed):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
(maybe_sampler_output=None). This can avoid scoring overheads.
"""
# If maybe_sampler_output is None, then the draft worker did not
# provide a proposal for any sequence and thus no action needed.
# Also we do not support transposed maybe_sampler_output for now
# because it seems not straightforward for draft workers outputting
# transposed sampler outputs to handle the case of no proposal.
if maybe_sampler_output is None or transposed:
return (proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices)
new_proposal_lens: List[int] = []
new_nonzero_proposal_len_indices: List[int] = []
new_maybe_sampler_output: List[SamplerOutput] = []
nonzero_proposal_len_idx_ptr = 0
seq_idx = 0
while seq_idx < len(
proposal_lens) and nonzero_proposal_len_idx_ptr < len(
nonzero_proposal_len_indices):
if seq_idx < nonzero_proposal_len_indices[
nonzero_proposal_len_idx_ptr]:
# Sequence is not in the original nonzero_proposal_len_indices,
# meaning that it has a proposal length of 0 before sending to
# the draft worker.
assert proposal_lens[seq_idx] == 0
new_proposal_lens.append(0)
else:
# Sequence is in the original nonzero_proposal_len_indices
if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
# but does not have a proposal from the draft worker.
new_proposal_lens.append(0)
else:
# and has a proposal from the draft worker. Add it to the
# new nonzero proposal list and keep the sampler output.
new_proposal_lens.append(proposal_lens[seq_idx])
new_nonzero_proposal_len_indices.append(seq_idx)
new_maybe_sampler_output.append(
maybe_sampler_output[nonzero_proposal_len_idx_ptr])
nonzero_proposal_len_idx_ptr += 1
seq_idx += 1
# The remaining sequences should have proposal length of 0.
new_proposal_lens.extend(proposal_lens[seq_idx:])
# We assume sampler_output will not be a list of all Nones.
# In this case this function should not be called.
assert new_maybe_sampler_output
return (new_proposal_lens, new_maybe_sampler_output,
new_nonzero_proposal_len_indices)
def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
retrieve_indices = self.tree_buffers["retrieve_indices"]
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.tensor(-1,
dtype=torch.long,
device=self._device).expand(
batch_size, proposal_len)
proposal_probs = torch.tensor(0,
dtype=torch.float32,
device=self._device).expand(
batch_size, proposal_len,
self._vocab_size)
proposal_lens_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
len(proposal_lens))
cart_candidates_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
batch_size, retrieve_indices.shape[0],
retrieve_indices.shape[1])
tree_attn_masks_tensor = torch.tensor(0,
dtype=torch.int32,
device=self._device).expand(
batch_size, self.tree_buffers["tree_attn_masks"].shape[0],
self.tree_buffers["tree_attn_masks"].shape[1])
return proposal_tokens, proposal_probs, proposal_lens_tensor, cart_candidates_tensor, tree_attn_masks_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, _, _, cart_candidates, tree_attn_masks = sampler_output_to_torch(
sampler_output, sampler_transposed)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = None
proposal_tokens, proposal_probs = (
entire_proposal_tokens,
entire_proposal_probs,
)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
entire_cart_candidates = cart_candidates.new_zeros(
batch_size,
*cart_candidates.shape[1:],
)
entire_cart_candidates[nonzero_proposal_len_indices] = cart_candidates
entire_tree_attn_masks = tree_attn_masks.new_zeros(
batch_size,
*tree_attn_masks.shape[1:],
)
entire_tree_attn_masks[nonzero_proposal_len_indices] = tree_attn_masks
entire_tree_attn_masks = entire_tree_attn_masks.reshape(-1, tree_attn_masks.shape[-1])
return proposal_tokens, proposal_probs, proposal_lens_tensor, entire_cart_candidates, entire_tree_attn_masks
...@@ -147,7 +147,8 @@ def split_batch_by_proposal_len( ...@@ -147,7 +147,8 @@ def split_batch_by_proposal_len(
def sampler_output_to_torch( def sampler_output_to_torch(
sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Utility function which converts a list of SamplerOutput to tensors. """Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether sampler_transposed here is used as the indicator for whether
...@@ -162,6 +163,8 @@ def sampler_output_to_torch( ...@@ -162,6 +163,8 @@ def sampler_output_to_torch(
""" """
# shape: [batch_size, num_sampler_output, vocab_size] # shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_probs = None
if sampler_output_list[0].sampled_token_probs is not None:
sampled_token_probs = torch.stack( sampled_token_probs = torch.stack(
[ [
sampler_output.sampled_token_probs sampler_output.sampled_token_probs
...@@ -171,6 +174,8 @@ def sampler_output_to_torch( ...@@ -171,6 +174,8 @@ def sampler_output_to_torch(
) )
# shape: [batch_size, num_sampler_output, vocab_size] # shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs = None
if sampler_output_list[0].logprobs is not None:
sampled_token_logprobs = torch.stack( sampled_token_logprobs = torch.stack(
[sampler_output.logprobs for sampler_output in sampler_output_list], [sampler_output.logprobs for sampler_output in sampler_output_list],
dim=0, dim=0,
...@@ -205,8 +210,33 @@ def sampler_output_to_torch( ...@@ -205,8 +210,33 @@ def sampler_output_to_torch(
else: else:
sampled_hidden_states = None sampled_hidden_states = None
sampled_cart_candidates = None
if sampler_output_list[0].cart_candidates is not None:
sampled_cart_candidates = torch.cat(
[
sampler_output.cart_candidates
for sampler_output in sampler_output_list
],
dim=0,
)
if sampler_transposed:
sampled_cart_candidates = sampled_cart_candidates.transpose(0, 1)
sampled_tree_attn_masks = None
if sampler_output_list[0].tree_attn_masks is not None:
sampled_tree_attn_masks = torch.stack(
[
sampler_output.tree_attn_masks
for sampler_output in sampler_output_list
],
dim=0,
)
if sampler_transposed:
sampled_tree_attn_masks = sampled_tree_attn_masks.transpose(0, 1)
return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
sampled_hidden_states) sampled_hidden_states, sampled_cart_candidates, sampled_tree_attn_masks)
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
......
import os import os
from typing import Optional, Union from typing import Optional, Union, List
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -15,6 +15,7 @@ class MedusaConfig(PretrainedConfig): ...@@ -15,6 +15,7 @@ class MedusaConfig(PretrainedConfig):
max_paths: int = 64, max_paths: int = 64,
topk: int = 10, topk: int = 10,
truncated_vocab_size: Optional[int] = None, truncated_vocab_size: Optional[int] = None,
medusa_choices:List[List[int]] = None,
**kwargs): **kwargs):
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -26,6 +27,7 @@ class MedusaConfig(PretrainedConfig): ...@@ -26,6 +27,7 @@ class MedusaConfig(PretrainedConfig):
self.max_seq_len = int(2**20) self.max_seq_len = int(2**20)
self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\ self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\
else truncated_vocab_size else truncated_vocab_size
self.medusa_choices = medusa_choices
if "architectures" not in kwargs: if "architectures" not in kwargs:
kwargs["architectures"] = ["MedusaModel"] kwargs["architectures"] = ["MedusaModel"]
...@@ -51,6 +53,14 @@ class MedusaConfig(PretrainedConfig): ...@@ -51,6 +53,14 @@ class MedusaConfig(PretrainedConfig):
def num_attention_heads(self): def num_attention_heads(self):
return 0 return 0
@property
def num_lookahead_heads(self):
return self.num_heads
@num_lookahead_heads.setter
def num_lookahead_heads(self, num_lookahead_heads: int):
self.num_heads = num_lookahead_heads
@property @property
def num_lookahead_tokens(self): def num_lookahead_tokens(self):
return self.num_heads return self.num_heads
......
...@@ -13,6 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -13,6 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.worker.cache_engine import CacheEngine
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner
...@@ -309,6 +310,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -309,6 +310,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache return self.cpu_cache
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
def execute_worker( def execute_worker(
self, self,
worker_input: WorkerInput, worker_input: WorkerInput,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment