Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
28375803
Commit
28375803
authored
Nov 27, 2024
by
王敏
Browse files
1.优化medusa推理,节省cpu耗时
2.更新medusa readme 3.解决benchmark_moe报错问题
parent
3c9817d2
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
765 additions
and
81 deletions
+765
-81
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+3
-1
examples/medusa/README.md
examples/medusa/README.md
+15
-11
examples/medusa/medusa_benchmark_throughput.py
examples/medusa/medusa_benchmark_throughput.py
+664
-0
examples/medusa/medusa_weight_converter.py
examples/medusa/medusa_weight_converter.py
+3
-3
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+6
-2
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+5
-2
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+2
-1
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+5
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+1
-2
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+3
-2
vllm/sequence.py
vllm/sequence.py
+0
-8
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+1
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+29
-19
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+23
-20
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+3
-5
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
28375803
...
...
@@ -304,7 +304,9 @@ def main(args: argparse.Namespace):
else
:
batch_sizes
=
[
args
.
batch_size
]
ray
.
init
()
ray
.
init
(
address
=
None
,
ignore_reinit_error
=
True
,
num_gpus
=
args
.
tp_size
)
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
...
...
examples/medusa/README.md
View file @
28375803
...
...
@@ -17,9 +17,9 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
# medusa 模型需要转换为vllm中Medusa的模型格式
```
bash
python medusa_weight_converter.py
--medusa_num_heads
4
--medusa_num_layers
1
--medusa_model_path
/work/
medusa/qwen2_72b_head_4/adapter_
model.bin
--vocab_size
152064
--hidden_size
8192
--output_dir
/work/medusa/
sugon/
vllm-medusa-qwen2-72b-head-4
--medusa_choices
=
"[(0
,
), (0, 0), (0, 0, 0), (0,
0, 0, 0
), (
0,
1), (1,), (
1
, 0), (0, 0, 1), (0,
1, 0), (0,
2), (
1
,
0
, 0), (2
,
), (
2, 0)
,
(
0,
3
), (0,
0, 2), (0, 2
, 0
)
,
(0, 4), (0, 0, 1, 0), (0, 1, 0,
0), (2, 0, 0),
(3,),
(0,
5
), (0, 0,
0, 1), (3, 0), (0, 0,
3), (
1,
0, 0, 0
), (0, 3
,
0
), (0,
6
), (0, 0,
4), (0, 4
, 0), (
1
,
1), (4,
)]"
python medusa_weight_converter.py
--medusa_num_heads
4
--medusa_num_layers
1
--medusa_model_path
/work/model.bin
--vocab_size
152064
--hidden_size
8192
--output_dir
/work/medusa/vllm-medusa-qwen2-72b-head-4
--medusa_choices
=
"[(0), (0, 0), (0, 0, 0), (0,
1
), (1), (1,
0
), (
0, 0, 0
, 0), (0, 0, 1), (0, 2), (
0
,
1
, 0), (2), (
0
, 0,
2
), (0,
3), (1
, 0, 0), (2, 0
), (0, 2
, 0), (0,
4
), (0, 0,
3), (
3), (0, 0, 0,
1
), (0,
5
), (0, 0,
1
, 0), (
0
,
0, 4
)]"
```
此处
qwen2_72b_head_4是medusa模型使用peft lora训练后保存的权重,其他格式也可参考[medusa_weight_converter.py]修改进行权重转换
此处
model.bin是训练后保存的medusa head权重
### Run
...
...
@@ -28,30 +28,34 @@ python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --m
python3
-m
vllm.entrypoints.openai.api_server
\
--served-model-name
qwen_medusa
\
--model
/models/Qwen2-72B-Instruct/
-tp
4
\
--max-model-len
1024
--max-num-seqs
8
--gpu-memory-utilization
0.
7
\
--max-model-len
1024
--max-num-seqs
8
--gpu-memory-utilization
0.
8
\
--speculative-model
/work/medusa/vllm-medusa-qwen2-72b-head-4
\
--speculative-draft-tensor-parallel-size
4
\
--speculative-disable-by-batch-size
4
\
--use-v2-block-manager
\
--spec-decoding-acceptance-method
typical_acceptance_sampler
\
--enforce-eager
--dtype
float16
--trust-remote-code
--port
8086
\
--enable-lora
--lora-modules
medusa-lora
=
/work/qwen2_72b_head_4
\
--max-lora-rank
32
--lora-extra-vocab-size
0
--merge-lora
True
\
--lora-target-modules
qkv_proj
\
--dtype
float16
--trust-remote-code
--port
8086
\
--tree-style-spec-decoding
True
\
--num-speculative-heads
4
--num-speculative-tokens
33
--num-speculative-heads
4
--num-speculative-tokens
24
```
merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) +
2
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) +
1
# do request
```
bash
curl http://localhost:8086/v1/completions
\
-H
"Content-Type: application/json"
\
-d
'{
"model": "medus
a-lor
a",
"model": "
qwen_
medusa",
"prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n",
"max_tokens": 256,
"temperature": 0.0
}'
```
bash
```
### benchmark
python medusa_benchmark_throughput.py --model /data/llm-models/qwen2/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 1 --dataset /work/test/medusa_benchmark_data.json --max-model-len 4096 --gpu-memory-utilization 0.9
可设置max-num-seqs对不同的batch进行性能测试
examples/medusa/medusa_benchmark_throughput.py
0 → 100644
View file @
28375803
"""Benchmark offline inference throughput."""
import
argparse
import
json
import
random
import
time
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
uvloop
from
tqdm
import
tqdm
from
transformers
import
(
AutoModelForCausalLM
,
AutoTokenizer
,
PreTrainedTokenizerBase
)
from
vllm.inputs
import
PromptInputs
from
vllm.engine.arg_utils
import
DEVICE_OPTIONS
,
AsyncEngineArgs
,
EngineArgs
from
vllm.entrypoints.openai.api_server
import
(
build_async_engine_client_from_engine_args
)
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.utils
import
FlexibleArgumentParser
,
merge_async_iterators
from
vllm.lora.request
import
LoRARequest
def
nullable_str
(
val
:
str
):
if
not
val
or
val
==
"None"
:
return
None
return
val
def
sample_requests
(
dataset_path
:
str
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
fixed_output_len
:
Optional
[
int
],
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
dataset
=
json
.
load
(
f
)
# Only keep the first two turns of each conversation.
dataset
=
[
data
[
"prompt"
]
for
data
in
dataset
]
# Shuffle the dataset.
random
.
shuffle
(
dataset
)
# Filter out sequences that are too long or too short
filtered_dataset
:
List
[
Tuple
[
str
,
int
,
int
]]
=
[]
for
i
in
range
(
len
(
dataset
)):
if
len
(
filtered_dataset
)
==
num_requests
:
break
# Tokenize the prompts and completions.
prompt
=
dataset
[
i
]
prompt_token_ids
=
tokenizer
(
prompt
).
input_ids
prompt_len
=
len
(
prompt_token_ids
)
output_len
=
fixed_output_len
filtered_dataset
.
append
((
prompt
,
prompt_len
,
output_len
))
return
filtered_dataset
def
run_vllm
(
warmup_requests
:
List
[
Tuple
[
str
,
int
,
int
]],
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tokenizer
:
str
,
quantization
:
Optional
[
str
],
tensor_parallel_size
:
int
,
seed
:
int
,
n
:
int
,
use_beam_search
:
bool
,
trust_remote_code
:
bool
,
dtype
:
str
,
max_model_len
:
Optional
[
int
],
enforce_eager
:
bool
,
kv_cache_dtype
:
str
,
quantization_param_path
:
Optional
[
str
],
device
:
str
,
enable_prefix_caching
:
bool
,
enable_chunked_prefill
:
bool
,
max_num_batched_tokens
:
int
,
distributed_executor_backend
:
Optional
[
str
],
gpu_memory_utilization
:
float
=
0.9
,
num_scheduler_steps
:
int
=
1
,
use_v2_block_manager
:
bool
=
False
,
download_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
EngineArgs
.
load_format
,
disable_async_output_proc
:
bool
=
False
,
max_num_seqs
:
int
=
8
,
speculative_model
:
str
=
None
,
speculative_draft_tensor_parallel_size
:
int
=
1
,
speculative_disable_by_batch_size
:
int
=
4
,
spec_decoding_acceptance_method
:
str
=
None
,
enable_lora
:
bool
=
False
,
max_lora_rank
:
int
=
32
,
merge_lora
:
bool
=
False
,
lora_extra_vocab_size
:
int
=
0
,
lora_target_modules
:
List
[
str
]
=
None
,
tree_style_spec_decoding
:
bool
=
False
,
num_speculative_heads
:
int
=
5
,
num_speculative_tokens
:
int
=
64
,
use_new_beam_search_impl
:
bool
=
False
,
lora_modules
:
str
=
None
)
->
float
:
from
vllm
import
LLM
,
SamplingParams
llm
=
LLM
(
model
=
model
,
tokenizer
=
tokenizer
,
quantization
=
quantization
,
tensor_parallel_size
=
tensor_parallel_size
,
seed
=
seed
,
trust_remote_code
=
trust_remote_code
,
dtype
=
dtype
,
max_model_len
=
max_model_len
,
gpu_memory_utilization
=
gpu_memory_utilization
,
enforce_eager
=
enforce_eager
,
kv_cache_dtype
=
kv_cache_dtype
,
quantization_param_path
=
quantization_param_path
,
device
=
device
,
enable_prefix_caching
=
enable_prefix_caching
,
download_dir
=
download_dir
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
distributed_executor_backend
=
distributed_executor_backend
,
load_format
=
load_format
,
num_scheduler_steps
=
num_scheduler_steps
,
use_v2_block_manager
=
use_v2_block_manager
,
disable_async_output_proc
=
disable_async_output_proc
,
max_num_seqs
=
max_num_seqs
,
speculative_model
=
speculative_model
,
speculative_draft_tensor_parallel_size
=
speculative_draft_tensor_parallel_size
,
speculative_disable_by_batch_size
=
speculative_disable_by_batch_size
,
spec_decoding_acceptance_method
=
spec_decoding_acceptance_method
,
enable_lora
=
enable_lora
,
max_lora_rank
=
max_lora_rank
,
merge_lora
=
merge_lora
,
lora_extra_vocab_size
=
lora_extra_vocab_size
,
lora_target_modules
=
lora_target_modules
,
tree_style_spec_decoding
=
tree_style_spec_decoding
,
num_speculative_heads
=
num_speculative_heads
,
num_speculative_tokens
=
num_speculative_tokens
)
# Add the requests to the engine.
prompts
:
List
[
str
]
=
[]
sampling_params
:
List
[
SamplingParams
]
=
[]
for
prompt
,
_
,
output_len
in
requests
:
prompts
.
append
(
prompt
)
sampling_params
.
append
(
SamplingParams
(
n
=
n
,
temperature
=
0.0
,
top_p
=
1.0
,
use_beam_search
=
use_beam_search
,
ignore_eos
=
False
,
max_tokens
=
output_len
,
))
# warmup
warmup_prompts
=
[]
warmup_sampling_params
=
[]
for
prompt
,
_
,
output_len
in
warmup_requests
:
warmup_prompts
.
append
(
prompt
)
warmup_sampling_params
.
append
(
SamplingParams
(
n
=
n
,
temperature
=
0.0
,
top_p
=
1.0
,
use_beam_search
=
use_beam_search
,
ignore_eos
=
False
,
max_tokens
=
output_len
,
))
print
(
"Warming up..."
)
for
_
in
tqdm
(
range
(
args
.
num_iters_warmup
),
desc
=
"Warmup iterations"
):
if
lora_modules
is
None
:
llm
.
generate
(
warmup_prompts
,
warmup_sampling_params
,
use_tqdm
=
True
)
else
:
llm
.
generate
(
warmup_prompts
,
warmup_sampling_params
,
use_tqdm
=
True
,
lora_request
=
LoRARequest
(
"medusa-lora"
,
1
,
lora_modules
))
total_out_tokens
=
0
start
=
time
.
perf_counter
()
if
lora_modules
is
None
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
False
)
else
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
False
,
lora_request
=
LoRARequest
(
"medusa-lora"
,
1
,
lora_modules
))
for
output
in
outputs
:
print
(
"token_ids len:{} text:{}"
.
format
(
len
(
output
.
outputs
[
0
].
token_ids
),
output
.
outputs
[
0
].
text
))
total_out_tokens
+=
len
(
output
.
outputs
[
0
].
token_ids
)
end
=
time
.
perf_counter
()
return
end
-
start
,
total_out_tokens
async
def
run_vllm_async
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tokenizer
:
str
,
quantization
:
Optional
[
str
],
tensor_parallel_size
:
int
,
seed
:
int
,
n
:
int
,
use_beam_search
:
bool
,
trust_remote_code
:
bool
,
dtype
:
str
,
max_model_len
:
Optional
[
int
],
enforce_eager
:
bool
,
kv_cache_dtype
:
str
,
quantization_param_path
:
Optional
[
str
],
device
:
str
,
enable_prefix_caching
:
bool
,
enable_chunked_prefill
:
bool
,
max_num_batched_tokens
:
int
,
distributed_executor_backend
:
Optional
[
str
],
gpu_memory_utilization
:
float
=
0.9
,
num_scheduler_steps
:
int
=
1
,
use_v2_block_manager
:
bool
=
False
,
download_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
EngineArgs
.
load_format
,
disable_async_output_proc
:
bool
=
False
,
disable_frontend_multiprocessing
:
bool
=
False
,
max_num_seqs
:
int
=
8
,
speculative_model
:
str
=
None
,
speculative_draft_tensor_parallel_size
:
int
=
1
,
speculative_disable_by_batch_size
:
int
=
4
,
spec_decoding_acceptance_method
:
str
=
None
,
enable_lora
:
bool
=
False
,
max_lora_rank
:
int
=
32
,
merge_lora
:
bool
=
False
,
lora_extra_vocab_size
:
int
=
0
,
lora_target_modules
:
List
[
str
]
=
None
,
tree_style_spec_decoding
:
bool
=
False
,
num_speculative_heads
:
int
=
5
,
num_speculative_tokens
:
int
=
64
,
use_new_beam_search_impl
:
bool
=
False
,
lora_modules
:
str
=
None
)
->
float
:
from
vllm
import
SamplingParams
engine_args
=
AsyncEngineArgs
(
model
=
model
,
tokenizer
=
tokenizer
,
quantization
=
quantization
,
tensor_parallel_size
=
tensor_parallel_size
,
seed
=
seed
,
trust_remote_code
=
trust_remote_code
,
dtype
=
dtype
,
max_model_len
=
max_model_len
,
gpu_memory_utilization
=
gpu_memory_utilization
,
enforce_eager
=
enforce_eager
,
kv_cache_dtype
=
kv_cache_dtype
,
quantization_param_path
=
quantization_param_path
,
device
=
device
,
enable_prefix_caching
=
enable_prefix_caching
,
download_dir
=
download_dir
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
distributed_executor_backend
=
distributed_executor_backend
,
load_format
=
load_format
,
num_scheduler_steps
=
num_scheduler_steps
,
use_v2_block_manager
=
use_v2_block_manager
,
disable_async_output_proc
=
disable_async_output_proc
,
worker_use_ray
=
False
,
disable_log_requests
=
True
,
max_num_seqs
=
max_num_seqs
,
speculative_model
=
speculative_model
,
speculative_draft_tensor_parallel_size
=
speculative_draft_tensor_parallel_size
,
speculative_disable_by_batch_size
=
speculative_disable_by_batch_size
,
spec_decoding_acceptance_method
=
spec_decoding_acceptance_method
,
enable_lora
=
enable_lora
,
max_lora_rank
=
max_lora_rank
,
merge_lora
=
merge_lora
,
lora_extra_vocab_size
=
lora_extra_vocab_size
,
lora_target_modules
=
lora_target_modules
,
tree_style_spec_decoding
=
tree_style_spec_decoding
,
num_speculative_heads
=
num_speculative_heads
,
num_speculative_tokens
=
num_speculative_tokens
)
async
with
build_async_engine_client_from_engine_args
(
engine_args
,
disable_frontend_multiprocessing
)
as
llm
:
# Add the requests to the engine.
prompts
:
List
[
str
]
=
[]
sampling_params
:
List
[
SamplingParams
]
=
[]
for
prompt
,
_
,
output_len
in
requests
:
prompts
.
append
(
prompt
)
sampling_params
.
append
(
SamplingParams
(
n
=
n
,
temperature
=
0.0
if
use_beam_search
else
1.0
,
top_p
=
1.0
,
use_beam_search
=
use_beam_search
,
ignore_eos
=
False
,
max_tokens
=
output_len
,
))
generators
=
[]
start
=
time
.
perf_counter
()
for
i
,
(
prompt
,
sp
)
in
enumerate
(
zip
(
prompts
,
sampling_params
)):
generator
=
llm
.
generate
(
prompt
,
sp
,
request_id
=
f
"test
{
i
}
"
)
generators
.
append
(
generator
)
all_gens
=
merge_async_iterators
(
*
generators
)
out_dict
=
{}
async
for
i
,
res
in
all_gens
:
#print("res:", res)
out_dict
[
res
.
request_id
]
=
len
(
res
.
outputs
[
0
].
token_ids
)
end
=
time
.
perf_counter
()
total_out_tokens
=
0
for
token_num
in
out_dict
.
values
():
total_out_tokens
+=
token_num
return
end
-
start
,
total_out_tokens
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
random
.
seed
(
args
.
seed
)
# Sample the requests.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
warmup_prompt
=
"hi"
*
10
warmup_requests
=
[(
warmup_prompt
,
10
,
10
)
for
_
in
range
(
1
)]
if
args
.
dataset
is
None
:
# Synthesize a prompt with the given input length.
prompt
=
"hi"
*
(
args
.
input_len
-
1
)
requests
=
[(
prompt
,
args
.
input_len
,
args
.
output_len
)
for
_
in
range
(
args
.
num_prompts
)]
else
:
requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
,
args
.
output_len
)
if
args
.
async_engine
:
run_args
=
[
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
max_model_len
,
args
.
enforce_eager
,
args
.
kv_cache_dtype
,
args
.
quantization_param_path
,
args
.
device
,
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
max_num_batched_tokens
,
args
.
distributed_executor_backend
,
args
.
gpu_memory_utilization
,
args
.
num_scheduler_steps
,
args
.
use_v2_block_manager
,
args
.
download_dir
,
args
.
load_format
,
args
.
disable_async_output_proc
,
False
,
args
.
max_num_seqs
,
args
.
speculative_model
,
args
.
speculative_draft_tensor_parallel_size
,
args
.
speculative_disable_by_batch_size
,
args
.
spec_decoding_acceptance_method
,
args
.
enable_lora
,
args
.
max_lora_rank
,
args
.
merge_lora
,
args
.
lora_extra_vocab_size
,
args
.
lora_target_modules
,
args
.
tree_style_spec_decoding
,
args
.
num_speculative_heads
,
args
.
num_speculative_tokens
]
else
:
run_args
=
[
warmup_requests
,
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
max_model_len
,
args
.
enforce_eager
,
args
.
kv_cache_dtype
,
args
.
quantization_param_path
,
args
.
device
,
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
max_num_batched_tokens
,
args
.
distributed_executor_backend
,
args
.
gpu_memory_utilization
,
args
.
num_scheduler_steps
,
args
.
use_v2_block_manager
,
args
.
download_dir
,
args
.
load_format
,
args
.
disable_async_output_proc
,
args
.
max_num_seqs
,
args
.
speculative_model
,
args
.
speculative_draft_tensor_parallel_size
,
args
.
speculative_disable_by_batch_size
,
args
.
spec_decoding_acceptance_method
,
args
.
enable_lora
,
args
.
max_lora_rank
,
args
.
merge_lora
,
args
.
lora_extra_vocab_size
,
args
.
lora_target_modules
,
args
.
tree_style_spec_decoding
,
args
.
num_speculative_heads
,
args
.
num_speculative_tokens
]
if
args
.
async_engine
:
run_args
.
append
(
args
.
disable_frontend_multiprocessing
)
elapsed_time
,
total_out_tokens
=
uvloop
.
run
(
run_vllm_async
(
*
run_args
))
else
:
elapsed_time
,
total_out_tokens
=
run_vllm
(
*
run_args
,
args
.
use_new_beam_search_impl
,
args
.
lora_modules
)
total_num_tokens
=
total_out_tokens
+
sum
(
prompt_len
for
_
,
prompt_len
,
_
in
requests
)
print
(
f
"Latency:
{
elapsed_time
:.
2
f
}
s"
)
print
(
f
"All Throughput:
{
len
(
requests
)
/
elapsed_time
:.
2
f
}
requests/s, "
f
"
{
total_num_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
print
(
f
"Generate Throughput:
{
total_out_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
# Output JSON results if specified
if
args
.
output_json
:
results
=
{
"elapsed_time"
:
elapsed_time
,
"num_requests"
:
len
(
requests
),
"total_num_tokens"
:
total_num_tokens
,
"requests_per_second"
:
len
(
requests
)
/
elapsed_time
,
"tokens_per_second"
:
total_num_tokens
/
elapsed_time
,
}
with
open
(
args
.
output_json
,
"w"
)
as
f
:
json
.
dump
(
results
,
f
,
indent
=
4
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the throughput."
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
None
,
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
None
,
help
=
"Input prompt length for each request"
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
256
,
help
=
"Output length for each request. Overrides the "
"output length from the dataset."
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"facebook/opt-125m"
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
'-q'
,
choices
=
[
*
QUANTIZATION_METHODS
,
None
],
default
=
None
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1
,
help
=
"Number of generated sequences per prompt."
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
'--num-iters-warmup'
,
type
=
int
,
default
=
1
,
help
=
'Number of iterations to run for warmup.'
)
parser
.
add_argument
(
"--use-new-beam-search-impl"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
1000
,
help
=
"Number of prompts to process."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--trust-remote-code'
,
action
=
'store_true'
,
help
=
'trust remote code from huggingface'
)
parser
.
add_argument
(
'--max-model-len'
,
type
=
int
,
default
=
None
,
help
=
'Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'auto'
,
choices
=
[
'auto'
,
'half'
,
'float16'
,
'bfloat16'
,
'float'
,
'float32'
],
help
=
'data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
0.9
,
help
=
'the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.'
)
parser
.
add_argument
(
"--enforce-eager"
,
action
=
"store_true"
,
help
=
"enforce eager execution"
)
parser
.
add_argument
(
'--kv-cache-dtype'
,
type
=
str
,
choices
=
[
'auto'
,
'fp8'
,
'fp8_e5m2'
,
'fp8_e4m3'
],
default
=
"auto"
,
help
=
'Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)'
)
parser
.
add_argument
(
'--quantization-param-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.'
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"auto"
,
choices
=
DEVICE_OPTIONS
,
help
=
'device type for vLLM execution'
)
parser
.
add_argument
(
"--num-scheduler-steps"
,
type
=
int
,
default
=
1
,
help
=
"Maximum number of forward steps per scheduler call."
)
parser
.
add_argument
(
"--use-v2-block-manager"
,
action
=
'store_true'
,
help
=
"Enable block manager v2."
)
parser
.
add_argument
(
"--enable-prefix-caching"
,
action
=
'store_true'
,
help
=
"Enable automatic prefix caching for vLLM backend."
)
parser
.
add_argument
(
"--enable-chunked-prefill"
,
action
=
'store_true'
,
help
=
"enable chunked prefill for vLLM backend."
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
None
,
help
=
'maximum number of batched tokens per '
'iteration'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
None
,
help
=
'directory to download and load the weights, '
'default to the default cache dir of huggingface'
)
parser
.
add_argument
(
'--output-json'
,
type
=
str
,
default
=
None
,
help
=
'Path to save the throughput results in JSON format.'
)
parser
.
add_argument
(
'--distributed-executor-backend'
,
choices
=
[
'ray'
,
'mp'
],
default
=
None
,
help
=
'Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.'
)
parser
.
add_argument
(
'--load-format'
,
type
=
str
,
default
=
EngineArgs
.
load_format
,
choices
=
[
'auto'
,
'pt'
,
'safetensors'
,
'npcache'
,
'dummy'
,
'tensorizer'
,
'bitsandbytes'
],
help
=
'The format of the model weights to load.
\n\n
'
'* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available.
\n
'
'* "pt" will load the weights in the pytorch bin format.
\n
'
'* "safetensors" will load the weights in the safetensors format.
\n
'
'* "npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading.
\n
'
'* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.
\n
'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.
\n
'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.
\n
'
)
parser
.
add_argument
(
"--disable-async-output-proc"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Disable async output processor for vLLM backend."
)
parser
.
add_argument
(
"--async-engine"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Use vLLM async engine rather than LLM class."
)
parser
.
add_argument
(
"--disable-frontend-multiprocessing"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Disable decoupled async engine frontend."
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
EngineArgs
.
max_num_seqs
,
help
=
'Maximum number of sequences per iteration.'
)
parser
.
add_argument
(
'--speculative-model'
,
type
=
nullable_str
,
default
=
EngineArgs
.
speculative_model
,
help
=
'The name of the draft model to be used in speculative decoding.'
)
parser
.
add_argument
(
'--speculative-draft-tensor-parallel-size'
,
'-spec-draft-tp'
,
type
=
int
,
default
=
EngineArgs
.
speculative_draft_tensor_parallel_size
,
help
=
'Number of tensor parallel replicas for '
'the draft model in speculative decoding.'
)
parser
.
add_argument
(
'--speculative-disable-by-batch-size'
,
type
=
int
,
default
=
EngineArgs
.
speculative_disable_by_batch_size
,
help
=
'Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.'
)
parser
.
add_argument
(
'--spec-decoding-acceptance-method'
,
type
=
str
,
default
=
EngineArgs
.
spec_decoding_acceptance_method
,
choices
=
[
'rejection_sampler'
,
'typical_acceptance_sampler'
],
help
=
'Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.'
)
# LoRA related configs
parser
.
add_argument
(
'--enable-lora'
,
action
=
'store_true'
,
help
=
'If True, enable handling of LoRA adapters.'
)
parser
.
add_argument
(
'--max-lora-rank'
,
type
=
int
,
default
=
EngineArgs
.
max_lora_rank
,
help
=
'Max LoRA rank.'
)
parser
.
add_argument
(
'--merge-lora'
,
type
=
bool
,
default
=
False
,
help
=
'If set to True, the weights of the base layer will be merged with the weights of Lora.'
)
parser
.
add_argument
(
'--lora-extra-vocab-size'
,
type
=
int
,
default
=
EngineArgs
.
lora_extra_vocab_size
,
help
=
(
'Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'
))
parser
.
add_argument
(
'--lora-target-modules'
,
nargs
=
'*'
,
default
=
None
,
help
=
'List of lora module name, If not specified, modules will be chosen according to the model architecture.'
)
parser
.
add_argument
(
'--tree-style-spec-decoding'
,
type
=
bool
,
default
=
False
,
help
=
'If set to True, tree-style generation will be activated.'
)
parser
.
add_argument
(
'--num-speculative-heads'
,
type
=
int
,
default
=
EngineArgs
.
num_speculative_heads
,
help
=
'The number of speculative heads to sample from '
'the draft model in speculative decoding.'
)
parser
.
add_argument
(
'--num-speculative-tokens'
,
type
=
int
,
default
=
EngineArgs
.
num_speculative_tokens
,
help
=
'The number of speculative tokens to sample from '
'the draft model in speculative decoding.'
)
parser
.
add_argument
(
'--lora-modules'
,
type
=
nullable_str
,
default
=
None
,
help
=
'Path of lora model.'
)
args
=
parser
.
parse_args
()
if
args
.
tokenizer
is
None
:
args
.
tokenizer
=
args
.
model
if
args
.
dataset
is
None
:
assert
args
.
input_len
is
not
None
assert
args
.
output_len
is
not
None
else
:
assert
args
.
input_len
is
None
main
(
args
)
\ No newline at end of file
examples/medusa/medusa_weight_converter.py
View file @
28375803
...
...
@@ -20,9 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE
=
64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE
=
'
base_model.model.
medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE
=
'
base_model.model.
medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE
=
'
base_model.model.
medusa_head.{}.{}.linear.bias'
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE
=
'medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE
=
'medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE
=
'medusa_head.{}.{}.linear.bias'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE
=
'blocks.{}.layers.{}.weight'
VLLM_BLOCK_BIAS_NAME_TEMPLATE
=
'blocks.{}.layers.{}.bias'
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
28375803
...
...
@@ -238,6 +238,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
block_tables_list
:
Optional
[
List
[
int
]]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"BlocksparseFlashAttentionMetadata"
]:
...
...
@@ -269,7 +271,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
return
self
.
_cached_prefill_metadata
...
...
@@ -298,7 +301,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
return
self
.
_cached_decode_metadata
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
28375803
...
...
@@ -168,6 +168,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
_cached_decode_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
block_tables_list
:
Optional
[
List
[
int
]]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"ROCmFlashAttentionMetadata"
]:
...
...
@@ -199,7 +200,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
return
self
.
_cached_prefill_metadata
...
...
@@ -228,7 +230,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
return
self
.
_cached_decode_metadata
...
...
vllm/attention/backends/utils.py
View file @
28375803
...
...
@@ -272,7 +272,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
tree_attention_masks_tensor
=
tree_attention_masks_tensor
tree_attention_masks_tensor
=
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables
)
...
...
vllm/attention/backends/xformers.py
View file @
28375803
...
...
@@ -190,6 +190,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
block_tables_list
:
Optional
[
List
[
int
]]
=
None
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
...
...
@@ -271,7 +272,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
)
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
return
self
.
_cached_prefill_metadata
@
property
...
...
@@ -311,7 +313,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
)
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
return
self
.
_cached_decode_metadata
...
...
vllm/engine/llm_engine.py
View file @
28375803
...
...
@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
,
VLLM_INVALID_TOKEN_ID
)
SequenceStatus
,
CompletionSequenceGroupOutput
,
VLLM_INVALID_TOKEN_ID
)
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
...
...
@@ -989,7 +989,7 @@ class LLMEngine:
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
# tree style speculative decoding may generate empty output in first step
if
outputs
and
isinstance
(
output
[
0
],
Sa
mple
r
Output
):
if
outputs
and
isinstance
(
output
[
0
],
Co
mple
tionSequenceGroup
Output
):
samples
=
[
o
.
samples
[
0
]
for
o
in
output
]
valid_samples
=
[
sample
for
sample
in
samples
...
...
vllm/model_executor/layers/sampler.py
View file @
28375803
...
...
@@ -235,7 +235,6 @@ class Sampler(nn.Module):
sampling_metadata: Metadata for sampling.
"""
assert
logits
is
not
None
original_logits
=
logits
.
clone
()
_
,
vocab_size
=
logits
.
shape
# Prepare sampling tensors with pinned memory to avoid blocking.
...
...
@@ -320,7 +319,7 @@ class Sampler(nn.Module):
sample_logprobs
,
on_device_tensors
=
on_device_tensors
,
skip_sampler_cpu_output
=
sampling_metadata
.
skip_sampler_cpu_output
,
logits
=
original_
logits
)
logits
=
logits
)
@
property
def
_should_modify_greedy_probs_inplace
(
self
)
->
bool
:
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
28375803
...
...
@@ -198,10 +198,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
k
=
draft_token_ids
.
shape
[
-
1
]
output_token_id_list
=
[]
logger
.
info
(
"accept_length:%s"
,
accept_length
)
accept_length_list
=
accept_length
.
cpu
().
tolist
()
logger
.
info
(
"accept_length:%s"
,
accept_length_list
)
for
i
in
range
(
batch_size
):
output_best_candidates
.
append
(
best_candidate
[
i
])
accept_lengths
.
append
(
accept_length
[
i
])
accept_lengths
.
append
(
accept_length
_list
[
i
])
if
not
first_step_flags
[
i
]:
select_indices
=
cart_candidates
[
i
,
best_candidate
[
i
],
:
accept_length
[
i
]
+
1
]
...
...
vllm/sequence.py
View file @
28375803
...
...
@@ -996,9 +996,6 @@ class SequenceGroupMetadata(
# TODO: We should maintain this states out of the sequence group.
num_speculative_tokens
:
Optional
[
int
]
=
None
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
tree_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
if
self
.
seq_data
is
not
None
and
self
.
token_chunk_size
is
None
:
if
self
.
is_prompt
:
...
...
@@ -1036,11 +1033,6 @@ class SequenceGroupMetadata(
assert
self
.
state
.
current_step
<
self
.
state
.
num_steps
self
.
state
.
current_step
+=
1
def
set_tree_style_args
(
self
,
tree_attn_masks
:
Optional
[
torch
.
Tensor
],
tree_position_ids
:
Optional
[
torch
.
Tensor
]):
self
.
tree_attn_masks
=
tree_attn_masks
self
.
tree_position_ids
=
tree_position_ids
class
SequenceOutput
(
msgspec
.
Struct
,
...
...
vllm/spec_decode/medusa_worker.py
View file @
28375803
...
...
@@ -343,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Move the tensors in the dictionary to the specified device
medusa_buffers
=
{
k
:
(
v
.
clone
().
to
(
device
)
if
k
!=
"tree_position_ids"
else
v
.
clone
())
k
:
v
.
clone
().
to
(
device
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
torch
.
tensor
(
v
,
device
=
device
)
for
k
,
v
in
medusa_buffers
.
items
()
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
28375803
...
...
@@ -318,6 +318,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
# tree_style decoding modify probs in _verify_tokens
if
not
self
.
tree_style_spec_decoding
:
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
...
...
@@ -694,7 +697,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores
:
SpeculativeScores
,
proposals
:
SpeculativeProposals
,
max_proposal_len
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
torch
.
Tensor
],
List
[
int
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
List
[
int
]
],
List
[
int
]]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
...
...
@@ -712,7 +715,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
original_indices
=
spec_indices
+
non_spec_indices
# Get probabilities of target model, including bonus tokens.
if
non_spec_indices
:
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
]
else
:
proposal_verifier_probs
=
proposal_scores
.
probs
if
self
.
tree_style_spec_decoding
:
retrieve_indices
=
proposals
.
retrieve_indices
...
...
@@ -722,17 +728,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
# Get bonus tokens from target model.
bonus_token_ids
=
proposal_scores
.
token_ids
[
spec_indices
,
-
1
:]
bonus_token_ids
=
proposal_scores
.
token_ids
[:,
-
1
:]
if
non_spec_indices
:
bonus_token_ids
=
bonus_token_ids
[
spec_indices
,
:]
# Get probabilities according to proposal method.
proposal_probs
=
proposals
.
proposal_probs
[
spec_indices
]
\
if
proposals
.
proposal_probs
is
not
None
else
None
proposal_probs
=
proposals
.
proposal_probs
if
proposals
.
proposal_probs
is
not
None
else
None
if
non_spec_indices
:
proposal_probs
=
proposal_probs
[
spec_indices
]
# Get proposed tokens.
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
proposal_token_ids
=
proposals
.
proposal_token_ids
if
non_spec_indices
:
proposal_token_ids
=
proposal_token_ids
[
spec_indices
]
# Get tree buffers.
cart_candidates
=
proposals
.
cart_candidates
[
spec_indices
]
if
proposals
.
cart_candidates
is
not
None
else
None
cart_candidates
=
proposals
.
cart_candidates
if
proposals
.
cart_candidates
is
not
None
else
None
if
non_spec_indices
:
cart_candidates
=
cart_candidates
[
spec_indices
]
# Sampler arguments
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
...
...
@@ -821,6 +834,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
previous_hidden_state_list
=
[]
retrieve_indices
=
retrieve_indices
.
cpu
()
for
i
in
range
(
batch_size
):
logit
=
logits
[
i
,
best_candidates
[
i
],
accept_lengths
[
i
]].
unsqueeze
(
0
)
previous_logits_list
.
append
(
logit
)
...
...
@@ -865,14 +880,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
model_input
=
self
.
scorer
.
_scorer_worker
.
model_input
block_tables
=
None
if
hasattr
(
model_input
,
'attn_metadata'
)
and
hasattr
(
model_input
.
attn_metadata
,
'block_tables'
):
block_tables
=
model_input
.
attn_metadata
.
block_tables
if
hasattr
(
model_input
,
'attn_metadata'
)
and
hasattr
(
model_input
.
attn_metadata
,
'block_tables
_list
'
):
block_tables
=
model_input
.
attn_metadata
.
block_tables
_list
if
block_tables
is
None
:
raise
RuntimeError
(
"Can not get block_tables from model_input."
)
block_tables
=
block_tables
.
cpu
().
tolist
()
cache_engine
=
self
.
scorer
.
_scorer_worker
.
cache_engines
[
execute_model_req
.
virtual_engine
]
block_size
=
cache_engine
.
block_size
batch_size
=
len
(
select_indices_list
)
...
...
@@ -885,10 +898,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if
accept_legth
>
0
:
select_indices
=
select_indices_list
[
i
][
1
:]
+
seq_lens
[
i
]
select_indices
=
select_indices
.
tolist
()
self
.
compute_slot_mapping
(
select_indices_slot_mapping
,
i
*
block_table_stride
,
select_indices
,
block_size
,
block_tables
)
target_indices
=
torch
.
arange
(
accept_legth
+
1
)[
1
:]
+
seq_lens
[
i
]
target_indices
=
target_indices
.
tolist
()
self
.
compute_slot_mapping
(
target_slot_mapping
,
i
*
block_table_stride
,
target_indices
,
block_size
,
block_tables
)
...
...
@@ -900,12 +915,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
dtype
=
torch
.
long
,
device
=
self
.
device
).
view
(
-
1
,
1
)
src_dst_tensor
=
torch
.
cat
([
select_indices_slot_tensor
,
target_slot_mapping_tensor
],
dim
=-
1
)
#[batch_size*T, 2]
# kv_caches = self.scorer._scorer_worker.kv_cache[execute_model_req.virtual_engine]
# kv_cache_dtype = cache_engine.cache_config.cache_dtype
# backend = cache_engine.attn_backend
# num_kv_heads = cache_engine.num_kv_heads
# head_size = cache_engine.head_size
# backend.move_cache(kv_caches, src_dst_tensor, kv_cache_dtype, num_kv_heads*4, head_size)
self
.
kvcache_slot_to_be_moved
=
src_dst_tensor
def
compute_slot_mapping
(
self
,
slot_mapping
:
List
[
int
],
...
...
vllm/worker/model_runner.py
View file @
28375803
...
...
@@ -470,6 +470,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
block_aligned_sliding_window
=
\
self
.
sliding_window_blocks
*
self
.
block_size
if
hasattr
(
self
.
runner
,
"tree_attn_masks"
):
self
.
tree_attn_masks
=
self
.
runner
.
tree_attn_masks
self
.
tree_position_ids
=
self
.
runner
.
tree_position_ids
else
:
self
.
tree_attn_masks
=
None
self
.
tree_position_ids
=
None
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder_model
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Compute context length, sequence length and tokens
...
...
@@ -511,10 +520,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data
.
input_positions
[
seq_idx
]
=
list
(
range
(
context_len
,
seq_len
))
if
seq_group_metadata
.
tree_position_ids
is
not
None
:
inter_data
.
input_positions
[
seq_idx
]
=
seq_group_metadata
.
tree_position_ids
.
contiguous
().
tolist
()
inter_data
.
tree_attn_masks
[
seq_idx
]
=
seq_group_metadata
.
tree_attn_masks
inter_data
.
query_lens
[
seq_idx
]
=
seq_len
-
context_len
if
inter_data
.
is_prompt
else
1
...
...
@@ -718,7 +723,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len
=
0
if
self
.
runner
.
model_config
.
is_encoder_decoder_model
:
if
self
.
is_encoder_decoder_model
:
encoder_seq_len
=
seq_group_metadata
.
encoder_seq_data
.
get_len
()
inter_data
=
self
.
init_cached_inter_data
(
...
...
@@ -796,7 +801,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if
not
inter_data
.
is_prompt
:
max_decode_seq_len
=
max
(
max_decode_seq_len
,
max
(
inter_data
.
seq_lens
))
if
self
.
runner
.
model_config
.
is_encoder_decoder_model
:
if
self
.
is_encoder_decoder_model
:
max_encoder_seq_len
=
max
(
max_encoder_seq_len
,
inter_data
.
encoder_seq_len
)
...
...
@@ -849,20 +854,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
seq_lens
.
extend
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
# prepare tree attention masks
max_context_len
=
0
for
inter_data
in
self
.
inter_data_list
:
max_context_len
=
max
(
max_context_len
,
max
(
inter_data
.
context_lens
))
tree_attention_masks_list
=
[]
for
inter_data
in
self
.
inter_data_list
:
for
i
in
range
(
len
(
inter_data
.
seq_lens
)):
if
inter_data
.
tree_attn_masks
:
tree_attn_masks
=
inter_data
.
tree_attn_masks
[
i
]
if
tree_attn_masks
is
not
None
:
tree_attention_masks_list
.
append
(
tree_attn_masks
)
tree_attention_masks_tensor
=
None
if
tree_attention_masks_list
:
tree_attention_masks_tensor
=
torch
.
stack
(
tree_attention_masks_list
,
dim
=
0
)
tree_attention_masks_tensor
=
self
.
tree_attn_masks
if
tree_attention_masks_tensor
is
not
None
:
tree_attention_masks_tensor
=
tree_attention_masks_tensor
.
contiguous
()
input_positions_tensor
=
self
.
tree_position_ids
.
contiguous
()
# Attention metadata.
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
...
...
@@ -1039,6 +1034,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
sampling_metadata_cache
:
SamplingMetadataCache
=
\
SamplingMetadataCache
()
self
.
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
self
.
tree_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
DeviceMemoryProfiler
()
as
m
:
...
...
@@ -1506,6 +1504,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
def
set_tree_style_args
(
self
,
tree_attn_masks
:
Optional
[
torch
.
Tensor
],
tree_position_ids
:
Optional
[
torch
.
Tensor
]):
self
.
tree_attn_masks
=
tree_attn_masks
self
.
tree_position_ids
=
tree_position_ids
class
ModelRunner
(
GPUModelRunnerBase
[
ModelInputForGPUWithSamplingMetadata
]):
"""
...
...
vllm/worker/worker_base.py
View file @
28375803
...
...
@@ -283,11 +283,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
# set tree_attn_masks and position ids to seq_group_metadata_list
if
execute_model_req
.
tree_attn_masks
is
not
None
:
for
i
,
seq_group_metadata
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
seq_group_metadata
.
set_tree_style_args
(
tree_attn_masks
=
execute_model_req
.
tree_attn_masks
[
i
],
tree_position_ids
=
execute_model_req
.
tree_position_ids
[
i
])
if
hasattr
(
self
.
model_runner
,
"set_tree_style_args"
):
self
.
model_runner
.
set_tree_style_args
(
tree_attn_masks
=
execute_model_req
.
tree_attn_masks
,
tree_position_ids
=
execute_model_req
.
tree_position_ids
)
model_input
:
ModelRunnerInputBase
=
(
self
.
model_runner
.
prepare_model_input
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment