Commit f9060e6b authored by 王敏's avatar 王敏
Browse files

1.解决medusa模型多卡推理结果异常问题

2.examples中添加medusa readme
3.修复model_runner中input_positions配置错误的笔误,解决多个模型运行失败问题
parent 3f42b83d
...@@ -297,18 +297,19 @@ __device__ void paged_attention_kernel( ...@@ -297,18 +297,19 @@ __device__ void paged_attention_kernel(
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
bool mask = token_idx >= seq_len;
// used for tree-style attention // used for tree-style attention
if (attn_masks != nullptr) { if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride; const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0; if (attn_masks_ptr[token_idx] == 0) {
qk = -FLT_MAX;
}
} }
logits[token_idx - start_token_idx] = mask ? 0.f : qk; if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f: qk;
// Update the max value. // Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
} }
......
...@@ -327,17 +327,26 @@ __device__ void paged_attention_kernel_opt( ...@@ -327,17 +327,26 @@ __device__ void paged_attention_kernel_opt(
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
// used for tree-style attention
if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
if (attn_masks_ptr[token_idx] == 0) {
qk = -FLT_MAX;
}
}
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
bool mask = token_idx >= seq_len; const bool mask = token_idx >= seq_len;
// used for tree-style attention // used for tree-style attention
if (attn_masks != nullptr) { /*if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride; const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0; mask |= attn_masks_ptr[token_idx] == 0;
} }*/
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk; logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk); qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
......
...@@ -290,15 +290,18 @@ __device__ void paged_attention_kernel_TC( ...@@ -290,15 +290,18 @@ __device__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE+rowid; const int token_idx = block_idx * BLOCK_SIZE+rowid;
if(alibi_slope[i] != 0){ if(alibi_slope[i] != 0){
float alibi=alibi_slope[i]* (token_idx - seq_len + 1); float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
qk_vec[i] += alibi; qk_vec[i] = alibi;
} }
bool mask = (token_idx >= seq_len);
// used for tree-style attention // used for tree-style attention
if (attn_masks != nullptr) { if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride; const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0; if (attn_masks_ptr[token_idx] == 0) {
qk_vec[i] = -FLT_MAX;
} }
}
const bool mask = (token_idx >= seq_len);
if(mask){ if(mask){
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f); from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
} }
......
# Medusa Decoding
本文说明如何使用vllm构建和运行medusa模型,目前medusa支持tree-style generation,target model和draft model均可多卡推理
## Overview
与其他模型不同,medusa解码需要一个base model和若干Medusa heads.
Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
## Support Matrix
* FP16
* BF16
* PAGED_KV_CACHE
* Tensor Parallel
### convert Medusa model weights
# medusa 模型需要转换为vllm中Medusa的模型格式
```bash
python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/medusa/qwen2_72b_head_4/adapter_model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/sugon/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0,), (0, 0), (0, 0, 0), (0, 0, 0, 0), (0, 1), (1,), (1, 0), (0, 0, 1), (0, 1, 0), (0, 2), (1, 0, 0), (2,), (2, 0), (0, 3), (0, 0, 2), (0, 2, 0), (0, 4), (0, 0, 1, 0), (0, 1, 0, 0), (2, 0, 0), (3,), (0, 5), (0, 0, 0, 1), (3, 0), (0, 0, 3), (1, 0, 0, 0), (0, 3, 0), (0, 6), (0, 0, 4), (0, 4, 0), (1, 1), (4,)]"
```
此处qwen2_72b_head_4是medusa模型使用peft lora训练后保存的权重,其他格式也可参考[medusa_weight_converter.py]修改进行权重转换
### Run
```bash
python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.7 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 4 \
--use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \
--enforce-eager --dtype float16 --trust-remote-code --port 8086\
--enable-lora --lora-modules medusa-lora=/work/qwen2_72b_head_4 \
--max-lora-rank 32 --lora-extra-vocab-size 0 --merge-lora True \
--lora-target-modules qkv_proj \
--tree-style-spec-decoding True\
--num-speculative-heads 4 --num-speculative-tokens 33
```
merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) + 2
# do request
```bash
curl http://localhost:8086/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "medusa-lora",
"prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n",
"max_tokens": 256,
"temperature": 0.0
}'
```bash
import os import os
import ast
from pathlib import Path from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
from addict import Dict from addict import Dict
...@@ -21,12 +22,12 @@ DEFAULT_VOCAB_PADDING_SIZE = 64 ...@@ -21,12 +22,12 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.weight' TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'base_model.model.medusa_head.{}.1.weight' TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'base_model.model.medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.bias'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight' VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight'
VLLM_BLOCK_BIAS_NAME_TEMPLATE = 'blocks.{}.layers.{}.bias'
VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE = 'lm_heads.{}.weight' VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE = 'lm_heads.{}.weight'
MEDUSA_CHOICES = [(0,), (0, 0), (0, 0, 0), (1,), (0, 1), (1, 0), (0, 2), (2,), (0, 0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 3), (2, 0), (1, 0, 0), (3,), (0, 0, 2), (0, 4), (0, 2, 0), (0, 5), (4,), (1, 1), (0, 0, 3), (3, 0), (0, 6), (0, 0, 0, 1), (0, 3, 0), (0, 0, 4), (0, 0, 1, 0), (2, 0, 0), (5,), (0, 1, 0, 0), (0, 7)]
def default_weight_loader(param: torch.Tensor, def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None: loaded_weight: torch.Tensor) -> None:
...@@ -218,7 +219,7 @@ class ResidualBlock(nn.Module): ...@@ -218,7 +219,7 @@ class ResidualBlock(nn.Module):
super().__init__() super().__init__()
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size, bias=False) nn.Linear(hidden_size, hidden_size)
for _ in range(num_layers) for _ in range(num_layers)
]) ])
self.act = nn.SiLU() self.act = nn.SiLU()
...@@ -319,13 +320,8 @@ class CustomMedusaConfig(PretrainedConfig): ...@@ -319,13 +320,8 @@ class CustomMedusaConfig(PretrainedConfig):
def main(args): def main(args):
# load the medusa config from the yaml file medusa_head_num = args.medusa_num_heads
medusa_config_path=args.medusa_config_path medusa_num_layers = args.medusa_num_layers
with open(medusa_config_path, encoding="utf-8") as file:
medusa_cfg: Dict = Dict(yaml.safe_load(file))
medusa_head_num = medusa_cfg.medusa_num_heads
medusa_num_layers = medusa_cfg.medusa_num_layers
config = MedusaConfig(hidden_size=args.hidden_size, vocab_size=args.vocab_size, num_heads=medusa_head_num) config = MedusaConfig(hidden_size=args.hidden_size, vocab_size=args.vocab_size, num_heads=medusa_head_num)
medusa_model = Medusa(config) medusa_model = Medusa(config)
...@@ -346,6 +342,7 @@ def main(args): ...@@ -346,6 +342,7 @@ def main(args):
for i in range(medusa_head_num): for i in range(medusa_head_num):
for j in range(medusa_num_layers): for j in range(medusa_num_layers):
# load linear weight
vllm_medusa_block_weight_name = VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j) vllm_medusa_block_weight_name = VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j)
trained_medusa_block_weight_name = TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j) trained_medusa_block_weight_name = TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j)
...@@ -356,38 +353,59 @@ def main(args): ...@@ -356,38 +353,59 @@ def main(args):
default_weight_loader) default_weight_loader)
weight_loader(vllm_medusa_block_param, trained_medusa_block_param) weight_loader(vllm_medusa_block_param, trained_medusa_block_param)
# load linear bias
vllm_medusa_block_bias_name = VLLM_BLOCK_BIAS_NAME_TEMPLATE.format(i, j)
trained_medusa_block_bias_name = TRAINED_BLOCK_BIAS_NAME_TEMPLATE.format(i, j)
vllm_medusa_block_bias_param = params_dict[vllm_medusa_block_bias_name]
trained_medusa_block_bias_param = trained_medusa_model[trained_medusa_block_bias_name]
weight_loader = getattr(vllm_medusa_block_bias_param, "weight_loader",
default_weight_loader)
weight_loader(vllm_medusa_block_bias_param, trained_medusa_block_bias_param)
if not Path(args.output_dir).is_dir(): if not Path(args.output_dir).is_dir():
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
save_model(medusa_model, os.path.join(args.output_dir, "model.safetensors")) save_model(medusa_model, os.path.join(args.output_dir, "model.safetensors"))
medusa_choices = ast.literal_eval(args.medusa_choices)
to_save_config = CustomMedusaConfig(name_or_path=os.path.join(args.output_dir, "config.json"), to_save_config = CustomMedusaConfig(name_or_path=os.path.join(args.output_dir, "config.json"),
hidden_size=args.hidden_size, hidden_size=args.hidden_size,
num_heads=medusa_head_num, num_heads=medusa_head_num,
num_hidden_layers=medusa_num_layers, num_hidden_layers=medusa_num_layers,
vocab_size=args.vocab_size, vocab_size=args.vocab_size,
medusa_choices=MEDUSA_CHOICES) medusa_choices=medusa_choices)
to_save_config.save_pretrained(args.output_dir) to_save_config.save_pretrained(args.output_dir)
# validate weight # validate weight
# with safe_open("model.safetensors", framework="pt") as f: # with safe_open(os.path.join(args.output_dir, "model.safetensors"), framework="pt") as f:
# param = f.get_tensor(VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(0, 0)) # param = f.get_tensor(VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(3, 0))
# trained_param = trained_medusa_model[TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(0, 0)] # trained_param = trained_medusa_model[TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(3, 0)]
# mse_value = torch.nn.functional.mse_loss(param.cpu(), trained_param.cpu()) # mse_value = torch.nn.functional.mse_loss(param.cpu(), trained_param.cpu())
# print("weight mes:", mse_value) # print("weight mes:", mse_value)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Medusa Model Evaluator") parser = argparse.ArgumentParser(description="Medusa Model Evaluator")
parser.add_argument("--medusa_config_path", type=str, required=True,
help="Path to the medusa config file.")
parser.add_argument("--medusa_model_path", type=str, required=True, parser.add_argument("--medusa_model_path", type=str, required=True,
help="Path to the medusa model file.") help="Path to the medusa model file.")
parser.add_argument("--vocab_size", type=int, required=True, parser.add_argument("--vocab_size", type=int, required=True,
help="Vocab size") help="Vocab size")
parser.add_argument("--medusa_num_heads", type=int, required=True,
help="Number of Medusa heads")
parser.add_argument("--medusa_num_layers", type=int, required=True,
help="Number of Medusa layers")
parser.add_argument("--hidden_size", type=int, required=True, parser.add_argument("--hidden_size", type=int, required=True,
help="Hidden size") help="Hidden size")
parser.add_argument("--output_dir", type=str, required=True, parser.add_argument("--output_dir", type=str, required=True,
help="Output dir") help="Output dir")
parser.add_argument(
'--medusa_choices',
type=str,
required=True,
help="Medusa choice to use, if not none, will use Medusa decoding."
" E.g.: [[0, 0, 0, 0], [0, 1, 0], [1, 0], [1, 1]] for 9 medusa tokens."
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -1386,11 +1386,11 @@ class SpeculativeConfig: ...@@ -1386,11 +1386,11 @@ class SpeculativeConfig:
else: else:
speculative_draft_tensor_parallel_size = \ speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1: # elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1 # # TODO(wooyeon): allow tp values larger than 1
raise ValueError( # raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be " # f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1") # f"other value than 1")
draft_parallel_config = ParallelConfig( draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config. pipeline_parallel_size=target_parallel_config.
......
...@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata, Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceStatus, VLLM_INVALID_TOKEN_ID)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
...@@ -945,8 +945,10 @@ class LLMEngine: ...@@ -945,8 +945,10 @@ class LLMEngine:
if len(outputs) > 1: if len(outputs) > 1:
outputs_by_sequence_group = create_output_by_sequence_group( outputs_by_sequence_group = create_output_by_sequence_group(
outputs, num_seq_groups=len(seq_group_metadata_list)) outputs, num_seq_groups=len(seq_group_metadata_list))
else: elif len(outputs) == 1:
outputs_by_sequence_group = outputs outputs_by_sequence_group = outputs
else:
return None
# Determine the requests we need to operate on # Determine the requests we need to operate on
if request_id: if request_id:
...@@ -967,6 +969,7 @@ class LLMEngine: ...@@ -967,6 +969,7 @@ class LLMEngine:
finished_before: List[int] = [] finished_before: List[int] = []
finished_now: List[int] = [] finished_now: List[int] = []
empty_seq_indices: List[int] = []
for i in indices: for i in indices:
if i in skip: if i in skip:
continue continue
...@@ -985,6 +988,16 @@ class LLMEngine: ...@@ -985,6 +988,16 @@ class LLMEngine:
else: else:
output = [outputs_by_sequence_group[0][i]] output = [outputs_by_sequence_group[0][i]]
# tree style speculative decoding may generate empty output in first step
samples = [o.samples[0] for o in output]
valid_samples = [
sample for sample in samples
if sample.output_token != VLLM_INVALID_TOKEN_ID
]
if len(valid_samples) == 0:
empty_seq_indices.append(i)
continue
if not is_async: if not is_async:
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size) scheduled_seq_group.token_chunk_size)
...@@ -1056,7 +1069,7 @@ class LLMEngine: ...@@ -1056,7 +1069,7 @@ class LLMEngine:
# Create the outputs # Create the outputs
for i in indices: for i in indices:
if i in skip or i in finished_before or i in finished_now: if i in skip or i in finished_before or i in finished_now or i in empty_seq_indices:
continue # Avoids double processing continue # Avoids double processing
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
......
...@@ -131,7 +131,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -131,7 +131,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.base_layer = base_layer self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]] self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor] self.embeddings_weights: Optional[torch.Tensor]
self.device = _get_lora_device(self.base_layer)
def create_lora_weights( def create_lora_weights(
self, self,
...@@ -208,18 +207,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -208,18 +207,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index, self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
self.lora_a = lora_a.to(self.device)
self.lora_b = lora_b.to(self.device)
if self.lora_config.merge_lora:
merged_weights = torch.matmul(self.lora_a, self.lora_b)
if merged_weights.shape != self.base_layer.weight.data:
merged_weights = merged_weights.T + self.base_layer.weight
else:
merged_weights = merged_weights + self.base_layer.weight
self.base_layer.weight.data.copy_(merged_weights)
if embeddings_tensor is not None: if embeddings_tensor is not None:
self.embeddings_tensors[ self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor. index, :embeddings_tensor.shape[0], :embeddings_tensor.
...@@ -238,15 +225,12 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -238,15 +225,12 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1 added_tokens_mask = x > self.base_layer.org_vocab_size - 1
embeddings_indices = self.punica_wrapper.embeddings_indices embeddings_indices = self.punica_wrapper.embeddings_indices
indices = embeddings_indices[0].view_as(x) indices = embeddings_indices[1].view_as(x)
if not self.lora_config.merge_lora:
indices_0 = embeddings_indices[1].view_as(x)
full_lora_a_embeddings = F.embedding( full_lora_a_embeddings = F.embedding(
x + indices_0, x + indices,
self.lora_a_stacked_2d, self.lora_a_stacked_2d,
) )
indices = embeddings_indices[0].view_as(x)
full_output = self.base_layer.forward( full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask)) x.add_(indices * added_tokens_mask))
...@@ -267,10 +251,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -267,10 +251,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked, self.lora_b_stacked,
add_input=True) add_input=True)
return full_output.view_as(full_output_org) return full_output.view_as(full_output_org)
else:
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
return full_output
@classmethod @classmethod
def can_replace_layer( def can_replace_layer(
...@@ -729,7 +709,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -729,7 +709,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.q_shard_id = self.tp_rank self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_a_output_size_per_partition = ( self.lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size)) else divide(lora_config.max_lora_rank, self.tp_size))
# q, k, v # q, k, v
...@@ -737,7 +717,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -737,7 +717,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_a_output_size_per_partition, self.lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
...@@ -745,7 +725,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -745,7 +725,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_a_output_size_per_partition, self.lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
...@@ -753,7 +733,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -753,7 +733,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_a_output_size_per_partition, self.lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
...@@ -869,6 +849,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -869,6 +849,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
qkv_weights = qkv_weights + self.base_layer.weight.data qkv_weights = qkv_weights + self.base_layer.weight.data
self.base_layer.weight.data.copy_(qkv_weights) self.base_layer.weight.data.copy_(qkv_weights)
qkv_weights_list.clear()
else: else:
if lora_b[0] is not None: if lora_b[0] is not None:
lora_b_q = lora_b[0] lora_b_q = lora_b[0]
......
...@@ -5,6 +5,9 @@ import torch.nn.functional as F ...@@ -5,6 +5,9 @@ import torch.nn.functional as F
from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeDeterministicBaseSampler) SpecDecodeDeterministicBaseSampler)
from vllm.logger import init_logger
logger = init_logger(__name__)
class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...@@ -195,7 +198,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -195,7 +198,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
k = draft_token_ids.shape[-1] k = draft_token_ids.shape[-1]
output_token_id_list = [] output_token_id_list = []
print("####################################accept_length:", accept_length) logger.info("accept_length:%s", accept_length)
for i in range(batch_size): for i in range(batch_size):
output_best_candidates.append(best_candidate[i]) output_best_candidates.append(best_candidate[i])
accept_lengths.append(accept_length[i]) accept_lengths.append(accept_length[i])
......
...@@ -12,6 +12,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -12,6 +12,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient) TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
...@@ -22,7 +24,7 @@ class ResidualBlock(nn.Module): ...@@ -22,7 +24,7 @@ class ResidualBlock(nn.Module):
super().__init__() super().__init__()
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size, bias=False) nn.Linear(hidden_size, hidden_size)
for _ in range(num_layers) for _ in range(num_layers)
]) ])
self.act = nn.SiLU() self.act = nn.SiLU()
...@@ -89,12 +91,13 @@ class Medusa(nn.Module): ...@@ -89,12 +91,13 @@ class Medusa(nn.Module):
return [block(hidden_states) for block in self.blocks] return [block(hidden_states) for block in self.blocks]
def compute_logits( def compute_logits(
self, hidden_states: List[torch.Tensor], self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
logits_lst: List[torch.Tensor] = [] logits_lst: List[torch.Tensor] = []
for hs, lm_head in zip(hidden_states, self.lm_heads): for hs, lm_head in zip(hidden_states, self.lm_heads):
_logits = self.logits_processor(lm_head, hs, sampling_metadata) #_logits = self.logits_processor(lm_head, hs, sampling_metadata)
_logits = lm_head.linear_method.apply(lm_head, hs, bias=None)
_logits = tensor_model_parallel_all_gather(_logits)
if _logits is None: if _logits is None:
# _logits should only be None on rank > 0, in which case # _logits should only be None on rank > 0, in which case
...@@ -117,7 +120,7 @@ class Medusa(nn.Module): ...@@ -117,7 +120,7 @@ class Medusa(nn.Module):
def sample( def sample(
self, self,
logits: List[torch.Tensor], logits: List[torch.Tensor],
sampling_metadata: SamplingMetadata, sample_indices_list: List[List[int]],
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
logits = torch.stack(logits, dim=0).float() logits = torch.stack(logits, dim=0).float()
logprobs = torch.log_softmax(logits, dim=-1) logprobs = torch.log_softmax(logits, dim=-1)
...@@ -128,13 +131,13 @@ class Medusa(nn.Module): ...@@ -128,13 +131,13 @@ class Medusa(nn.Module):
token_prob_list = [] token_prob_list = []
token_logprob_list = [] token_logprob_list = []
for idx, seq_group in enumerate(sampling_metadata.seq_groups): for idx, sample_indices in enumerate(sample_indices_list):
token_id_list.append(token_ids[:, seq_group.sample_indices]) token_id_list.append(token_ids[:, sample_indices])
token_prob_list.append(probs[:, seq_group.sample_indices]) token_prob_list.append(probs[:, sample_indices])
token_logprob_list.append(logprobs[:, seq_group.sample_indices]) token_logprob_list.append(logprobs[:, sample_indices])
outputs: List[Optional[SamplerOutput]] = [] outputs: List[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)): for idx in range(len(sample_indices_list)):
outputs.append( outputs.append(
SamplerOutput( SamplerOutput(
outputs=None, outputs=None,
...@@ -148,7 +151,7 @@ class Medusa(nn.Module): ...@@ -148,7 +151,7 @@ class Medusa(nn.Module):
def medusa_sample( def medusa_sample(
self, self,
medusa_logits: List[torch.Tensor], medusa_logits: List[torch.Tensor],
sampling_metadata: SamplingMetadata, sample_indices_list: List[List[int]],
logits: torch.Tensor, logits: torch.Tensor,
medusa_buffers: Dict[str, Any] medusa_buffers: Dict[str, Any]
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
...@@ -178,12 +181,12 @@ class Medusa(nn.Module): ...@@ -178,12 +181,12 @@ class Medusa(nn.Module):
token_id_list = [] token_id_list = []
cart_candidate_list = [] cart_candidate_list = []
for idx, seq_group in enumerate(sampling_metadata.seq_groups): for sample_indices in sample_indices_list:
token_id_list.append(tree_candidates[seq_group.sample_indices, :]) token_id_list.append(tree_candidates[sample_indices, :])
cart_candidate_list.append(cart_candidates[seq_group.sample_indices, :]) cart_candidate_list.append(cart_candidates[sample_indices, :])
outputs: List[Optional[SamplerOutput]] = [] outputs: List[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)): for idx in range(len(sample_indices_list)):
outputs.append( outputs.append(
SamplerOutput( SamplerOutput(
outputs=None, outputs=None,
...@@ -196,24 +199,22 @@ class Medusa(nn.Module): ...@@ -196,24 +199,22 @@ class Medusa(nn.Module):
def generate_proposals( def generate_proposals(
self, self,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sample_indices_list: List[List[int]],
previous_logits: torch.Tensor=None, previous_logits: torch.Tensor=None,
medusa_buffers: Dict[str, Any]=None medusa_buffers: Dict[str, Any]=None
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
if medusa_buffers is None: if medusa_buffers is None:
return self.sample( return self.sample(
logits=self.compute_logits( logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states), hidden_states=self.forward(previous_hidden_states)
sampling_metadata=sampling_metadata,
), ),
sampling_metadata=sampling_metadata, sample_indices_list=sample_indices_list,
) )
else: else:
return self.medusa_sample(medusa_logits=self.compute_logits( return self.medusa_sample(medusa_logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states), hidden_states=self.forward(previous_hidden_states)
sampling_metadata=sampling_metadata,
), ),
sampling_metadata=sampling_metadata, sample_indices_list=sample_indices_list,
logits=previous_logits, logits=previous_logits,
medusa_buffers=medusa_buffers) medusa_buffers=medusa_buffers)
......
...@@ -172,6 +172,8 @@ class SequenceData(msgspec.Struct, ...@@ -172,6 +172,8 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids. # It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None _mrope_position_delta: Optional[int] = None
_first_step_flag: bool = True
@staticmethod @staticmethod
def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData": def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
if len(token_counts) == 0: if len(token_counts) == 0:
...@@ -351,6 +353,12 @@ class SequenceData(msgspec.Struct, ...@@ -351,6 +353,12 @@ class SequenceData(msgspec.Struct,
def stage(self) -> SequenceStage: def stage(self) -> SequenceStage:
return self._stage return self._stage
def get_first_step_flag(self):
return self._first_step_flag
def set_first_step_flag(self, flag: bool):
self._first_step_flag = flag
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceData(" return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, " f"prompt_token_ids={self._prompt_token_ids}, "
...@@ -1371,6 +1379,9 @@ class ExecuteModelRequest( ...@@ -1371,6 +1379,9 @@ class ExecuteModelRequest(
# Optional tree position ids from draft model. # Optional tree position ids from draft model.
tree_position_ids: Optional[torch.Tensor] = None tree_position_ids: Optional[torch.Tensor] = None
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
@property @property
def is_first_multi_step(self) -> bool: def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of # TODO(will) make this be able to handle batches with variable number of
...@@ -1419,4 +1430,5 @@ class ExecuteModelRequest( ...@@ -1419,4 +1430,5 @@ class ExecuteModelRequest(
if self.last_sampled_token_ids is not None else None, if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback, async_callback=self.async_callback,
tree_attn_masks=self.tree_attn_masks, tree_attn_masks=self.tree_attn_masks,
tree_position_ids=self.tree_position_ids) tree_position_ids=self.tree_position_ids,
kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved)
...@@ -592,7 +592,7 @@ class BatchExpansionTreeStyleScorer(BatchExpansionTop1Scorer): ...@@ -592,7 +592,7 @@ class BatchExpansionTreeStyleScorer(BatchExpansionTop1Scorer):
prompt_token_ids = seq_data.prompt_token_ids_array prompt_token_ids = seq_data.prompt_token_ids_array
# first step need to ignore output token generated by prefill phase # first step need to ignore output token generated by prefill phase
if len(seq_data.get_output_token_ids()) == 1: if seq_data.get_first_step_flag():
new_output_token_ids = [*seq_data.get_output_token_ids()[:-1], *token_ids] new_output_token_ids = [*seq_data.get_output_token_ids()[:-1], *token_ids]
else: else:
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
......
import weakref import weakref
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple, Dict
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -12,6 +12,7 @@ from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase ...@@ -12,6 +12,7 @@ from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from vllm.distributed import broadcast_tensor_dict
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient) TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
...@@ -66,6 +67,51 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -66,6 +67,51 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
def set_should_modify_greedy_probs_inplace(self): def set_should_modify_greedy_probs_inplace(self):
pass pass
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Dict[str, torch.Tensor]:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
sample_indices_list = []
for seq_group in sampling_metadata.seq_groups:
sample_indices_list.append(seq_group.sample_indices)
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
previous_logits = execute_model_req.previous_logits.logits if \
execute_model_req.previous_logits is not None else None
tensor_dict = {
"previous_hidden_states": previous_hidden_states,
"previous_logits": previous_logits,
"sample_indices_list": sample_indices_list,
"seq_lens": seq_lens
}
if self.do_metadata_broadcast:
broadcast_tensor_dict(tensor_dict, src=0)
return tensor_dict
def _get_worker_input_from_broadcast(
self
) -> Optional[Dict[str, torch.Tensor]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
return broadcast_data
@torch.inference_mode() @torch.inference_mode()
def sampler_output( def sampler_output(
self, self,
...@@ -83,26 +129,22 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -83,26 +129,22 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
""" """
self._raise_if_unsupported(execute_model_req) self._raise_if_unsupported(execute_model_req)
seq_group_metadata_list = execute_model_req.seq_group_metadata_list if self.is_driver_worker:
tensor_dict = self._get_driver_input_and_broadcast(execute_model_req)
seq_lens, query_lens = self._prepare_input_tensors( else:
seq_group_metadata_list) tensor_dict = self._get_worker_input_from_broadcast()
if tensor_dict is None:
generators = self.model_runner.get_generators( raise ValueError("Can not get inputs of medusa worker!!!")
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
model_outputs = self.model_runner.model.generate_proposals( model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states. previous_hidden_states=tensor_dict["previous_hidden_states"],
hidden_states, sample_indices_list=tensor_dict["sample_indices_list"],
sampling_metadata=sampling_metadata, previous_logits=tensor_dict["previous_logits"],
previous_logits=execute_model_req.previous_logits.logits if execute_model_req.previous_logits is not None else None,
medusa_buffers=self.medusa_buffers) medusa_buffers=self.medusa_buffers)
# create tree attn masks # create tree attn masks
if self.medusa_buffers is not None: if self.medusa_buffers is not None:
seq_lens = tensor_dict["seq_lens"]
max_context_len = max(seq_lens) max_context_len = max(seq_lens)
for sampler_output, seq_len in zip(model_outputs, seq_lens): for sampler_output, seq_len in zip(model_outputs, seq_lens):
context_len = seq_len context_len = seq_len
...@@ -142,7 +184,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -142,7 +184,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
query_lens.append(seq_len - context_len) query_lens.append(seq_len - context_len)
else: else:
# first step of tree decoding need to ignore first token # first step of tree decoding need to ignore first token
if self.medusa_buffers is not None and seq_data.get_output_len() == 1: if self.medusa_buffers is not None and seq_data.get_first_step_flag():
seq_data_len -= 1 seq_data_len -= 1
seq_lens.append(seq_data_len) seq_lens.append(seq_data_len)
query_lens.append(1) query_lens.append(1)
...@@ -168,6 +210,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -168,6 +210,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
"""MedusaWorker does not yet implement support for cache swap """MedusaWorker does not yet implement support for cache swap
operations or beam search. operations or beam search.
""" """
if execute_model_req is None:
return None
if any([ if any([
execute_model_req.blocks_to_swap_in, execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out, execute_model_req.blocks_to_swap_out,
......
...@@ -260,6 +260,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -260,6 +260,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# in the subsequent step. # in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None self.previous_hidden_states: Optional[HiddenStates] = None
self.previous_logits: Optional[Logits] = None self.previous_logits: Optional[Logits] = None
self.kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
self._disable_logprobs = disable_logprobs self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats self._disable_log_stats = disable_log_stats
...@@ -600,6 +601,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -600,6 +601,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.execute_model() self.scorer_worker.execute_model()
if not data["disable_all_speculation"]: if not data["disable_all_speculation"]:
if not self.tree_style_spec_decoding:
# Even if num_lookahead_slots is zero, we want to run the # Even if num_lookahead_slots is zero, we want to run the
# proposer model as it may have KV. # proposer model as it may have KV.
# #
...@@ -607,6 +609,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -607,6 +609,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# should delegate how many times it runs to the proposer. # should delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)): for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model() self.proposer_worker.execute_model()
else:
if not data["no_spec"]:
self.proposer_worker.sampler_output(None, None, None)
if not data["no_spec"]: if not data["no_spec"]:
self.scorer_worker.execute_model() self.scorer_worker.execute_model()
...@@ -635,6 +640,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -635,6 +640,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req.previous_logits = self.previous_logits execute_model_req.previous_logits = self.previous_logits
self.previous_logits = None self.previous_logits = None
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
self.kvcache_slot_to_be_moved = None
with Timer() as proposal_timer: with Timer() as proposal_timer:
# Generate proposals using draft worker. # Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals( proposals = self.proposer_worker.get_spec_proposals(
...@@ -743,7 +751,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -743,7 +751,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
first_step_flags = [] first_step_flags = []
for i, sgm in enumerate(seq_group_metadata_list): for i, sgm in enumerate(seq_group_metadata_list):
seq = next(iter(sgm.seq_data.values())) seq = next(iter(sgm.seq_data.values()))
first_step_flags.append(True if seq.get_output_len() == 1 else False) first_step_flags.append(True if seq.get_first_step_flag() else False)
sampler_extra_kwargs["first_step_flags"] = first_step_flags sampler_extra_kwargs["first_step_flags"] = first_step_flags
...@@ -847,8 +855,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -847,8 +855,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
seq_len = min(seq_len, context_len + token_chunk_size) seq_len = min(seq_len, context_len + token_chunk_size)
# first step of tree-style decoding need to ignore first generated token # first step of tree-style decoding need to ignore first generated token
if seq_data.get_output_len() == 1: if seq_data.get_first_step_flag():
seq_len -= 1 seq_len -= 1
# move cache is the last step of tree decoding, so set first_step_flag to false
seq_data.set_first_step_flag(False)
seq_lens.append(seq_len) seq_lens.append(seq_len)
model_input = self.scorer._scorer_worker.model_input model_input = self.scorer._scorer_worker.model_input
...@@ -888,12 +899,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -888,12 +899,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
dtype=torch.long, dtype=torch.long,
device=self.device).view(-1, 1) device=self.device).view(-1, 1)
src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2] src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2]
kv_caches = self.scorer._scorer_worker.kv_cache[execute_model_req.virtual_engine] # kv_caches = self.scorer._scorer_worker.kv_cache[execute_model_req.virtual_engine]
kv_cache_dtype = cache_engine.cache_config.cache_dtype # kv_cache_dtype = cache_engine.cache_config.cache_dtype
backend = cache_engine.attn_backend # backend = cache_engine.attn_backend
num_kv_heads = cache_engine.num_kv_heads # num_kv_heads = cache_engine.num_kv_heads
head_size = cache_engine.head_size # head_size = cache_engine.head_size
backend.move_cache(kv_caches, src_dst_tensor, kv_cache_dtype, num_kv_heads, head_size) # backend.move_cache(kv_caches, src_dst_tensor, kv_cache_dtype, num_kv_heads*4, head_size)
self.kvcache_slot_to_be_moved = src_dst_tensor
def compute_slot_mapping(self, slot_mapping: List[int], def compute_slot_mapping(self, slot_mapping: List[int],
seq_id: int, select_indices: List[int], block_size: int, seq_id: int, select_indices: List[int], block_size: int,
...@@ -996,7 +1008,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -996,7 +1008,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids_by_step) accepted_token_ids_by_step)
maybe_rejsample_metrics = ( maybe_rejsample_metrics = (
self._metrics.maybe_collect_rejsample_metrics(k)) self._metrics.maybe_collect_rejsample_metrics(k))
if maybe_rejsample_metrics is not None: if maybe_rejsample_metrics is not None and sampler_output_list:
sampler_output_list[ sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics 0].spec_decode_worker_metrics = maybe_rejsample_metrics
......
...@@ -115,7 +115,7 @@ class TreeStyleProposer(SpeculativeProposer): ...@@ -115,7 +115,7 @@ class TreeStyleProposer(SpeculativeProposer):
tree_position_ids_list = [] tree_position_ids_list = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
seq_data = next(iter(seq_group_metadata.seq_data.values())) seq_data = next(iter(seq_group_metadata.seq_data.values()))
if seq_data.get_output_len() == 1: if seq_data.get_first_step_flag():
seq_len = seq_data.get_len() -1 seq_len = seq_data.get_len() -1
else: else:
seq_len = seq_data.get_len() seq_len = seq_data.get_len()
......
...@@ -101,6 +101,14 @@ class CacheEngine: ...@@ -101,6 +101,14 @@ class CacheEngine:
def copy(self, src_to_dsts: torch.Tensor) -> None: def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
def move_caches(self, kv_caches: List[torch.Tensor],
src_to_dsts: torch.Tensor) -> None:
self.attn_backend.move_cache(kv_caches,
src_to_dsts,
self.cache_config.cache_dtype,
self.num_kv_heads,
self.head_size)
@staticmethod @staticmethod
def get_cache_block_size( def get_cache_block_size(
cache_config: CacheConfig, cache_config: CacheConfig,
......
...@@ -509,20 +509,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -509,20 +509,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else: else:
inter_data.input_tokens[seq_idx].append(tokens) inter_data.input_tokens[seq_idx].append(tokens)
if (seq_len - context_len) == 1: inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
if seq_group_metadata.tree_position_ids is None: if seq_group_metadata.tree_position_ids is not None:
if (seq_len - context_len) == 1: inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.contiguous().cpu().tolist()
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
else:
inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.cpu().tolist()
inter_data.tree_attn_masks[seq_idx] = seq_group_metadata.tree_attn_masks inter_data.tree_attn_masks[seq_idx] = seq_group_metadata.tree_attn_masks
inter_data.query_lens[ inter_data.query_lens[
...@@ -871,6 +861,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -871,6 +861,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
tree_attention_masks_tensor = None tree_attention_masks_tensor = None
if len(tree_attention_masks_list) > 0: if len(tree_attention_masks_list) > 0:
tree_attention_masks_tensor = torch.stack(tree_attention_masks_list, dim=0) tree_attention_masks_tensor = torch.stack(tree_attention_masks_list, dim=0)
tree_attention_masks_tensor = tree_attention_masks_tensor.contiguous()
# Attention metadata. # Attention metadata.
attn_metadata = self.attn_metadata_builder.build( attn_metadata = self.attn_metadata_builder.build(
......
...@@ -324,6 +324,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -324,6 +324,7 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
num_steps=num_steps, num_steps=num_steps,
kvcache_slot_to_be_moved=execute_model_req.kvcache_slot_to_be_moved
) )
@torch.inference_mode() @torch.inference_mode()
...@@ -342,6 +343,11 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -342,6 +343,11 @@ class Worker(LocalOrDistributedWorkerBase):
and worker_input.blocks_to_copy.numel() > 0): and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
# tree-style generation need to move kvcache to correct position
if worker_input.kvcache_slot_to_be_moved is not None:
self.cache_engine[virtual_engine].move_caches(self.kv_cache[virtual_engine],
worker_input.kvcache_slot_to_be_moved)
def _get_cached_seq_group_metadata( def _get_cached_seq_group_metadata(
self, self,
seq_group_metadata_list: List[Union[SequenceGroupMetadata, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
......
...@@ -161,6 +161,9 @@ class WorkerInput: ...@@ -161,6 +161,9 @@ class WorkerInput:
virtual_engine: int = 0 virtual_engine: int = 0
num_steps: int = 1 num_steps: int = 1
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
@classmethod @classmethod
def from_broadcasted_tensor_dict( def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"], cls: Type["WorkerInput"],
...@@ -177,6 +180,7 @@ class WorkerInput: ...@@ -177,6 +180,7 @@ class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"], virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"), num_steps=tensor_dict.pop("num_steps"),
kvcache_slot_to_be_moved=tensor_dict.pop("kvcache_slot_to_be_moved"),
) )
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
...@@ -191,6 +195,7 @@ class WorkerInput: ...@@ -191,6 +195,7 @@ class WorkerInput:
"blocks_to_copy": self.blocks_to_copy, "blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"num_steps": self.num_steps, "num_steps": self.num_steps,
"kvcache_slot_to_be_moved": self.kvcache_slot_to_be_moved
} }
return tensor_dict return tensor_dict
...@@ -281,8 +286,6 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -281,8 +286,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
# set tree_attn_masks and position ids to seq_group_metadata_list # set tree_attn_masks and position ids to seq_group_metadata_list
if execute_model_req.tree_attn_masks is not None: if execute_model_req.tree_attn_masks is not None:
for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list): for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list):
# seq_group_metadata.tree_attn_masks = execute_model_req.tree_attn_masks[i]
# seq_group_metadata.tree_position_ids = execute_model_req.tree_position_ids[i]
seq_group_metadata.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks[i], seq_group_metadata.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks[i],
tree_position_ids=execute_model_req.tree_position_ids[i]) tree_position_ids=execute_model_req.tree_position_ids[i])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment