Commit 01df9361 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_medusa' into 'v0.6.2-dev'

修复medusa多卡推理和其他模型推理报错的问题

See merge request dcutoolkit/deeplearing/vllm!31
parents 505c5ff3 f9060e6b
......@@ -297,18 +297,19 @@ __device__ void paged_attention_kernel(
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
bool mask = token_idx >= seq_len;
// used for tree-style attention
if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0;
if (attn_masks_ptr[token_idx] == 0) {
qk = -FLT_MAX;
}
}
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f: qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
......
......@@ -327,17 +327,26 @@ __device__ void paged_attention_kernel_opt(
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
// used for tree-style attention
if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
if (attn_masks_ptr[token_idx] == 0) {
qk = -FLT_MAX;
}
}
__builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
bool mask = token_idx >= seq_len;
const bool mask = token_idx >= seq_len;
// used for tree-style attention
if (attn_masks != nullptr) {
/*if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0;
}
}*/
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value.
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
......
......@@ -290,15 +290,18 @@ __device__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE+rowid;
if(alibi_slope[i] != 0){
float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
qk_vec[i] += alibi;
qk_vec[i] = alibi;
}
bool mask = (token_idx >= seq_len);
// used for tree-style attention
if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0;
if (attn_masks_ptr[token_idx] == 0) {
qk_vec[i] = -FLT_MAX;
}
}
const bool mask = (token_idx >= seq_len);
if(mask){
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
}
......
# Medusa Decoding
本文说明如何使用vllm构建和运行medusa模型,目前medusa支持tree-style generation,target model和draft model均可多卡推理
## Overview
与其他模型不同,medusa解码需要一个base model和若干Medusa heads.
Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
## Support Matrix
* FP16
* BF16
* PAGED_KV_CACHE
* Tensor Parallel
### convert Medusa model weights
# medusa 模型需要转换为vllm中Medusa的模型格式
```bash
python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/medusa/qwen2_72b_head_4/adapter_model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/sugon/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0,), (0, 0), (0, 0, 0), (0, 0, 0, 0), (0, 1), (1,), (1, 0), (0, 0, 1), (0, 1, 0), (0, 2), (1, 0, 0), (2,), (2, 0), (0, 3), (0, 0, 2), (0, 2, 0), (0, 4), (0, 0, 1, 0), (0, 1, 0, 0), (2, 0, 0), (3,), (0, 5), (0, 0, 0, 1), (3, 0), (0, 0, 3), (1, 0, 0, 0), (0, 3, 0), (0, 6), (0, 0, 4), (0, 4, 0), (1, 1), (4,)]"
```
此处qwen2_72b_head_4是medusa模型使用peft lora训练后保存的权重,其他格式也可参考[medusa_weight_converter.py]修改进行权重转换
### Run
```bash
python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.7 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 4 \
--use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \
--enforce-eager --dtype float16 --trust-remote-code --port 8086\
--enable-lora --lora-modules medusa-lora=/work/qwen2_72b_head_4 \
--max-lora-rank 32 --lora-extra-vocab-size 0 --merge-lora True \
--lora-target-modules qkv_proj \
--tree-style-spec-decoding True\
--num-speculative-heads 4 --num-speculative-tokens 33
```
merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) + 2
# do request
```bash
curl http://localhost:8086/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "medusa-lora",
"prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n",
"max_tokens": 256,
"temperature": 0.0
}'
```bash
import os
import ast
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union
from addict import Dict
......@@ -21,12 +22,12 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'base_model.model.medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.bias'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight'
VLLM_BLOCK_BIAS_NAME_TEMPLATE = 'blocks.{}.layers.{}.bias'
VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE = 'lm_heads.{}.weight'
MEDUSA_CHOICES = [(0,), (0, 0), (0, 0, 0), (1,), (0, 1), (1, 0), (0, 2), (2,), (0, 0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 3), (2, 0), (1, 0, 0), (3,), (0, 0, 2), (0, 4), (0, 2, 0), (0, 5), (4,), (1, 1), (0, 0, 3), (3, 0), (0, 6), (0, 0, 0, 1), (0, 3, 0), (0, 0, 4), (0, 0, 1, 0), (2, 0, 0), (5,), (0, 1, 0, 0), (0, 7)]
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
......@@ -218,7 +219,7 @@ class ResidualBlock(nn.Module):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size, bias=False)
nn.Linear(hidden_size, hidden_size)
for _ in range(num_layers)
])
self.act = nn.SiLU()
......@@ -319,13 +320,8 @@ class CustomMedusaConfig(PretrainedConfig):
def main(args):
# load the medusa config from the yaml file
medusa_config_path=args.medusa_config_path
with open(medusa_config_path, encoding="utf-8") as file:
medusa_cfg: Dict = Dict(yaml.safe_load(file))
medusa_head_num = medusa_cfg.medusa_num_heads
medusa_num_layers = medusa_cfg.medusa_num_layers
medusa_head_num = args.medusa_num_heads
medusa_num_layers = args.medusa_num_layers
config = MedusaConfig(hidden_size=args.hidden_size, vocab_size=args.vocab_size, num_heads=medusa_head_num)
medusa_model = Medusa(config)
......@@ -346,6 +342,7 @@ def main(args):
for i in range(medusa_head_num):
for j in range(medusa_num_layers):
# load linear weight
vllm_medusa_block_weight_name = VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j)
trained_medusa_block_weight_name = TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j)
......@@ -356,38 +353,59 @@ def main(args):
default_weight_loader)
weight_loader(vllm_medusa_block_param, trained_medusa_block_param)
# load linear bias
vllm_medusa_block_bias_name = VLLM_BLOCK_BIAS_NAME_TEMPLATE.format(i, j)
trained_medusa_block_bias_name = TRAINED_BLOCK_BIAS_NAME_TEMPLATE.format(i, j)
vllm_medusa_block_bias_param = params_dict[vllm_medusa_block_bias_name]
trained_medusa_block_bias_param = trained_medusa_model[trained_medusa_block_bias_name]
weight_loader = getattr(vllm_medusa_block_bias_param, "weight_loader",
default_weight_loader)
weight_loader(vllm_medusa_block_bias_param, trained_medusa_block_bias_param)
if not Path(args.output_dir).is_dir():
os.makedirs(args.output_dir, exist_ok=True)
save_model(medusa_model, os.path.join(args.output_dir, "model.safetensors"))
medusa_choices = ast.literal_eval(args.medusa_choices)
to_save_config = CustomMedusaConfig(name_or_path=os.path.join(args.output_dir, "config.json"),
hidden_size=args.hidden_size,
num_heads=medusa_head_num,
num_hidden_layers=medusa_num_layers,
vocab_size=args.vocab_size,
medusa_choices=MEDUSA_CHOICES)
medusa_choices=medusa_choices)
to_save_config.save_pretrained(args.output_dir)
# validate weight
# with safe_open("model.safetensors", framework="pt") as f:
# param = f.get_tensor(VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(0, 0))
# trained_param = trained_medusa_model[TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(0, 0)]
# with safe_open(os.path.join(args.output_dir, "model.safetensors"), framework="pt") as f:
# param = f.get_tensor(VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(3, 0))
# trained_param = trained_medusa_model[TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(3, 0)]
# mse_value = torch.nn.functional.mse_loss(param.cpu(), trained_param.cpu())
# print("weight mes:", mse_value)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Medusa Model Evaluator")
parser.add_argument("--medusa_config_path", type=str, required=True,
help="Path to the medusa config file.")
parser.add_argument("--medusa_model_path", type=str, required=True,
help="Path to the medusa model file.")
parser.add_argument("--vocab_size", type=int, required=True,
help="Vocab size")
parser.add_argument("--medusa_num_heads", type=int, required=True,
help="Number of Medusa heads")
parser.add_argument("--medusa_num_layers", type=int, required=True,
help="Number of Medusa layers")
parser.add_argument("--hidden_size", type=int, required=True,
help="Hidden size")
parser.add_argument("--output_dir", type=str, required=True,
help="Output dir")
parser.add_argument(
'--medusa_choices',
type=str,
required=True,
help="Medusa choice to use, if not none, will use Medusa decoding."
" E.g.: [[0, 0, 0, 0], [0, 1, 0], [1, 0], [1, 1]] for 9 medusa tokens."
)
args = parser.parse_args()
main(args)
......@@ -1386,11 +1386,11 @@ class SpeculativeConfig:
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1")
# elif speculative_draft_tensor_parallel_size != 1:
# # TODO(wooyeon): allow tp values larger than 1
# raise ValueError(
# f"{speculative_draft_tensor_parallel_size=} cannot be "
# f"other value than 1")
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
......
......@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
SequenceStatus, VLLM_INVALID_TOKEN_ID)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
......@@ -945,8 +945,10 @@ class LLMEngine:
if len(outputs) > 1:
outputs_by_sequence_group = create_output_by_sequence_group(
outputs, num_seq_groups=len(seq_group_metadata_list))
else:
elif len(outputs) == 1:
outputs_by_sequence_group = outputs
else:
return None
# Determine the requests we need to operate on
if request_id:
......@@ -967,6 +969,7 @@ class LLMEngine:
finished_before: List[int] = []
finished_now: List[int] = []
empty_seq_indices: List[int] = []
for i in indices:
if i in skip:
continue
......@@ -985,6 +988,16 @@ class LLMEngine:
else:
output = [outputs_by_sequence_group[0][i]]
# tree style speculative decoding may generate empty output in first step
samples = [o.samples[0] for o in output]
valid_samples = [
sample for sample in samples
if sample.output_token != VLLM_INVALID_TOKEN_ID
]
if len(valid_samples) == 0:
empty_seq_indices.append(i)
continue
if not is_async:
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
......@@ -1056,7 +1069,7 @@ class LLMEngine:
# Create the outputs
for i in indices:
if i in skip or i in finished_before or i in finished_now:
if i in skip or i in finished_before or i in finished_now or i in empty_seq_indices:
continue # Avoids double processing
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
......
......@@ -131,7 +131,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]
self.device = _get_lora_device(self.base_layer)
def create_lora_weights(
self,
......@@ -208,18 +207,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
self.lora_a = lora_a.to(self.device)
self.lora_b = lora_b.to(self.device)
if self.lora_config.merge_lora:
merged_weights = torch.matmul(self.lora_a, self.lora_b)
if merged_weights.shape != self.base_layer.weight.data:
merged_weights = merged_weights.T + self.base_layer.weight
else:
merged_weights = merged_weights + self.base_layer.weight
self.base_layer.weight.data.copy_(merged_weights)
if embeddings_tensor is not None:
self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor.
......@@ -238,15 +225,12 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
embeddings_indices = self.punica_wrapper.embeddings_indices
indices = embeddings_indices[0].view_as(x)
if not self.lora_config.merge_lora:
indices_0 = embeddings_indices[1].view_as(x)
indices = embeddings_indices[1].view_as(x)
full_lora_a_embeddings = F.embedding(
x + indices_0,
x + indices,
self.lora_a_stacked_2d,
)
indices = embeddings_indices[0].view_as(x)
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
......@@ -267,10 +251,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked,
add_input=True)
return full_output.view_as(full_output_org)
else:
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
return full_output
@classmethod
def can_replace_layer(
......@@ -729,7 +709,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_a_output_size_per_partition = (
self.lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
# q, k, v
......@@ -737,7 +717,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
......@@ -745,7 +725,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
......@@ -753,7 +733,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
......@@ -869,6 +849,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
qkv_weights = qkv_weights + self.base_layer.weight.data
self.base_layer.weight.data.copy_(qkv_weights)
qkv_weights_list.clear()
else:
if lora_b[0] is not None:
lora_b_q = lora_b[0]
......
......@@ -5,6 +5,9 @@ import torch.nn.functional as F
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeDeterministicBaseSampler)
from vllm.logger import init_logger
logger = init_logger(__name__)
class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
......@@ -195,7 +198,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
k = draft_token_ids.shape[-1]
output_token_id_list = []
print("####################################accept_length:", accept_length)
logger.info("accept_length:%s", accept_length)
for i in range(batch_size):
output_best_candidates.append(best_candidate[i])
accept_lengths.append(accept_length[i])
......
......@@ -12,6 +12,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm import _custom_ops as ops
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
......@@ -22,7 +24,7 @@ class ResidualBlock(nn.Module):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size, bias=False)
nn.Linear(hidden_size, hidden_size)
for _ in range(num_layers)
])
self.act = nn.SiLU()
......@@ -89,12 +91,13 @@ class Medusa(nn.Module):
return [block(hidden_states) for block in self.blocks]
def compute_logits(
self, hidden_states: List[torch.Tensor],
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
logits_lst: List[torch.Tensor] = []
for hs, lm_head in zip(hidden_states, self.lm_heads):
_logits = self.logits_processor(lm_head, hs, sampling_metadata)
#_logits = self.logits_processor(lm_head, hs, sampling_metadata)
_logits = lm_head.linear_method.apply(lm_head, hs, bias=None)
_logits = tensor_model_parallel_all_gather(_logits)
if _logits is None:
# _logits should only be None on rank > 0, in which case
......@@ -117,7 +120,7 @@ class Medusa(nn.Module):
def sample(
self,
logits: List[torch.Tensor],
sampling_metadata: SamplingMetadata,
sample_indices_list: List[List[int]],
) -> List[SamplerOutput]:
logits = torch.stack(logits, dim=0).float()
logprobs = torch.log_softmax(logits, dim=-1)
......@@ -128,13 +131,13 @@ class Medusa(nn.Module):
token_prob_list = []
token_logprob_list = []
for idx, seq_group in enumerate(sampling_metadata.seq_groups):
token_id_list.append(token_ids[:, seq_group.sample_indices])
token_prob_list.append(probs[:, seq_group.sample_indices])
token_logprob_list.append(logprobs[:, seq_group.sample_indices])
for idx, sample_indices in enumerate(sample_indices_list):
token_id_list.append(token_ids[:, sample_indices])
token_prob_list.append(probs[:, sample_indices])
token_logprob_list.append(logprobs[:, sample_indices])
outputs: List[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)):
for idx in range(len(sample_indices_list)):
outputs.append(
SamplerOutput(
outputs=None,
......@@ -148,7 +151,7 @@ class Medusa(nn.Module):
def medusa_sample(
self,
medusa_logits: List[torch.Tensor],
sampling_metadata: SamplingMetadata,
sample_indices_list: List[List[int]],
logits: torch.Tensor,
medusa_buffers: Dict[str, Any]
) -> List[SamplerOutput]:
......@@ -178,12 +181,12 @@ class Medusa(nn.Module):
token_id_list = []
cart_candidate_list = []
for idx, seq_group in enumerate(sampling_metadata.seq_groups):
token_id_list.append(tree_candidates[seq_group.sample_indices, :])
cart_candidate_list.append(cart_candidates[seq_group.sample_indices, :])
for sample_indices in sample_indices_list:
token_id_list.append(tree_candidates[sample_indices, :])
cart_candidate_list.append(cart_candidates[sample_indices, :])
outputs: List[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)):
for idx in range(len(sample_indices_list)):
outputs.append(
SamplerOutput(
outputs=None,
......@@ -196,24 +199,22 @@ class Medusa(nn.Module):
def generate_proposals(
self,
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
sample_indices_list: List[List[int]],
previous_logits: torch.Tensor=None,
medusa_buffers: Dict[str, Any]=None
) -> List[SamplerOutput]:
if medusa_buffers is None:
return self.sample(
logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),
sampling_metadata=sampling_metadata,
hidden_states=self.forward(previous_hidden_states)
),
sampling_metadata=sampling_metadata,
sample_indices_list=sample_indices_list,
)
else:
return self.medusa_sample(medusa_logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),
sampling_metadata=sampling_metadata,
hidden_states=self.forward(previous_hidden_states)
),
sampling_metadata=sampling_metadata,
sample_indices_list=sample_indices_list,
logits=previous_logits,
medusa_buffers=medusa_buffers)
......
......@@ -172,6 +172,8 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None
_first_step_flag: bool = True
@staticmethod
def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
if len(token_counts) == 0:
......@@ -351,6 +353,12 @@ class SequenceData(msgspec.Struct,
def stage(self) -> SequenceStage:
return self._stage
def get_first_step_flag(self):
return self._first_step_flag
def set_first_step_flag(self, flag: bool):
self._first_step_flag = flag
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, "
......@@ -1371,6 +1379,9 @@ class ExecuteModelRequest(
# Optional tree position ids from draft model.
tree_position_ids: Optional[torch.Tensor] = None
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
@property
def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of
......@@ -1419,4 +1430,5 @@ class ExecuteModelRequest(
if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback,
tree_attn_masks=self.tree_attn_masks,
tree_position_ids=self.tree_position_ids)
tree_position_ids=self.tree_position_ids,
kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved)
......@@ -592,7 +592,7 @@ class BatchExpansionTreeStyleScorer(BatchExpansionTop1Scorer):
prompt_token_ids = seq_data.prompt_token_ids_array
# first step need to ignore output token generated by prefill phase
if len(seq_data.get_output_token_ids()) == 1:
if seq_data.get_first_step_flag():
new_output_token_ids = [*seq_data.get_output_token_ids()[:-1], *token_ids]
else:
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
......
import weakref
from typing import List, Optional, Set, Tuple
from typing import List, Optional, Set, Tuple, Dict
import torch
import torch.nn.functional as F
......@@ -12,6 +12,7 @@ from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
from vllm.worker.worker import Worker
from vllm.distributed import broadcast_tensor_dict
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
......@@ -66,6 +67,51 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
def set_should_modify_greedy_probs_inplace(self):
pass
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Dict[str, torch.Tensor]:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
sample_indices_list = []
for seq_group in sampling_metadata.seq_groups:
sample_indices_list.append(seq_group.sample_indices)
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
previous_logits = execute_model_req.previous_logits.logits if \
execute_model_req.previous_logits is not None else None
tensor_dict = {
"previous_hidden_states": previous_hidden_states,
"previous_logits": previous_logits,
"sample_indices_list": sample_indices_list,
"seq_lens": seq_lens
}
if self.do_metadata_broadcast:
broadcast_tensor_dict(tensor_dict, src=0)
return tensor_dict
def _get_worker_input_from_broadcast(
self
) -> Optional[Dict[str, torch.Tensor]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
return broadcast_data
@torch.inference_mode()
def sampler_output(
self,
......@@ -83,26 +129,22 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
"""
self._raise_if_unsupported(execute_model_req)
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
if self.is_driver_worker:
tensor_dict = self._get_driver_input_and_broadcast(execute_model_req)
else:
tensor_dict = self._get_worker_input_from_broadcast()
if tensor_dict is None:
raise ValueError("Can not get inputs of medusa worker!!!")
model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states.
hidden_states,
sampling_metadata=sampling_metadata,
previous_logits=execute_model_req.previous_logits.logits if execute_model_req.previous_logits is not None else None,
previous_hidden_states=tensor_dict["previous_hidden_states"],
sample_indices_list=tensor_dict["sample_indices_list"],
previous_logits=tensor_dict["previous_logits"],
medusa_buffers=self.medusa_buffers)
# create tree attn masks
if self.medusa_buffers is not None:
seq_lens = tensor_dict["seq_lens"]
max_context_len = max(seq_lens)
for sampler_output, seq_len in zip(model_outputs, seq_lens):
context_len = seq_len
......@@ -142,7 +184,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
query_lens.append(seq_len - context_len)
else:
# first step of tree decoding need to ignore first token
if self.medusa_buffers is not None and seq_data.get_output_len() == 1:
if self.medusa_buffers is not None and seq_data.get_first_step_flag():
seq_data_len -= 1
seq_lens.append(seq_data_len)
query_lens.append(1)
......@@ -168,6 +210,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if execute_model_req is None:
return None
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
......
......@@ -260,6 +260,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None
self.previous_logits: Optional[Logits] = None
self.kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats
......@@ -600,6 +601,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.execute_model()
if not data["disable_all_speculation"]:
if not self.tree_style_spec_decoding:
# Even if num_lookahead_slots is zero, we want to run the
# proposer model as it may have KV.
#
......@@ -607,6 +609,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# should delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model()
else:
if not data["no_spec"]:
self.proposer_worker.sampler_output(None, None, None)
if not data["no_spec"]:
self.scorer_worker.execute_model()
......@@ -635,6 +640,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req.previous_logits = self.previous_logits
self.previous_logits = None
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
self.kvcache_slot_to_be_moved = None
with Timer() as proposal_timer:
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
......@@ -743,7 +751,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
first_step_flags = []
for i, sgm in enumerate(seq_group_metadata_list):
seq = next(iter(sgm.seq_data.values()))
first_step_flags.append(True if seq.get_output_len() == 1 else False)
first_step_flags.append(True if seq.get_first_step_flag() else False)
sampler_extra_kwargs["first_step_flags"] = first_step_flags
......@@ -847,8 +855,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
seq_len = min(seq_len, context_len + token_chunk_size)
# first step of tree-style decoding need to ignore first generated token
if seq_data.get_output_len() == 1:
if seq_data.get_first_step_flag():
seq_len -= 1
# move cache is the last step of tree decoding, so set first_step_flag to false
seq_data.set_first_step_flag(False)
seq_lens.append(seq_len)
model_input = self.scorer._scorer_worker.model_input
......@@ -888,12 +899,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
dtype=torch.long,
device=self.device).view(-1, 1)
src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2]
kv_caches = self.scorer._scorer_worker.kv_cache[execute_model_req.virtual_engine]
kv_cache_dtype = cache_engine.cache_config.cache_dtype
backend = cache_engine.attn_backend
num_kv_heads = cache_engine.num_kv_heads
head_size = cache_engine.head_size
backend.move_cache(kv_caches, src_dst_tensor, kv_cache_dtype, num_kv_heads, head_size)
# kv_caches = self.scorer._scorer_worker.kv_cache[execute_model_req.virtual_engine]
# kv_cache_dtype = cache_engine.cache_config.cache_dtype
# backend = cache_engine.attn_backend
# num_kv_heads = cache_engine.num_kv_heads
# head_size = cache_engine.head_size
# backend.move_cache(kv_caches, src_dst_tensor, kv_cache_dtype, num_kv_heads*4, head_size)
self.kvcache_slot_to_be_moved = src_dst_tensor
def compute_slot_mapping(self, slot_mapping: List[int],
seq_id: int, select_indices: List[int], block_size: int,
......@@ -996,7 +1008,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids_by_step)
maybe_rejsample_metrics = (
self._metrics.maybe_collect_rejsample_metrics(k))
if maybe_rejsample_metrics is not None:
if maybe_rejsample_metrics is not None and sampler_output_list:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics
......
......@@ -115,7 +115,7 @@ class TreeStyleProposer(SpeculativeProposer):
tree_position_ids_list = []
for seq_group_metadata in seq_group_metadata_list:
seq_data = next(iter(seq_group_metadata.seq_data.values()))
if seq_data.get_output_len() == 1:
if seq_data.get_first_step_flag():
seq_len = seq_data.get_len() -1
else:
seq_len = seq_data.get_len()
......
......@@ -101,6 +101,14 @@ class CacheEngine:
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
def move_caches(self, kv_caches: List[torch.Tensor],
src_to_dsts: torch.Tensor) -> None:
self.attn_backend.move_cache(kv_caches,
src_to_dsts,
self.cache_config.cache_dtype,
self.num_kv_heads,
self.head_size)
@staticmethod
def get_cache_block_size(
cache_config: CacheConfig,
......
......@@ -509,20 +509,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else:
inter_data.input_tokens[seq_idx].append(tokens)
if (seq_len - context_len) == 1:
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
if seq_group_metadata.tree_position_ids is None:
if (seq_len - context_len) == 1:
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
else:
inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.cpu().tolist()
if seq_group_metadata.tree_position_ids is not None:
inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.contiguous().cpu().tolist()
inter_data.tree_attn_masks[seq_idx] = seq_group_metadata.tree_attn_masks
inter_data.query_lens[
......@@ -871,6 +861,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
tree_attention_masks_tensor = None
if len(tree_attention_masks_list) > 0:
tree_attention_masks_tensor = torch.stack(tree_attention_masks_list, dim=0)
tree_attention_masks_tensor = tree_attention_masks_tensor.contiguous()
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
......
......@@ -324,6 +324,7 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
kvcache_slot_to_be_moved=execute_model_req.kvcache_slot_to_be_moved
)
@torch.inference_mode()
......@@ -342,6 +343,11 @@ class Worker(LocalOrDistributedWorkerBase):
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
# tree-style generation need to move kvcache to correct position
if worker_input.kvcache_slot_to_be_moved is not None:
self.cache_engine[virtual_engine].move_caches(self.kv_cache[virtual_engine],
worker_input.kvcache_slot_to_be_moved)
def _get_cached_seq_group_metadata(
self,
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
......
......@@ -161,6 +161,9 @@ class WorkerInput:
virtual_engine: int = 0
num_steps: int = 1
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
......@@ -177,6 +180,7 @@ class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
kvcache_slot_to_be_moved=tensor_dict.pop("kvcache_slot_to_be_moved"),
)
def as_broadcastable_tensor_dict(
......@@ -191,6 +195,7 @@ class WorkerInput:
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
"kvcache_slot_to_be_moved": self.kvcache_slot_to_be_moved
}
return tensor_dict
......@@ -281,8 +286,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
# set tree_attn_masks and position ids to seq_group_metadata_list
if execute_model_req.tree_attn_masks is not None:
for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list):
# seq_group_metadata.tree_attn_masks = execute_model_req.tree_attn_masks[i]
# seq_group_metadata.tree_position_ids = execute_model_req.tree_position_ids[i]
seq_group_metadata.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks[i],
tree_position_ids=execute_model_req.tree_position_ids[i])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment