Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
28375803
Commit
28375803
authored
Nov 27, 2024
by
王敏
Browse files
1.优化medusa推理,节省cpu耗时
2.更新medusa readme 3.解决benchmark_moe报错问题
parent
3c9817d2
Changes
16
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
765 additions
and
81 deletions
+765
-81
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+3
-1
examples/medusa/README.md
examples/medusa/README.md
+15
-11
examples/medusa/medusa_benchmark_throughput.py
examples/medusa/medusa_benchmark_throughput.py
+664
-0
examples/medusa/medusa_weight_converter.py
examples/medusa/medusa_weight_converter.py
+3
-3
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+6
-2
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+5
-2
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+2
-1
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+5
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+1
-2
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+3
-2
vllm/sequence.py
vllm/sequence.py
+0
-8
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+1
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+29
-19
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+23
-20
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+3
-5
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
28375803
...
@@ -304,7 +304,9 @@ def main(args: argparse.Namespace):
...
@@ -304,7 +304,9 @@ def main(args: argparse.Namespace):
else
:
else
:
batch_sizes
=
[
args
.
batch_size
]
batch_sizes
=
[
args
.
batch_size
]
ray
.
init
()
ray
.
init
(
address
=
None
,
ignore_reinit_error
=
True
,
num_gpus
=
args
.
tp_size
)
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
...
...
examples/medusa/README.md
View file @
28375803
...
@@ -17,9 +17,9 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
...
@@ -17,9 +17,9 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
# medusa 模型需要转换为vllm中Medusa的模型格式
# medusa 模型需要转换为vllm中Medusa的模型格式
```
bash
```
bash
python medusa_weight_converter.py
--medusa_num_heads
4
--medusa_num_layers
1
--medusa_model_path
/work/
medusa/qwen2_72b_head_4/adapter_
model.bin
--vocab_size
152064
--hidden_size
8192
--output_dir
/work/medusa/
sugon/
vllm-medusa-qwen2-72b-head-4
--medusa_choices
=
"[(0
,
), (0, 0), (0, 0, 0), (0,
0, 0, 0
), (
0,
1), (1,), (
1
, 0), (0, 0, 1), (0,
1, 0), (0,
2), (
1
,
0
, 0), (2
,
), (
2, 0)
,
(
0,
3
), (0,
0, 2), (0, 2
, 0
)
,
(0, 4), (0, 0, 1, 0), (0, 1, 0,
0), (2, 0, 0),
(3,),
(0,
5
), (0, 0,
0, 1), (3, 0), (0, 0,
3), (
1,
0, 0, 0
), (0, 3
,
0
), (0,
6
), (0, 0,
4), (0, 4
, 0), (
1
,
1), (4,
)]"
python medusa_weight_converter.py
--medusa_num_heads
4
--medusa_num_layers
1
--medusa_model_path
/work/model.bin
--vocab_size
152064
--hidden_size
8192
--output_dir
/work/medusa/vllm-medusa-qwen2-72b-head-4
--medusa_choices
=
"[(0), (0, 0), (0, 0, 0), (0,
1
), (1), (1,
0
), (
0, 0, 0
, 0), (0, 0, 1), (0, 2), (
0
,
1
, 0), (2), (
0
, 0,
2
), (0,
3), (1
, 0, 0), (2, 0
), (0, 2
, 0), (0,
4
), (0, 0,
3), (
3), (0, 0, 0,
1
), (0,
5
), (0, 0,
1
, 0), (
0
,
0, 4
)]"
```
```
此处
qwen2_72b_head_4是medusa模型使用peft lora训练后保存的权重,其他格式也可参考[medusa_weight_converter.py]修改进行权重转换
此处
model.bin是训练后保存的medusa head权重
### Run
### Run
...
@@ -28,30 +28,34 @@ python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --m
...
@@ -28,30 +28,34 @@ python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --m
python3
-m
vllm.entrypoints.openai.api_server
\
python3
-m
vllm.entrypoints.openai.api_server
\
--served-model-name
qwen_medusa
\
--served-model-name
qwen_medusa
\
--model
/models/Qwen2-72B-Instruct/
-tp
4
\
--model
/models/Qwen2-72B-Instruct/
-tp
4
\
--max-model-len
1024
--max-num-seqs
8
--gpu-memory-utilization
0.
7
\
--max-model-len
1024
--max-num-seqs
8
--gpu-memory-utilization
0.
8
\
--speculative-model
/work/medusa/vllm-medusa-qwen2-72b-head-4
\
--speculative-model
/work/medusa/vllm-medusa-qwen2-72b-head-4
\
--speculative-draft-tensor-parallel-size
4
\
--speculative-draft-tensor-parallel-size
4
\
--speculative-disable-by-batch-size
4
\
--speculative-disable-by-batch-size
4
\
--use-v2-block-manager
\
--use-v2-block-manager
\
--spec-decoding-acceptance-method
typical_acceptance_sampler
\
--spec-decoding-acceptance-method
typical_acceptance_sampler
\
--enforce-eager
--dtype
float16
--trust-remote-code
--port
8086
\
--dtype
float16
--trust-remote-code
--port
8086
\
--enable-lora
--lora-modules
medusa-lora
=
/work/qwen2_72b_head_4
\
--max-lora-rank
32
--lora-extra-vocab-size
0
--merge-lora
True
\
--lora-target-modules
qkv_proj
\
--tree-style-spec-decoding
True
\
--tree-style-spec-decoding
True
\
--num-speculative-heads
4
--num-speculative-tokens
33
--num-speculative-heads
4
--num-speculative-tokens
24
```
```
merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数
merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) +
2
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) +
1
# do request
# do request
```
bash
```
bash
curl http://localhost:8086/v1/completions
\
curl http://localhost:8086/v1/completions
\
-H
"Content-Type: application/json"
\
-H
"Content-Type: application/json"
\
-d
'{
-d
'{
"model": "medus
a-lor
a",
"model": "
qwen_
medusa",
"prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n",
"prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n",
"max_tokens": 256,
"max_tokens": 256,
"temperature": 0.0
"temperature": 0.0
}'
}'
```
bash
```
### benchmark
python medusa_benchmark_throughput.py --model /data/llm-models/qwen2/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 1 --dataset /work/test/medusa_benchmark_data.json --max-model-len 4096 --gpu-memory-utilization 0.9
可设置max-num-seqs对不同的batch进行性能测试
examples/medusa/medusa_benchmark_throughput.py
0 → 100644
View file @
28375803
This diff is collapsed.
Click to expand it.
examples/medusa/medusa_weight_converter.py
View file @
28375803
...
@@ -20,9 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -20,9 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE
=
64
DEFAULT_VOCAB_PADDING_SIZE
=
64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE
=
'
base_model.model.
medusa_head.{}.{}.linear.weight'
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE
=
'medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE
=
'
base_model.model.
medusa_head.{}.1.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE
=
'medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE
=
'
base_model.model.
medusa_head.{}.{}.linear.bias'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE
=
'medusa_head.{}.{}.linear.bias'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE
=
'blocks.{}.layers.{}.weight'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE
=
'blocks.{}.layers.{}.weight'
VLLM_BLOCK_BIAS_NAME_TEMPLATE
=
'blocks.{}.layers.{}.bias'
VLLM_BLOCK_BIAS_NAME_TEMPLATE
=
'blocks.{}.layers.{}.bias'
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
28375803
...
@@ -238,6 +238,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -238,6 +238,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
block_tables_list
:
Optional
[
List
[
int
]]
=
None
@
property
@
property
def
prefill_metadata
(
def
prefill_metadata
(
self
)
->
Optional
[
"BlocksparseFlashAttentionMetadata"
]:
self
)
->
Optional
[
"BlocksparseFlashAttentionMetadata"
]:
...
@@ -269,7 +271,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -269,7 +271,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
)
return
self
.
_cached_prefill_metadata
return
self
.
_cached_prefill_metadata
...
@@ -298,7 +301,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -298,7 +301,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
use_cuda_graph
=
self
.
use_cuda_graph
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
28375803
...
@@ -168,6 +168,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -168,6 +168,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
_cached_decode_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
block_tables_list
:
Optional
[
List
[
int
]]
=
None
@
property
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"ROCmFlashAttentionMetadata"
]:
def
prefill_metadata
(
self
)
->
Optional
[
"ROCmFlashAttentionMetadata"
]:
...
@@ -199,7 +200,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -199,7 +200,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
)
return
self
.
_cached_prefill_metadata
return
self
.
_cached_prefill_metadata
...
@@ -228,7 +230,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -228,7 +230,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
use_cuda_graph
=
self
.
use_cuda_graph
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
...
...
vllm/attention/backends/utils.py
View file @
28375803
...
@@ -272,7 +272,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -272,7 +272,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor
=
context_lens_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
tree_attention_masks_tensor
=
tree_attention_masks_tensor
tree_attention_masks_tensor
=
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables
)
)
...
...
vllm/attention/backends/xformers.py
View file @
28375803
...
@@ -190,6 +190,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -190,6 +190,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
block_tables_list
:
Optional
[
List
[
int
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# Set during the execution of the first attention op.
...
@@ -271,7 +272,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -271,7 +272,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
,
cross_block_tables
=
self
.
cross_block_tables
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
)
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
return
self
.
_cached_prefill_metadata
return
self
.
_cached_prefill_metadata
@
property
@
property
...
@@ -311,7 +313,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -311,7 +313,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
,
cross_block_tables
=
self
.
cross_block_tables
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
)
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables_list
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
...
...
vllm/engine/llm_engine.py
View file @
28375803
...
@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
...
@@ -41,7 +41,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
,
VLLM_INVALID_TOKEN_ID
)
SequenceStatus
,
CompletionSequenceGroupOutput
,
VLLM_INVALID_TOKEN_ID
)
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.config
import
try_get_generation_config
...
@@ -989,7 +989,7 @@ class LLMEngine:
...
@@ -989,7 +989,7 @@ class LLMEngine:
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
# tree style speculative decoding may generate empty output in first step
# tree style speculative decoding may generate empty output in first step
if
outputs
and
isinstance
(
output
[
0
],
Sa
mple
r
Output
):
if
outputs
and
isinstance
(
output
[
0
],
Co
mple
tionSequenceGroup
Output
):
samples
=
[
o
.
samples
[
0
]
for
o
in
output
]
samples
=
[
o
.
samples
[
0
]
for
o
in
output
]
valid_samples
=
[
valid_samples
=
[
sample
for
sample
in
samples
sample
for
sample
in
samples
...
...
vllm/model_executor/layers/sampler.py
View file @
28375803
...
@@ -235,7 +235,6 @@ class Sampler(nn.Module):
...
@@ -235,7 +235,6 @@ class Sampler(nn.Module):
sampling_metadata: Metadata for sampling.
sampling_metadata: Metadata for sampling.
"""
"""
assert
logits
is
not
None
assert
logits
is
not
None
original_logits
=
logits
.
clone
()
_
,
vocab_size
=
logits
.
shape
_
,
vocab_size
=
logits
.
shape
# Prepare sampling tensors with pinned memory to avoid blocking.
# Prepare sampling tensors with pinned memory to avoid blocking.
...
@@ -320,7 +319,7 @@ class Sampler(nn.Module):
...
@@ -320,7 +319,7 @@ class Sampler(nn.Module):
sample_logprobs
,
sample_logprobs
,
on_device_tensors
=
on_device_tensors
,
on_device_tensors
=
on_device_tensors
,
skip_sampler_cpu_output
=
sampling_metadata
.
skip_sampler_cpu_output
,
skip_sampler_cpu_output
=
sampling_metadata
.
skip_sampler_cpu_output
,
logits
=
original_
logits
)
logits
=
logits
)
@
property
@
property
def
_should_modify_greedy_probs_inplace
(
self
)
->
bool
:
def
_should_modify_greedy_probs_inplace
(
self
)
->
bool
:
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
28375803
...
@@ -198,10 +198,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -198,10 +198,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
k
=
draft_token_ids
.
shape
[
-
1
]
k
=
draft_token_ids
.
shape
[
-
1
]
output_token_id_list
=
[]
output_token_id_list
=
[]
logger
.
info
(
"accept_length:%s"
,
accept_length
)
accept_length_list
=
accept_length
.
cpu
().
tolist
()
logger
.
info
(
"accept_length:%s"
,
accept_length_list
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
output_best_candidates
.
append
(
best_candidate
[
i
])
output_best_candidates
.
append
(
best_candidate
[
i
])
accept_lengths
.
append
(
accept_length
[
i
])
accept_lengths
.
append
(
accept_length
_list
[
i
])
if
not
first_step_flags
[
i
]:
if
not
first_step_flags
[
i
]:
select_indices
=
cart_candidates
[
i
,
best_candidate
[
i
],
:
accept_length
[
i
]
+
1
]
select_indices
=
cart_candidates
[
i
,
best_candidate
[
i
],
:
accept_length
[
i
]
+
1
]
...
...
vllm/sequence.py
View file @
28375803
...
@@ -996,9 +996,6 @@ class SequenceGroupMetadata(
...
@@ -996,9 +996,6 @@ class SequenceGroupMetadata(
# TODO: We should maintain this states out of the sequence group.
# TODO: We should maintain this states out of the sequence group.
num_speculative_tokens
:
Optional
[
int
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
tree_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
seq_data
is
not
None
and
self
.
token_chunk_size
is
None
:
if
self
.
seq_data
is
not
None
and
self
.
token_chunk_size
is
None
:
if
self
.
is_prompt
:
if
self
.
is_prompt
:
...
@@ -1036,11 +1033,6 @@ class SequenceGroupMetadata(
...
@@ -1036,11 +1033,6 @@ class SequenceGroupMetadata(
assert
self
.
state
.
current_step
<
self
.
state
.
num_steps
assert
self
.
state
.
current_step
<
self
.
state
.
num_steps
self
.
state
.
current_step
+=
1
self
.
state
.
current_step
+=
1
def
set_tree_style_args
(
self
,
tree_attn_masks
:
Optional
[
torch
.
Tensor
],
tree_position_ids
:
Optional
[
torch
.
Tensor
]):
self
.
tree_attn_masks
=
tree_attn_masks
self
.
tree_position_ids
=
tree_position_ids
class
SequenceOutput
(
class
SequenceOutput
(
msgspec
.
Struct
,
msgspec
.
Struct
,
...
...
vllm/spec_decode/medusa_worker.py
View file @
28375803
...
@@ -343,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -343,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Move the tensors in the dictionary to the specified device
# Move the tensors in the dictionary to the specified device
medusa_buffers
=
{
medusa_buffers
=
{
k
:
(
v
.
clone
().
to
(
device
)
if
k
!=
"tree_position_ids"
else
v
.
clone
())
k
:
v
.
clone
().
to
(
device
)
if
isinstance
(
v
,
torch
.
Tensor
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
torch
.
tensor
(
v
,
device
=
device
)
else
torch
.
tensor
(
v
,
device
=
device
)
for
k
,
v
in
medusa_buffers
.
items
()
for
k
,
v
in
medusa_buffers
.
items
()
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
28375803
...
@@ -318,8 +318,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -318,8 +318,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
"""
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
)
=
True
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
# tree_style decoding modify probs in _verify_tokens
if
not
self
.
tree_style_spec_decoding
:
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
self
.
proposer_worker
.
set_should_modify_greedy_probs_inplace
()
self
.
proposer_worker
.
set_should_modify_greedy_probs_inplace
()
...
@@ -694,7 +697,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -694,7 +697,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores
:
SpeculativeScores
,
proposal_scores
:
SpeculativeScores
,
proposals
:
SpeculativeProposals
,
proposals
:
SpeculativeProposals
,
max_proposal_len
:
int
,
max_proposal_len
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
torch
.
Tensor
],
List
[
int
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
List
[
int
]
],
List
[
int
]]:
"""Determine which speculative tokens are accepted using the
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
probabilities of each token according to the proposer and scorer models.
...
@@ -712,7 +715,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -712,7 +715,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
original_indices
=
spec_indices
+
non_spec_indices
original_indices
=
spec_indices
+
non_spec_indices
# Get probabilities of target model, including bonus tokens.
# Get probabilities of target model, including bonus tokens.
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
]
if
non_spec_indices
:
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
]
else
:
proposal_verifier_probs
=
proposal_scores
.
probs
if
self
.
tree_style_spec_decoding
:
if
self
.
tree_style_spec_decoding
:
retrieve_indices
=
proposals
.
retrieve_indices
retrieve_indices
=
proposals
.
retrieve_indices
...
@@ -722,17 +728,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -722,17 +728,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
# Get bonus tokens from target model.
# Get bonus tokens from target model.
bonus_token_ids
=
proposal_scores
.
token_ids
[
spec_indices
,
-
1
:]
bonus_token_ids
=
proposal_scores
.
token_ids
[:,
-
1
:]
if
non_spec_indices
:
bonus_token_ids
=
bonus_token_ids
[
spec_indices
,
:]
# Get probabilities according to proposal method.
# Get probabilities according to proposal method.
proposal_probs
=
proposals
.
proposal_probs
[
spec_indices
]
\
proposal_probs
=
proposals
.
proposal_probs
if
proposals
.
proposal_probs
is
not
None
else
None
if
proposals
.
proposal_probs
is
not
None
else
None
if
non_spec_indices
:
proposal_probs
=
proposal_probs
[
spec_indices
]
# Get proposed tokens.
# Get proposed tokens.
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
proposal_token_ids
=
proposals
.
proposal_token_ids
if
non_spec_indices
:
proposal_token_ids
=
proposal_token_ids
[
spec_indices
]
# Get tree buffers.
# Get tree buffers.
cart_candidates
=
proposals
.
cart_candidates
[
spec_indices
]
if
proposals
.
cart_candidates
is
not
None
else
None
cart_candidates
=
proposals
.
cart_candidates
if
proposals
.
cart_candidates
is
not
None
else
None
if
non_spec_indices
:
cart_candidates
=
cart_candidates
[
spec_indices
]
# Sampler arguments
# Sampler arguments
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
...
@@ -820,6 +833,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -820,6 +833,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
previous_logits_list
=
[]
previous_logits_list
=
[]
previous_hidden_state_list
=
[]
previous_hidden_state_list
=
[]
retrieve_indices
=
retrieve_indices
.
cpu
()
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
logit
=
logits
[
i
,
best_candidates
[
i
],
accept_lengths
[
i
]].
unsqueeze
(
0
)
logit
=
logits
[
i
,
best_candidates
[
i
],
accept_lengths
[
i
]].
unsqueeze
(
0
)
...
@@ -865,13 +880,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -865,13 +880,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
model_input
=
self
.
scorer
.
_scorer_worker
.
model_input
model_input
=
self
.
scorer
.
_scorer_worker
.
model_input
block_tables
=
None
block_tables
=
None
if
hasattr
(
model_input
,
'attn_metadata'
)
and
hasattr
(
model_input
.
attn_metadata
,
'block_tables'
):
if
hasattr
(
model_input
,
'attn_metadata'
)
and
hasattr
(
model_input
.
attn_metadata
,
'block_tables
_list
'
):
block_tables
=
model_input
.
attn_metadata
.
block_tables
block_tables
=
model_input
.
attn_metadata
.
block_tables
_list
if
block_tables
is
None
:
if
block_tables
is
None
:
raise
RuntimeError
(
"Can not get block_tables from model_input."
)
raise
RuntimeError
(
"Can not get block_tables from model_input."
)
block_tables
=
block_tables
.
cpu
().
tolist
()
cache_engine
=
self
.
scorer
.
_scorer_worker
.
cache_engines
[
execute_model_req
.
virtual_engine
]
cache_engine
=
self
.
scorer
.
_scorer_worker
.
cache_engines
[
execute_model_req
.
virtual_engine
]
block_size
=
cache_engine
.
block_size
block_size
=
cache_engine
.
block_size
...
@@ -885,10 +898,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -885,10 +898,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if
accept_legth
>
0
:
if
accept_legth
>
0
:
select_indices
=
select_indices_list
[
i
][
1
:]
+
seq_lens
[
i
]
select_indices
=
select_indices_list
[
i
][
1
:]
+
seq_lens
[
i
]
select_indices
=
select_indices
.
tolist
()
self
.
compute_slot_mapping
(
select_indices_slot_mapping
,
i
*
block_table_stride
,
self
.
compute_slot_mapping
(
select_indices_slot_mapping
,
i
*
block_table_stride
,
select_indices
,
block_size
,
block_tables
)
select_indices
,
block_size
,
block_tables
)
target_indices
=
torch
.
arange
(
accept_legth
+
1
)[
1
:]
+
seq_lens
[
i
]
target_indices
=
torch
.
arange
(
accept_legth
+
1
)[
1
:]
+
seq_lens
[
i
]
target_indices
=
target_indices
.
tolist
()
self
.
compute_slot_mapping
(
target_slot_mapping
,
i
*
block_table_stride
,
self
.
compute_slot_mapping
(
target_slot_mapping
,
i
*
block_table_stride
,
target_indices
,
block_size
,
block_tables
)
target_indices
,
block_size
,
block_tables
)
...
@@ -900,12 +915,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -900,12 +915,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
).
view
(
-
1
,
1
)
device
=
self
.
device
).
view
(
-
1
,
1
)
src_dst_tensor
=
torch
.
cat
([
select_indices_slot_tensor
,
target_slot_mapping_tensor
],
dim
=-
1
)
#[batch_size*T, 2]
src_dst_tensor
=
torch
.
cat
([
select_indices_slot_tensor
,
target_slot_mapping_tensor
],
dim
=-
1
)
#[batch_size*T, 2]
# kv_caches = self.scorer._scorer_worker.kv_cache[execute_model_req.virtual_engine]
# kv_cache_dtype = cache_engine.cache_config.cache_dtype
# backend = cache_engine.attn_backend
# num_kv_heads = cache_engine.num_kv_heads
# head_size = cache_engine.head_size
# backend.move_cache(kv_caches, src_dst_tensor, kv_cache_dtype, num_kv_heads*4, head_size)
self
.
kvcache_slot_to_be_moved
=
src_dst_tensor
self
.
kvcache_slot_to_be_moved
=
src_dst_tensor
def
compute_slot_mapping
(
self
,
slot_mapping
:
List
[
int
],
def
compute_slot_mapping
(
self
,
slot_mapping
:
List
[
int
],
...
...
vllm/worker/model_runner.py
View file @
28375803
...
@@ -469,6 +469,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -469,6 +469,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
self
.
block_aligned_sliding_window
=
\
self
.
block_aligned_sliding_window
=
\
self
.
sliding_window_blocks
*
self
.
block_size
self
.
sliding_window_blocks
*
self
.
block_size
if
hasattr
(
self
.
runner
,
"tree_attn_masks"
):
self
.
tree_attn_masks
=
self
.
runner
.
tree_attn_masks
self
.
tree_position_ids
=
self
.
runner
.
tree_position_ids
else
:
self
.
tree_attn_masks
=
None
self
.
tree_position_ids
=
None
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder_model
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
seq_group_metadata
:
SequenceGroupMetadata
):
...
@@ -511,10 +520,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -511,10 +520,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data
.
input_positions
[
seq_idx
]
=
list
(
range
(
context_len
,
seq_len
))
inter_data
.
input_positions
[
seq_idx
]
=
list
(
range
(
context_len
,
seq_len
))
if
seq_group_metadata
.
tree_position_ids
is
not
None
:
inter_data
.
input_positions
[
seq_idx
]
=
seq_group_metadata
.
tree_position_ids
.
contiguous
().
tolist
()
inter_data
.
tree_attn_masks
[
seq_idx
]
=
seq_group_metadata
.
tree_attn_masks
inter_data
.
query_lens
[
inter_data
.
query_lens
[
seq_idx
]
=
seq_len
-
context_len
if
inter_data
.
is_prompt
else
1
seq_idx
]
=
seq_len
-
context_len
if
inter_data
.
is_prompt
else
1
...
@@ -718,7 +723,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -718,7 +723,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len
=
0
encoder_seq_len
=
0
if
self
.
runner
.
model_config
.
is_encoder_decoder_model
:
if
self
.
is_encoder_decoder_model
:
encoder_seq_len
=
seq_group_metadata
.
encoder_seq_data
.
get_len
()
encoder_seq_len
=
seq_group_metadata
.
encoder_seq_data
.
get_len
()
inter_data
=
self
.
init_cached_inter_data
(
inter_data
=
self
.
init_cached_inter_data
(
...
@@ -796,7 +801,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -796,7 +801,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if
not
inter_data
.
is_prompt
:
if
not
inter_data
.
is_prompt
:
max_decode_seq_len
=
max
(
max_decode_seq_len
,
max_decode_seq_len
=
max
(
max_decode_seq_len
,
max
(
inter_data
.
seq_lens
))
max
(
inter_data
.
seq_lens
))
if
self
.
runner
.
model_config
.
is_encoder_decoder_model
:
if
self
.
is_encoder_decoder_model
:
max_encoder_seq_len
=
max
(
max_encoder_seq_len
,
max_encoder_seq_len
=
max
(
max_encoder_seq_len
,
inter_data
.
encoder_seq_len
)
inter_data
.
encoder_seq_len
)
...
@@ -847,22 +852,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -847,22 +852,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Sequence and query lengths.
# Sequence and query lengths.
if
cuda_graph_pad_size
:
if
cuda_graph_pad_size
:
seq_lens
.
extend
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
seq_lens
.
extend
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
# prepare tree attention masks
# prepare tree attention masks
max_context_len
=
0
tree_attention_masks_tensor
=
self
.
tree_attn_masks
for
inter_data
in
self
.
inter_data_list
:
if
tree_attention_masks_tensor
is
not
None
:
max_context_len
=
max
(
max_context_len
,
max
(
inter_data
.
context_lens
))
tree_attention_masks_list
=
[]
for
inter_data
in
self
.
inter_data_list
:
for
i
in
range
(
len
(
inter_data
.
seq_lens
)):
if
inter_data
.
tree_attn_masks
:
tree_attn_masks
=
inter_data
.
tree_attn_masks
[
i
]
if
tree_attn_masks
is
not
None
:
tree_attention_masks_list
.
append
(
tree_attn_masks
)
tree_attention_masks_tensor
=
None
if
tree_attention_masks_list
:
tree_attention_masks_tensor
=
torch
.
stack
(
tree_attention_masks_list
,
dim
=
0
)
tree_attention_masks_tensor
=
tree_attention_masks_tensor
.
contiguous
()
tree_attention_masks_tensor
=
tree_attention_masks_tensor
.
contiguous
()
input_positions_tensor
=
self
.
tree_position_ids
.
contiguous
()
# Attention metadata.
# Attention metadata.
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
...
@@ -1038,6 +1033,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1038,6 +1033,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
inter_data_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
self
.
inter_data_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
self
.
sampling_metadata_cache
:
SamplingMetadataCache
=
\
self
.
sampling_metadata_cache
:
SamplingMetadataCache
=
\
SamplingMetadataCache
()
SamplingMetadataCache
()
self
.
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
self
.
tree_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
...
@@ -1505,6 +1503,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1505,6 +1503,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@
property
@
property
def
vocab_size
(
self
)
->
int
:
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
return
self
.
model_config
.
get_vocab_size
()
def
set_tree_style_args
(
self
,
tree_attn_masks
:
Optional
[
torch
.
Tensor
],
tree_position_ids
:
Optional
[
torch
.
Tensor
]):
self
.
tree_attn_masks
=
tree_attn_masks
self
.
tree_position_ids
=
tree_position_ids
class
ModelRunner
(
GPUModelRunnerBase
[
ModelInputForGPUWithSamplingMetadata
]):
class
ModelRunner
(
GPUModelRunnerBase
[
ModelInputForGPUWithSamplingMetadata
]):
...
...
vllm/worker/worker_base.py
View file @
28375803
...
@@ -283,11 +283,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -283,11 +283,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
# set tree_attn_masks and position ids to seq_group_metadata_list
if
hasattr
(
self
.
model_runner
,
"set_tree_style_args"
):
if
execute_model_req
.
tree_attn_masks
is
not
None
:
self
.
model_runner
.
set_tree_style_args
(
tree_attn_masks
=
execute_model_req
.
tree_attn_masks
,
for
i
,
seq_group_metadata
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
tree_position_ids
=
execute_model_req
.
tree_position_ids
)
seq_group_metadata
.
set_tree_style_args
(
tree_attn_masks
=
execute_model_req
.
tree_attn_masks
[
i
],
tree_position_ids
=
execute_model_req
.
tree_position_ids
[
i
])
model_input
:
ModelRunnerInputBase
=
(
model_input
:
ModelRunnerInputBase
=
(
self
.
model_runner
.
prepare_model_input
(
self
.
model_runner
.
prepare_model_input
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment