Commit 201768d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'

优化medusa 推理

See merge request dcutoolkit/deeplearing/vllm!41
parents 87a2e37f 28375803
...@@ -304,7 +304,9 @@ def main(args: argparse.Namespace): ...@@ -304,7 +304,9 @@ def main(args: argparse.Namespace):
else: else:
batch_sizes = [args.batch_size] batch_sizes = [args.batch_size]
ray.init() ray.init(address=None,
ignore_reinit_error=True,
num_gpus=args.tp_size)
num_gpus = int(ray.available_resources()["GPU"]) num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
......
...@@ -17,9 +17,9 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py] ...@@ -17,9 +17,9 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
# medusa 模型需要转换为vllm中Medusa的模型格式 # medusa 模型需要转换为vllm中Medusa的模型格式
```bash ```bash
python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/medusa/qwen2_72b_head_4/adapter_model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/sugon/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0,), (0, 0), (0, 0, 0), (0, 0, 0, 0), (0, 1), (1,), (1, 0), (0, 0, 1), (0, 1, 0), (0, 2), (1, 0, 0), (2,), (2, 0), (0, 3), (0, 0, 2), (0, 2, 0), (0, 4), (0, 0, 1, 0), (0, 1, 0, 0), (2, 0, 0), (3,), (0, 5), (0, 0, 0, 1), (3, 0), (0, 0, 3), (1, 0, 0, 0), (0, 3, 0), (0, 6), (0, 0, 4), (0, 4, 0), (1, 1), (4,)]" python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0), (0, 0), (0, 0, 0), (0, 1), (1), (1, 0), (0, 0, 0, 0), (0, 0, 1), (0, 2), (0, 1, 0), (2), (0, 0, 2), (0, 3), (1, 0, 0), (2, 0), (0, 2, 0), (0, 4), (0, 0, 3), (3), (0, 0, 0, 1), (0, 5), (0, 0, 1, 0), (0, 0, 4)]"
``` ```
此处qwen2_72b_head_4是medusa模型使用peft lora训练后保存的权重,其他格式也可参考[medusa_weight_converter.py]修改进行权重转换 此处model.bin是训练后保存的medusa head权重
### Run ### Run
...@@ -28,30 +28,34 @@ python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --m ...@@ -28,30 +28,34 @@ python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --m
python3 -m vllm.entrypoints.openai.api_server \ python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \ --served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \ --model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.7 \ --max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \ --speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \ --speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 4 \ --speculative-disable-by-batch-size 4 \
--use-v2-block-manager \ --use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \ --spec-decoding-acceptance-method typical_acceptance_sampler \
--enforce-eager --dtype float16 --trust-remote-code --port 8086\ --dtype float16 --trust-remote-code --port 8086\
--enable-lora --lora-modules medusa-lora=/work/qwen2_72b_head_4 \
--max-lora-rank 32 --lora-extra-vocab-size 0 --merge-lora True \
--lora-target-modules qkv_proj \
--tree-style-spec-decoding True\ --tree-style-spec-decoding True\
--num-speculative-heads 4 --num-speculative-tokens 33 --num-speculative-heads 4 --num-speculative-tokens 24
``` ```
merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数 merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) + 2 num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) + 1
# do request # do request
```bash ```bash
curl http://localhost:8086/v1/completions \ curl http://localhost:8086/v1/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "medusa-lora", "model": "qwen_medusa",
"prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n", "prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n",
"max_tokens": 256, "max_tokens": 256,
"temperature": 0.0 "temperature": 0.0
}' }'
```bash ```
### benchmark
python medusa_benchmark_throughput.py --model /data/llm-models/qwen2/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 1 --dataset /work/test/medusa_benchmark_data.json --max-model-len 4096 --gpu-memory-utilization 0.9
可设置max-num-seqs对不同的batch进行性能测试
"""Benchmark offline inference throughput."""
import argparse
import json
import random
import time
from typing import List, Optional, Tuple
import numpy as np
import torch
import uvloop
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from vllm.inputs import PromptInputs
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
from vllm.lora.request import LoRARequest
def nullable_str(val: str):
if not val or val == "None":
return None
return val
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Only keep the first two turns of each conversation.
dataset = [data["prompt"] for data in dataset]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i]
prompt_token_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_token_ids)
output_len = fixed_output_len
filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
def run_vllm(
warmup_requests: List[Tuple[str, int, int]],
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
max_num_seqs: int = 8,
speculative_model: str=None,
speculative_draft_tensor_parallel_size: int = 1,
speculative_disable_by_batch_size: int = 4,
spec_decoding_acceptance_method: str = None,
enable_lora: bool = False,
max_lora_rank: int = 32,
merge_lora: bool = False,
lora_extra_vocab_size: int = 0,
lora_target_modules: List[str] = None,
tree_style_spec_decoding: bool = False,
num_speculative_heads: int = 5,
num_speculative_tokens: int = 64,
use_new_beam_search_impl: bool = False,
lora_modules: str = None
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
max_num_seqs=max_num_seqs,
speculative_model=speculative_model,
speculative_draft_tensor_parallel_size=speculative_draft_tensor_parallel_size,
speculative_disable_by_batch_size=speculative_disable_by_batch_size,
spec_decoding_acceptance_method=spec_decoding_acceptance_method,
enable_lora=enable_lora,
max_lora_rank=max_lora_rank,
merge_lora=merge_lora,
lora_extra_vocab_size=lora_extra_vocab_size,
lora_target_modules=lora_target_modules,
tree_style_spec_decoding=tree_style_spec_decoding,
num_speculative_heads=num_speculative_heads,
num_speculative_tokens=num_speculative_tokens
)
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=False,
max_tokens=output_len,
))
# warmup
warmup_prompts = []
warmup_sampling_params = []
for prompt, _, output_len in warmup_requests:
warmup_prompts.append(prompt)
warmup_sampling_params.append(
SamplingParams(
n=n,
temperature=0.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=False,
max_tokens=output_len,
))
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
if lora_modules is None:
llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True)
else:
llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True,
lora_request=LoRARequest("medusa-lora", 1, lora_modules))
total_out_tokens = 0
start = time.perf_counter()
if lora_modules is None:
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
else:
outputs = llm.generate(prompts, sampling_params, use_tqdm=False,
lora_request=LoRARequest("medusa-lora", 1, lora_modules))
for output in outputs:
print("token_ids len:{} text:{}".format(len(output.outputs[0].token_ids), output.outputs[0].text))
total_out_tokens += len(output.outputs[0].token_ids)
end = time.perf_counter()
return end - start, total_out_tokens
async def run_vllm_async(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
disable_frontend_multiprocessing: bool = False,
max_num_seqs: int = 8,
speculative_model: str=None,
speculative_draft_tensor_parallel_size: int = 1,
speculative_disable_by_batch_size: int = 4,
spec_decoding_acceptance_method: str = None,
enable_lora: bool = False,
max_lora_rank: int = 32,
merge_lora: bool = False,
lora_extra_vocab_size: int = 0,
lora_target_modules: List[str] = None,
tree_style_spec_decoding: bool = False,
num_speculative_heads: int = 5,
num_speculative_tokens: int = 64,
use_new_beam_search_impl: bool = False,
lora_modules: str = None
) -> float:
from vllm import SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
disable_log_requests=True,
max_num_seqs=max_num_seqs,
speculative_model=speculative_model,
speculative_draft_tensor_parallel_size=speculative_draft_tensor_parallel_size,
speculative_disable_by_batch_size=speculative_disable_by_batch_size,
spec_decoding_acceptance_method=spec_decoding_acceptance_method,
enable_lora=enable_lora,
max_lora_rank=max_lora_rank,
merge_lora=merge_lora,
lora_extra_vocab_size=lora_extra_vocab_size,
lora_target_modules=lora_target_modules,
tree_style_spec_decoding=tree_style_spec_decoding,
num_speculative_heads=num_speculative_heads,
num_speculative_tokens=num_speculative_tokens
)
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=False,
max_tokens=output_len,
))
generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
out_dict = {}
async for i, res in all_gens:
#print("res:", res)
out_dict[res.request_id] = len(res.outputs[0].token_ids)
end = time.perf_counter()
total_out_tokens = 0
for token_num in out_dict.values():
total_out_tokens += token_num
return end - start, total_out_tokens
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
if args.async_engine:
run_args = [
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc, False, args.max_num_seqs,
args.speculative_model, args.speculative_draft_tensor_parallel_size,
args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method,
args.enable_lora, args.max_lora_rank, args.merge_lora, args.lora_extra_vocab_size,
args.lora_target_modules, args.tree_style_spec_decoding, args.num_speculative_heads,
args.num_speculative_tokens
]
else:
run_args = [
warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc, args.max_num_seqs,
args.speculative_model, args.speculative_draft_tensor_parallel_size,
args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method,
args.enable_lora, args.max_lora_rank, args.merge_lora, args.lora_extra_vocab_size,
args.lora_target_modules, args.tree_style_spec_decoding, args.num_speculative_heads,
args.num_speculative_tokens
]
if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time, total_out_tokens = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time, total_out_tokens = run_vllm(*run_args, args.use_new_beam_search_impl, args.lora_modules)
total_num_tokens = total_out_tokens + sum(prompt_len
for _, prompt_len, _ in requests)
print(f"Latency: {elapsed_time:.2f} s")
print(f"All Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
print(f"Generate Throughput: {total_out_tokens / elapsed_time:.2f} tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=256,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument('--num-iters-warmup',
type=int,
default=1,
help='Number of iterations to run for warmup.')
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--max-model-len',
type=int,
default=None,
help='Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument("--device",
type=str,
default="auto",
choices=DEVICE_OPTIONS,
help='device type for vLLM execution')
parser.add_argument(
"--num-scheduler-steps",
type=int,
default=1,
help="Maximum number of forward steps per scheduler call.")
parser.add_argument("--use-v2-block-manager",
action='store_true',
help="Enable block manager v2.")
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="Enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
parser.add_argument(
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
'bitsandbytes'
],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available.\n'
'* "pt" will load the weights in the pytorch bin format.\n'
'* "safetensors" will load the weights in the safetensors format.\n'
'* "npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading.\n'
'* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
"--disable-async-output-proc",
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='Maximum number of sequences per iteration.')
parser.add_argument(
'--speculative-model',
type=nullable_str,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=EngineArgs.speculative_draft_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')
parser.add_argument(
'--spec-decoding-acceptance-method',
type=str,
default=EngineArgs.spec_decoding_acceptance_method,
choices=['rejection_sampler', 'typical_acceptance_sampler'],
help='Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument('--merge-lora',
type=bool,
default=False,
help='If set to True, the weights of the base layer will be merged with the weights of Lora.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument('--lora-target-modules',
nargs='*',
default=None,
help='List of lora module name, If not specified, modules will be chosen according to the model architecture.')
parser.add_argument('--tree-style-spec-decoding',
type=bool,
default=False,
help='If set to True, tree-style generation will be activated.')
parser.add_argument(
'--num-speculative-heads',
type=int,
default=EngineArgs.num_speculative_heads,
help='The number of speculative heads to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--lora-modules',
type=nullable_str,
default=None,
help=
'Path of lora model.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
main(args)
\ No newline at end of file
...@@ -20,9 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -20,9 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.weight' TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'base_model.model.medusa_head.{}.1.weight' TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.bias' TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.bias'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight' VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight'
VLLM_BLOCK_BIAS_NAME_TEMPLATE = 'blocks.{}.layers.{}.bias' VLLM_BLOCK_BIAS_NAME_TEMPLATE = 'blocks.{}.layers.{}.bias'
......
...@@ -238,6 +238,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -238,6 +238,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
tree_attention_masks_tensor: Optional[torch.Tensor] = None tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None
@property @property
def prefill_metadata( def prefill_metadata(
self) -> Optional["BlocksparseFlashAttentionMetadata"]: self) -> Optional["BlocksparseFlashAttentionMetadata"]:
...@@ -269,7 +271,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -269,7 +271,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor=self.context_lens_tensor[:self.num_prefills], context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False, use_cuda_graph=False,
tree_attention_masks_tensor=self.tree_attention_masks_tensor tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
) )
return self._cached_prefill_metadata return self._cached_prefill_metadata
...@@ -298,7 +301,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -298,7 +301,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor=None, context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:], block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
tree_attention_masks_tensor=self.tree_attention_masks_tensor tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
) )
return self._cached_decode_metadata return self._cached_decode_metadata
......
...@@ -168,6 +168,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -168,6 +168,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
tree_attention_masks_tensor: Optional[torch.Tensor] = None tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None
@property @property
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
...@@ -199,7 +200,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -199,7 +200,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor=self.context_lens_tensor[:self.num_prefills], context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False, use_cuda_graph=False,
tree_attention_masks_tensor=self.tree_attention_masks_tensor tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
) )
return self._cached_prefill_metadata return self._cached_prefill_metadata
...@@ -228,7 +230,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -228,7 +230,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor=None, context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:], block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
tree_attention_masks_tensor=self.tree_attention_masks_tensor tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list
) )
return self._cached_decode_metadata return self._cached_decode_metadata
......
...@@ -272,7 +272,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -272,7 +272,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
tree_attention_masks_tensor=tree_attention_masks_tensor tree_attention_masks_tensor=tree_attention_masks_tensor,
block_tables_list=self.block_tables
) )
......
...@@ -190,6 +190,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -190,6 +190,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_block_tables: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None
tree_attention_masks_tensor: Optional[torch.Tensor] = None tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
...@@ -271,7 +272,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -271,7 +272,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len=self.max_encoder_seq_len, max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping, cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables, cross_block_tables=self.cross_block_tables,
tree_attention_masks_tensor=self.tree_attention_masks_tensor) tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list)
return self._cached_prefill_metadata return self._cached_prefill_metadata
@property @property
...@@ -311,7 +313,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -311,7 +313,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len=self.max_encoder_seq_len, max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping, cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables, cross_block_tables=self.cross_block_tables,
tree_attention_masks_tensor=self.tree_attention_masks_tensor) tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list)
return self._cached_decode_metadata return self._cached_decode_metadata
......
...@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata, Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus, VLLM_INVALID_TOKEN_ID) SequenceStatus, CompletionSequenceGroupOutput, VLLM_INVALID_TOKEN_ID)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
...@@ -989,7 +989,7 @@ class LLMEngine: ...@@ -989,7 +989,7 @@ class LLMEngine:
output = [outputs_by_sequence_group[0][i]] output = [outputs_by_sequence_group[0][i]]
# tree style speculative decoding may generate empty output in first step # tree style speculative decoding may generate empty output in first step
if outputs and isinstance(output[0], SamplerOutput): if outputs and isinstance(output[0], CompletionSequenceGroupOutput):
samples = [o.samples[0] for o in output] samples = [o.samples[0] for o in output]
valid_samples = [ valid_samples = [
sample for sample in samples sample for sample in samples
......
...@@ -235,7 +235,6 @@ class Sampler(nn.Module): ...@@ -235,7 +235,6 @@ class Sampler(nn.Module):
sampling_metadata: Metadata for sampling. sampling_metadata: Metadata for sampling.
""" """
assert logits is not None assert logits is not None
original_logits = logits.clone()
_, vocab_size = logits.shape _, vocab_size = logits.shape
# Prepare sampling tensors with pinned memory to avoid blocking. # Prepare sampling tensors with pinned memory to avoid blocking.
...@@ -320,7 +319,7 @@ class Sampler(nn.Module): ...@@ -320,7 +319,7 @@ class Sampler(nn.Module):
sample_logprobs, sample_logprobs,
on_device_tensors=on_device_tensors, on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output, skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=original_logits) logits=logits)
@property @property
def _should_modify_greedy_probs_inplace(self) -> bool: def _should_modify_greedy_probs_inplace(self) -> bool:
......
...@@ -198,10 +198,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -198,10 +198,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
k = draft_token_ids.shape[-1] k = draft_token_ids.shape[-1]
output_token_id_list = [] output_token_id_list = []
logger.info("accept_length:%s", accept_length) accept_length_list = accept_length.cpu().tolist()
logger.info("accept_length:%s", accept_length_list)
for i in range(batch_size): for i in range(batch_size):
output_best_candidates.append(best_candidate[i]) output_best_candidates.append(best_candidate[i])
accept_lengths.append(accept_length[i]) accept_lengths.append(accept_length_list[i])
if not first_step_flags[i]: if not first_step_flags[i]:
select_indices = cart_candidates[i, best_candidate[i], : accept_length[i] + 1] select_indices = cart_candidates[i, best_candidate[i], : accept_length[i] + 1]
......
...@@ -996,9 +996,6 @@ class SequenceGroupMetadata( ...@@ -996,9 +996,6 @@ class SequenceGroupMetadata(
# TODO: We should maintain this states out of the sequence group. # TODO: We should maintain this states out of the sequence group.
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
tree_attn_masks : Optional[torch.Tensor] = None
tree_position_ids : Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
if self.seq_data is not None and self.token_chunk_size is None: if self.seq_data is not None and self.token_chunk_size is None:
if self.is_prompt: if self.is_prompt:
...@@ -1036,11 +1033,6 @@ class SequenceGroupMetadata( ...@@ -1036,11 +1033,6 @@ class SequenceGroupMetadata(
assert self.state.current_step < self.state.num_steps assert self.state.current_step < self.state.num_steps
self.state.current_step += 1 self.state.current_step += 1
def set_tree_style_args(self, tree_attn_masks: Optional[torch.Tensor],
tree_position_ids: Optional[torch.Tensor]):
self.tree_attn_masks = tree_attn_masks
self.tree_position_ids = tree_position_ids
class SequenceOutput( class SequenceOutput(
msgspec.Struct, msgspec.Struct,
......
...@@ -343,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -343,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Move the tensors in the dictionary to the specified device # Move the tensors in the dictionary to the specified device
medusa_buffers = { medusa_buffers = {
k: (v.clone().to(device) if k != "tree_position_ids" else v.clone()) k: v.clone().to(device)
if isinstance(v, torch.Tensor) if isinstance(v, torch.Tensor)
else torch.tensor(v, device=device) else torch.tensor(v, device=device)
for k, v in medusa_buffers.items() for k, v in medusa_buffers.items()
......
...@@ -318,8 +318,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -318,8 +318,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
""" """
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True ) = True
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True # tree_style decoding modify probs in _verify_tokens
if not self.tree_style_spec_decoding:
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor() self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace() self.proposer_worker.set_should_modify_greedy_probs_inplace()
...@@ -694,7 +697,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -694,7 +697,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores: SpeculativeScores, proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
max_proposal_len: int, max_proposal_len: int,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], List[int]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
"""Determine which speculative tokens are accepted using the """Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models. probabilities of each token according to the proposer and scorer models.
...@@ -712,7 +715,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -712,7 +715,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
original_indices = spec_indices + non_spec_indices original_indices = spec_indices + non_spec_indices
# Get probabilities of target model, including bonus tokens. # Get probabilities of target model, including bonus tokens.
proposal_verifier_probs = proposal_scores.probs[spec_indices] if non_spec_indices:
proposal_verifier_probs = proposal_scores.probs[spec_indices]
else:
proposal_verifier_probs = proposal_scores.probs
if self.tree_style_spec_decoding: if self.tree_style_spec_decoding:
retrieve_indices = proposals.retrieve_indices retrieve_indices = proposals.retrieve_indices
...@@ -722,17 +728,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -722,17 +728,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
# Get bonus tokens from target model. # Get bonus tokens from target model.
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] bonus_token_ids = proposal_scores.token_ids[:, -1:]
if non_spec_indices:
bonus_token_ids = bonus_token_ids[spec_indices, :]
# Get probabilities according to proposal method. # Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs[spec_indices] \ proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
if proposals.proposal_probs is not None else None if non_spec_indices:
proposal_probs = proposal_probs[spec_indices]
# Get proposed tokens. # Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices] proposal_token_ids = proposals.proposal_token_ids
if non_spec_indices:
proposal_token_ids = proposal_token_ids[spec_indices]
# Get tree buffers. # Get tree buffers.
cart_candidates = proposals.cart_candidates[spec_indices] if proposals.cart_candidates is not None else None cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
if non_spec_indices:
cart_candidates = cart_candidates[spec_indices]
# Sampler arguments # Sampler arguments
sampler_extra_kwargs: Dict[str, Any] = {} sampler_extra_kwargs: Dict[str, Any] = {}
...@@ -820,6 +833,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -820,6 +833,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
previous_logits_list = [] previous_logits_list = []
previous_hidden_state_list = [] previous_hidden_state_list = []
retrieve_indices = retrieve_indices.cpu()
for i in range(batch_size): for i in range(batch_size):
logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0) logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
...@@ -865,13 +880,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -865,13 +880,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
model_input = self.scorer._scorer_worker.model_input model_input = self.scorer._scorer_worker.model_input
block_tables = None block_tables = None
if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables'): if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables_list'):
block_tables = model_input.attn_metadata.block_tables block_tables = model_input.attn_metadata.block_tables_list
if block_tables is None: if block_tables is None:
raise RuntimeError("Can not get block_tables from model_input.") raise RuntimeError("Can not get block_tables from model_input.")
block_tables = block_tables.cpu().tolist()
cache_engine = self.scorer._scorer_worker.cache_engines[execute_model_req.virtual_engine] cache_engine = self.scorer._scorer_worker.cache_engines[execute_model_req.virtual_engine]
block_size = cache_engine.block_size block_size = cache_engine.block_size
...@@ -885,10 +898,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -885,10 +898,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if accept_legth > 0: if accept_legth > 0:
select_indices = select_indices_list[i][1:] + seq_lens[i] select_indices = select_indices_list[i][1:] + seq_lens[i]
select_indices = select_indices.tolist()
self.compute_slot_mapping(select_indices_slot_mapping, i*block_table_stride, self.compute_slot_mapping(select_indices_slot_mapping, i*block_table_stride,
select_indices, block_size, block_tables) select_indices, block_size, block_tables)
target_indices = torch.arange(accept_legth+1)[1:] + seq_lens[i] target_indices = torch.arange(accept_legth+1)[1:] + seq_lens[i]
target_indices = target_indices.tolist()
self.compute_slot_mapping(target_slot_mapping, i*block_table_stride, self.compute_slot_mapping(target_slot_mapping, i*block_table_stride,
target_indices, block_size, block_tables) target_indices, block_size, block_tables)
...@@ -900,12 +915,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -900,12 +915,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
dtype=torch.long, dtype=torch.long,
device=self.device).view(-1, 1) device=self.device).view(-1, 1)
src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2] src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2]
# kv_caches = self.scorer._scorer_worker.kv_cache[execute_model_req.virtual_engine]
# kv_cache_dtype = cache_engine.cache_config.cache_dtype
# backend = cache_engine.attn_backend
# num_kv_heads = cache_engine.num_kv_heads
# head_size = cache_engine.head_size
# backend.move_cache(kv_caches, src_dst_tensor, kv_cache_dtype, num_kv_heads*4, head_size)
self.kvcache_slot_to_be_moved = src_dst_tensor self.kvcache_slot_to_be_moved = src_dst_tensor
def compute_slot_mapping(self, slot_mapping: List[int], def compute_slot_mapping(self, slot_mapping: List[int],
......
...@@ -469,6 +469,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -469,6 +469,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.sliding_window + self.block_size - 1) // self.block_size self.sliding_window + self.block_size - 1) // self.block_size
self.block_aligned_sliding_window = \ self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size self.sliding_window_blocks * self.block_size
if hasattr(self.runner, "tree_attn_masks"):
self.tree_attn_masks = self.runner.tree_attn_masks
self.tree_position_ids = self.runner.tree_position_ids
else:
self.tree_attn_masks = None
self.tree_position_ids = None
self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder_model
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata): seq_group_metadata: SequenceGroupMetadata):
...@@ -511,10 +520,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -511,10 +520,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len)) inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
if seq_group_metadata.tree_position_ids is not None:
inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.contiguous().tolist()
inter_data.tree_attn_masks[seq_idx] = seq_group_metadata.tree_attn_masks
inter_data.query_lens[ inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
...@@ -718,7 +723,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -718,7 +723,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len = 0 encoder_seq_len = 0
if self.runner.model_config.is_encoder_decoder_model: if self.is_encoder_decoder_model:
encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()
inter_data = self.init_cached_inter_data( inter_data = self.init_cached_inter_data(
...@@ -796,7 +801,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -796,7 +801,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if not inter_data.is_prompt: if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len, max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens)) max(inter_data.seq_lens))
if self.runner.model_config.is_encoder_decoder_model: if self.is_encoder_decoder_model:
max_encoder_seq_len = max(max_encoder_seq_len, max_encoder_seq_len = max(max_encoder_seq_len,
inter_data.encoder_seq_len) inter_data.encoder_seq_len)
...@@ -847,22 +852,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -847,22 +852,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Sequence and query lengths. # Sequence and query lengths.
if cuda_graph_pad_size: if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# prepare tree attention masks # prepare tree attention masks
max_context_len = 0 tree_attention_masks_tensor = self.tree_attn_masks
for inter_data in self.inter_data_list: if tree_attention_masks_tensor is not None:
max_context_len = max(max_context_len, max(inter_data.context_lens))
tree_attention_masks_list = []
for inter_data in self.inter_data_list:
for i in range(len(inter_data.seq_lens)):
if inter_data.tree_attn_masks:
tree_attn_masks = inter_data.tree_attn_masks[i]
if tree_attn_masks is not None:
tree_attention_masks_list.append(tree_attn_masks)
tree_attention_masks_tensor = None
if tree_attention_masks_list:
tree_attention_masks_tensor = torch.stack(tree_attention_masks_list, dim=0)
tree_attention_masks_tensor = tree_attention_masks_tensor.contiguous() tree_attention_masks_tensor = tree_attention_masks_tensor.contiguous()
input_positions_tensor = self.tree_position_ids.contiguous()
# Attention metadata. # Attention metadata.
attn_metadata = self.attn_metadata_builder.build( attn_metadata = self.attn_metadata_builder.build(
...@@ -1038,6 +1033,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1038,6 +1033,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.inter_data_cache: Dict[int, PyObjectCache] = {} self.inter_data_cache: Dict[int, PyObjectCache] = {}
self.sampling_metadata_cache: SamplingMetadataCache = \ self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() SamplingMetadataCache()
self.tree_attn_masks: Optional[torch.Tensor] = None
self.tree_position_ids : Optional[torch.Tensor] = None
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
...@@ -1505,6 +1503,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1505,6 +1503,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
return self.model_config.get_vocab_size() return self.model_config.get_vocab_size()
def set_tree_style_args(self, tree_attn_masks: Optional[torch.Tensor],
tree_position_ids: Optional[torch.Tensor]):
self.tree_attn_masks = tree_attn_masks
self.tree_position_ids = tree_position_ids
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
......
...@@ -283,11 +283,9 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -283,11 +283,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input: WorkerInput = self.prepare_worker_input( worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
# set tree_attn_masks and position ids to seq_group_metadata_list if hasattr(self.model_runner, "set_tree_style_args"):
if execute_model_req.tree_attn_masks is not None: self.model_runner.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks,
for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list): tree_position_ids=execute_model_req.tree_position_ids)
seq_group_metadata.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks[i],
tree_position_ids=execute_model_req.tree_position_ids[i])
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment