Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4eabe123
Commit
4eabe123
authored
May 28, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori
parents
45840cd2
58738772
Changes
670
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
342 additions
and
252 deletions
+342
-252
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+9
-8
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+103
-96
vllm/v1/spec_decode/medusa.py
vllm/v1/spec_decode/medusa.py
+5
-18
vllm/v1/spec_decode/metrics.py
vllm/v1/spec_decode/metrics.py
+3
-3
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+27
-0
vllm/v1/structured_output/utils.py
vllm/v1/structured_output/utils.py
+1
-1
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+2
-8
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+2
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+71
-27
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+12
-6
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+12
-5
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+1
-2
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+3
-3
vllm/worker/cpu_enc_dec_model_runner.py
vllm/worker/cpu_enc_dec_model_runner.py
+5
-2
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+4
-1
vllm/worker/cpu_pooling_model_runner.py
vllm/worker/cpu_pooling_model_runner.py
+5
-2
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+1
-2
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+7
-3
vllm/worker/hpu_worker.py
vllm/worker/hpu_worker.py
+5
-8
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+64
-54
No files found.
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
4eabe123
...
@@ -89,18 +89,18 @@ class TopKTopPSampler(nn.Module):
...
@@ -89,18 +89,18 @@ class TopKTopPSampler(nn.Module):
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""More optimized implementation for top-k and top-p sampling."""
"""More optimized implementation for top-k and top-p sampling."""
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
if
k
is
None
and
p
is
None
:
if
k
is
None
and
p
is
None
:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
# CPU-GPU synchronization while `flashinfer_sample` does.
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
return
random_sample
(
probs
,
generators
)
if
generators
:
if
generators
:
logger
.
warning
(
"FlashInfer 0.2.3+ does not support "
logger
.
warning
(
"FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"per-request generators. Falling back to "
"PyTorch-native implementation."
)
"PyTorch-native implementation."
)
return
self
.
forward_native
(
logits
,
generators
,
k
,
p
)
return
self
.
forward_native
(
logits
,
generators
,
k
,
p
)
return
flashinfer_sample
(
prob
s
,
k
,
p
,
generators
)
return
flashinfer_sample
(
logit
s
,
k
,
p
,
generators
)
def
forward_tpu
(
def
forward_tpu
(
self
,
self
,
...
@@ -254,12 +254,12 @@ def random_sample(
...
@@ -254,12 +254,12 @@ def random_sample(
def
flashinfer_sample
(
def
flashinfer_sample
(
prob
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
generators
:
dict
[
int
,
torch
.
Generator
],
generators
:
dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Sample from the
probabilitie
s using FlashInfer.
"""Sample from the
logit
s using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
However, this function is faster because it avoids sorting the logits tensor
...
@@ -274,18 +274,19 @@ def flashinfer_sample(
...
@@ -274,18 +274,19 @@ def flashinfer_sample(
the synchronization overhead.
the synchronization overhead.
"""
"""
assert
not
(
k
is
None
and
p
is
None
)
assert
not
(
k
is
None
and
p
is
None
)
if
k
is
None
:
if
k
is
None
:
# Top-p only.
# Top-p only.
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
next_token_ids
=
flashinfer
.
sampling
.
top_p_sampling_from_probs
(
next_token_ids
=
flashinfer
.
sampling
.
top_p_sampling_from_probs
(
probs
,
p
,
deterministic
=
True
)
probs
,
p
,
deterministic
=
True
)
elif
p
is
None
:
elif
p
is
None
:
# Top-k only.
# Top-k only.
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
next_token_ids
=
flashinfer
.
sampling
.
top_k_sampling_from_probs
(
next_token_ids
=
flashinfer
.
sampling
.
top_k_sampling_from_probs
(
probs
,
k
,
deterministic
=
True
)
probs
,
k
,
deterministic
=
True
)
else
:
else
:
# Both top-k and top-p.
# Both top-k and top-p.
next_token_ids
=
(
flashinfer
.
sampling
.
top_k_top_p_sampling_from_
prob
s
(
next_token_ids
=
flashinfer
.
sampling
.
top_k_top_p_sampling_from_
logit
s
(
prob
s
,
k
,
p
,
deterministic
=
True
)
)
logit
s
,
k
,
p
,
deterministic
=
True
)
return
next_token_ids
.
view
(
-
1
)
return
next_token_ids
.
view
(
-
1
)
vllm/v1/spec_decode/eagle.py
View file @
4eabe123
...
@@ -4,17 +4,17 @@ import torch.nn as nn
...
@@ -4,17 +4,17 @@ import torch.nn as nn
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
,
set_current_vllm_config
)
get_layers_from_vllm_config
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backends.flash_attn
import
(
CommonAttentionMetadata
,
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
FlashAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.utils
import
prepare_eagle_input_kernel
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -27,12 +27,15 @@ class EagleProposer:
...
@@ -27,12 +27,15 @@ class EagleProposer:
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
device
:
torch
.
device
,
runner
=
None
,
):
):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
draft_model_config
=
self
.
speculative_config
.
draft_model_config
self
.
draft_model_config
=
self
.
speculative_config
.
draft_model_config
self
.
method
=
self
.
speculative_config
.
method
self
.
method
=
self
.
speculative_config
.
method
self
.
runner
=
runner
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
...
@@ -108,9 +111,11 @@ class EagleProposer:
...
@@ -108,9 +111,11 @@ class EagleProposer:
# FA requires seq_len to have dtype int32.
# FA requires seq_len to have dtype int32.
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len
=
seq_lens
.
max
().
item
()
max_seq_len
=
seq_lens
.
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
attn_metadata
=
FlashAttentionMetadata
(
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_num_tokens
,
max_query_len
=
max_num_tokens
,
...
@@ -126,6 +131,31 @@ class EagleProposer:
...
@@ -126,6 +131,31 @@ class EagleProposer:
prefix_kv_lens
=
None
,
prefix_kv_lens
=
None
,
suffix_kv_lens
=
None
,
suffix_kv_lens
=
None
,
)
)
elif
self
.
method
==
"deepseek_mtp"
:
query_lens
=
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]
max_query_len
=
query_lens
.
max
().
item
()
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
)
assert
self
.
runner
is
not
None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builder
.
build
(
num_reqs
=
batch_size
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
)
else
:
raise
ValueError
(
f
"Unsupported method:
{
self
.
method
}
"
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata
=
{}
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
if
self
.
use_cuda_graph
and
\
if
self
.
use_cuda_graph
and
\
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
...
@@ -135,14 +165,18 @@ class EagleProposer:
...
@@ -135,14 +165,18 @@ class EagleProposer:
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
per_layer_
attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
):
num_tokens
=
num_input_tokens
):
last_hidden_states
,
hidden_states
=
self
.
model
(
ret_
hidden_states
=
self
.
model
(
input_ids
=
self
.
input_ids
[:
num_input_tokens
],
self
.
input_ids
[:
num_input_tokens
],
positions
=
self
.
positions
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
self
.
hidden_states
[:
num_input_tokens
],
)
)
if
self
.
method
==
"deepseek_mtp"
:
last_hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
...
@@ -152,6 +186,10 @@ class EagleProposer:
...
@@ -152,6 +186,10 @@ class EagleProposer:
# [batch_size, 1]
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
return
draft_token_ids
.
view
(
-
1
,
1
)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
draft_token_ids_list
=
[
draft_token_ids
]
...
@@ -213,13 +251,13 @@ class EagleProposer:
...
@@ -213,13 +251,13 @@ class EagleProposer:
self
.
hidden_states
[:
batch_size
]
=
hidden_states
self
.
hidden_states
[:
batch_size
]
=
hidden_states
# Run the model.
# Run the model.
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
per_layer_
attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
input_batch_size
):
num_tokens
=
input_batch_size
):
last_hidden_states
,
hidden_states
=
self
.
model
(
last_hidden_states
,
hidden_states
=
self
.
model
(
input_ids
=
self
.
input_ids
[:
input_batch_size
],
self
.
input_ids
[:
input_batch_size
],
positions
=
self
.
positions
[:
input_batch_size
],
self
.
positions
[:
input_batch_size
],
hidden_states
=
self
.
hidden_states
[:
input_batch_size
],
self
.
hidden_states
[:
input_batch_size
],
)
)
hidden_states
=
hidden_states
[:
batch_size
]
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
...
@@ -239,6 +277,7 @@ class EagleProposer:
...
@@ -239,6 +277,7 @@ class EagleProposer:
cu_target_query_lens
:
torch
.
Tensor
,
cu_target_query_lens
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
num_rejected_tokens
:
torch
.
Tensor
,
num_rejected_tokens
:
torch
.
Tensor
,
num_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_rejected_tokens: [n1, n2, n3]
...
@@ -256,21 +295,16 @@ class EagleProposer:
...
@@ -256,21 +295,16 @@ class EagleProposer:
# [a - n1, b - n2, c - n3] ->
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens
=
torch
.
empty
_like
(
cu_target_query_lens
)
cu_num_tokens
=
torch
.
zeros
_like
(
cu_target_query_lens
)
torch
.
cumsum
(
num_tokens_per_req
,
dim
=
0
,
out
=
cu_num_tokens
[
1
:])
torch
.
cumsum
(
num_tokens_per_req
,
dim
=
0
,
out
=
cu_num_tokens
[
1
:])
cu_num_tokens
[
0
]
=
0
# FIXME(woosuk): Avoid synchronization.
num_tokens
=
cu_num_tokens
[
-
1
].
item
()
token_indices
=
torch
.
empty
(
token_indices
=
torch
.
empty
(
num_tokens
,
num_tokens
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
cu_
num_tok
ens
.
device
,
device
=
cu_
target_query_l
ens
.
device
,
)
)
batch_size
=
num_rejected_tokens
.
shape
[
0
]
batch_size
=
num_rejected_tokens
.
shape
[
0
]
BLOCK_SIZE
=
1024
BLOCK_SIZE
=
1024
prepare_input_kernel
[(
batch_size
,
)](
prepare_
eagle_
input_kernel
[(
batch_size
,
)](
token_indices
,
token_indices
,
cu_target_query_lens
,
cu_target_query_lens
,
cu_num_tokens
,
cu_num_tokens
,
...
@@ -279,48 +313,28 @@ class EagleProposer:
...
@@ -279,48 +313,28 @@ class EagleProposer:
return
cu_num_tokens
,
token_indices
return
cu_num_tokens
,
token_indices
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
loader
=
get_model_loader
(
self
.
vllm_config
.
load_config
)
draft_model_config
=
\
target_layer_num
=
self
.
vllm_config
.
model_config
.
get_num_layers
(
self
.
vllm_config
.
speculative_config
.
draft_model_config
self
.
vllm_config
.
parallel_config
)
target_attn_layer_names
=
set
(
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
())
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
())
draft_model_config
=
\
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
self
.
vllm_config
.
speculative_config
.
draft_model_config
model_config
=
draft_model_config
)
# FIXME(lily): This does not handle with distributed inference.
target_device
=
self
.
vllm_config
.
device_config
.
device
# We need to set the vllm_config here to register attention
# layers in the forward context.
with
set_default_torch_dtype
(
draft_model_config
.
dtype
),
set_current_vllm_config
(
self
.
vllm_config
):
draft_model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
draft_model_config
.
architectures
)
self
.
model
=
draft_model_cls
(
vllm_config
=
self
.
vllm_config
,
start_layer_id
=
target_layer_num
).
to
(
target_device
)
draft_attn_layer_names
=
(
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
target_attn_layer_names
)
target_attn_layer_names
)
assert
len
(
draft_attn_layer_names
)
==
1
self
.
attn_layer_name
=
next
(
iter
(
draft_attn_layer_names
))
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
loaded_weights
=
self
.
model
.
load_weights
(
loader
.
get_all_weights
(
draft_model_config
,
self
.
model
))
# share embed_tokens with the target model if needed
# share embed_tokens with the target model if needed
if
get_pp_group
().
world_size
==
1
:
if
get_pp_group
().
world_size
==
1
:
assert
"model.embed_tokens.weight"
not
in
loaded_weights
,
\
"For PP = 1, Eagle draft should share embed with target model"
logger
.
info
(
logger
.
info
(
"The EAGLE head shares the same vocab embedding"
\
"The EAGLE head shares the same vocab embedding"
\
" with the target model."
" with the target model."
)
)
self
.
model
.
model
.
embed_tokens
=
target_model
.
model
.
embed_tokens
self
.
model
.
model
.
embed_tokens
=
target_model
.
model
.
embed_tokens
else
:
else
:
assert
"model.embed_tokens.weight"
in
loaded_weights
,
\
"For PP > 1, Eagle draft checkpoint should its own copy of "
" the model.embed_tokens.weight"
logger
.
info
(
logger
.
info
(
"Since PP > 1, the EAGLE head loaded its own vocab embedding"
\
"Since PP > 1, the EAGLE head loaded its own vocab embedding"
\
" weights instead of sharing them with the target model."
" weights instead of sharing them with the target model."
...
@@ -342,11 +356,30 @@ class EagleProposer:
...
@@ -342,11 +356,30 @@ class EagleProposer:
with
set_forward_context
(
None
,
self
.
vllm_config
,
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
num_tokens
=
num_tokens
):
self
.
model
(
self
.
model
(
input_ids
=
self
.
input_ids
[:
num_tokens
],
self
.
input_ids
[:
num_tokens
],
positions
=
self
.
positions
[:
num_tokens
],
self
.
positions
[:
num_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_tokens
],
self
.
hidden_states
[:
num_tokens
],
)
)
def
validate_same_kv_cache_group
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Validate that all eagle layers belong to the same KVCacheGroup.
Need this assumption to ensure all eagle layers can use the
same AttentionMetadata.
May extend to multiple AttentionMetadata in the future.
"""
kv_cache_groups
:
dict
[
str
,
int
]
=
{}
for
id
,
kv_cache_group
in
enumerate
(
kv_cache_config
.
kv_cache_groups
):
for
layer_name
in
kv_cache_group
.
layer_names
:
kv_cache_groups
[
layer_name
]
=
id
assert
len
(
set
([
kv_cache_groups
[
layer_name
]
for
layer_name
in
self
.
attn_layer_names
])
)
==
1
,
"All eagle layers should belong to the same kv cache group"
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# to sample the draft tokens. We will use this after we find a way to manage
...
@@ -389,29 +422,3 @@ def compute_probs_and_sample_next_token(
...
@@ -389,29 +422,3 @@ def compute_probs_and_sample_next_token(
next_token_ids
,
next_token_ids
,
)
)
return
next_token_ids
,
probs
return
next_token_ids
,
probs
@
triton
.
jit
def
prepare_input_kernel
(
out_ptr
,
cu_query_lens_ptr
,
cu_num_tokens_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
# [start_pos, end_pos)
start_pos
=
tl
.
load
(
cu_num_tokens_ptr
+
pid
)
end_pos
=
tl
.
load
(
cu_num_tokens_ptr
+
pid
+
1
)
num_tokens
=
end_pos
-
start_pos
index_start
=
tl
.
load
(
cu_query_lens_ptr
+
pid
)
num_blocks
=
tl
.
cdiv
(
num_tokens
,
BLOCK_SIZE
)
for
i
in
tl
.
range
(
num_blocks
):
offset
=
i
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
tl
.
store
(
out_ptr
+
start_pos
+
offset
,
index_start
+
offset
,
mask
=
offset
<
num_tokens
,
)
vllm/v1/spec_decode/medusa.py
View file @
4eabe123
...
@@ -3,12 +3,10 @@
...
@@ -3,12 +3,10 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models.medusa
import
Medusa
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
# Initialize logger
# Initialize logger
...
@@ -49,20 +47,9 @@ class MedusaProposer:
...
@@ -49,20 +47,9 @@ class MedusaProposer:
return
[
list
(
row
)
for
row
in
zip
(
*
draft_tokens
)]
return
[
list
(
row
)
for
row
in
zip
(
*
draft_tokens
)]
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
# Get model loader and config
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
loader
=
get_model_loader
(
self
.
vllm_config
.
load_config
)
model_config
=
self
.
vllm_config
.
draft_config
=
self
.
vllm_config
.
speculative_config
.
draft_model_config
speculative_config
.
draft_model_config
)
# Load model with proper dtype and config
with
set_default_torch_dtype
(
draft_config
.
dtype
),
\
set_current_vllm_config
(
self
.
vllm_config
):
self
.
model
=
Medusa
(
vllm_config
=
self
.
vllm_config
.
speculative_config
).
to
(
self
.
device
)
# Load model weights
weights
=
loader
.
get_all_weights
(
draft_config
,
self
.
model
)
self
.
model
.
load_weights
(
weights
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
num_tokens
:
int
)
->
None
:
def
dummy_run
(
self
,
num_tokens
:
int
)
->
None
:
...
...
vllm/v1/spec_decode/metrics.py
View file @
4eabe123
...
@@ -134,17 +134,17 @@ class SpecDecodingProm:
...
@@ -134,17 +134,17 @@ class SpecDecodingProm:
self
.
counter_spec_decode_num_drafts
=
\
self
.
counter_spec_decode_num_drafts
=
\
self
.
_counter_cls
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_drafts
_total
"
,
name
=
"vllm:spec_decode_num_drafts"
,
documentation
=
"Number of spec decoding drafts."
,
documentation
=
"Number of spec decoding drafts."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_spec_decode_num_draft_tokens
=
\
self
.
counter_spec_decode_num_draft_tokens
=
\
self
.
_counter_cls
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_draft_tokens
_total
"
,
name
=
"vllm:spec_decode_num_draft_tokens"
,
documentation
=
"Number of draft tokens."
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
,).
labels
(
*
labelvalues
)
labelnames
=
labelnames
,).
labels
(
*
labelvalues
)
self
.
counter_spec_decode_num_accepted_tokens
=
\
self
.
counter_spec_decode_num_accepted_tokens
=
\
self
.
_counter_cls
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_accepted_tokens
_total
"
,
name
=
"vllm:spec_decode_num_accepted_tokens"
,
documentation
=
"Number of accepted tokens."
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
...
...
vllm/v1/spec_decode/utils.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
...
@@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
...
@@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
return
False
return
False
return
True
return
True
@
triton
.
jit
def
prepare_eagle_input_kernel
(
out_ptr
,
cu_query_lens_ptr
,
cu_num_tokens_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
# [start_pos, end_pos)
start_pos
=
tl
.
load
(
cu_num_tokens_ptr
+
pid
)
end_pos
=
tl
.
load
(
cu_num_tokens_ptr
+
pid
+
1
)
num_tokens
=
end_pos
-
start_pos
index_start
=
tl
.
load
(
cu_query_lens_ptr
+
pid
)
num_blocks
=
tl
.
cdiv
(
num_tokens
,
BLOCK_SIZE
)
for
i
in
tl
.
range
(
num_blocks
):
offset
=
i
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
tl
.
store
(
out_ptr
+
start_pos
+
offset
,
index_start
+
offset
,
mask
=
offset
<
num_tokens
,
)
vllm/v1/structured_output/utils.py
View file @
4eabe123
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
re
import
re
gex
as
re
def
grammar_is_likely_lark
(
grammar_str
:
str
)
->
bool
:
def
grammar_is_likely_lark
(
grammar_str
:
str
)
->
bool
:
...
...
vllm/v1/worker/block_table.py
View file @
4eabe123
...
@@ -5,7 +5,6 @@ import torch
...
@@ -5,7 +5,6 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -105,15 +104,10 @@ class MultiGroupBlockTable:
...
@@ -105,15 +104,10 @@ class MultiGroupBlockTable:
def
__init__
(
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
def
__init__
(
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_batched_tokens
:
int
,
pin_memory
:
bool
,
max_num_batched_tokens
:
int
,
pin_memory
:
bool
,
device
:
torch
.
device
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
device
:
torch
.
device
,
block_size
:
int
)
->
None
:
max_num_blocks_per_req
=
[
cdiv
(
max_model_len
,
g
.
kv_cache_spec
.
block_size
)
for
g
in
kv_cache_config
.
kv_cache_groups
]
self
.
block_tables
=
[
self
.
block_tables
=
[
BlockTable
(
max_num_reqs
,
max_num_blocks_per_req
[
i
]
,
BlockTable
(
max_num_reqs
,
cdiv
(
max_model_len
,
block_size
)
,
max_num_batched_tokens
,
pin_memory
,
device
)
max_num_batched_tokens
,
pin_memory
,
device
)
for
i
in
range
(
len
(
kv_cache_config
.
kv_cache_groups
))
]
]
def
append_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
def
append_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
4eabe123
...
@@ -11,7 +11,6 @@ from vllm.lora.request import LoRARequest
...
@@ -11,7 +11,6 @@ from vllm.lora.request import LoRARequest
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
swap_dict_values
from
vllm.utils
import
swap_dict_values
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.utils
import
copy_slice
from
vllm.v1.utils
import
copy_slice
...
@@ -63,7 +62,7 @@ class InputBatch:
...
@@ -63,7 +62,7 @@ class InputBatch:
device
:
torch
.
device
,
device
:
torch
.
device
,
pin_memory
:
bool
,
pin_memory
:
bool
,
vocab_size
:
int
,
vocab_size
:
int
,
kv_cache_config
:
KVCacheConfig
,
block_size
:
int
,
):
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
...
@@ -105,7 +104,7 @@ class InputBatch:
...
@@ -105,7 +104,7 @@ class InputBatch:
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_batched_tokens
=
max_num_batched_tokens
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
device
=
device
,
device
=
device
,
kv_cache_config
=
kv_cache_config
,
block_size
=
block_size
,
)
)
# Sampling-related.
# Sampling-related.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4eabe123
...
@@ -27,15 +27,15 @@ from vllm.distributed.parallel_state import (
...
@@ -27,15 +27,15 @@ from vllm.distributed.parallel_state import (
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
LazyLoader
,
cdiv
,
check_use_alibi
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
is_pin_memory_available
)
check_use_alibi
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
...
@@ -63,6 +63,7 @@ from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
...
@@ -63,6 +63,7 @@ from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
import
xgrammar
as
xgr
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
else
:
else
:
xgr
=
LazyLoader
(
"xgr"
,
globals
(),
"xgrammar"
)
xgr
=
LazyLoader
(
"xgr"
,
globals
(),
"xgrammar"
)
...
@@ -150,12 +151,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -150,12 +151,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
use_aux_hidden_state_outputs
=
False
self
.
use_aux_hidden_state_outputs
=
False
if
self
.
speculative_config
:
if
self
.
speculative_config
:
self
.
use_spec_decode
=
True
self
.
use_spec_decode
=
True
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
if
self
.
speculative_config
.
method
==
"ngram"
:
if
self
.
speculative_config
.
method
==
"ngram"
:
self
.
drafter
=
NgramProposer
(
self
.
vllm_config
)
self
.
drafter
=
NgramProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
elif
self
.
speculative_config
.
use_eagle
():
self
.
drafter
=
EagleProposer
(
self
.
vllm_config
,
self
.
drafter
=
EagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
.
device
)
# type: ignore
self
)
# type: ignore
if
self
.
speculative_config
.
method
==
"eagle3"
:
if
self
.
speculative_config
.
method
==
"eagle3"
:
self
.
use_aux_hidden_state_outputs
=
True
self
.
use_aux_hidden_state_outputs
=
True
elif
self
.
speculative_config
.
method
==
"medusa"
:
elif
self
.
speculative_config
.
method
==
"medusa"
:
...
@@ -170,6 +175,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -170,6 +175,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states.
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
block_size
=
self
.
cache_config
.
block_size
,
)
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
==
CompilationLevel
.
PIECEWISE
and
not
self
.
model_config
.
enforce_eager
)
and
not
self
.
model_config
.
enforce_eager
)
...
@@ -914,8 +929,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -914,8 +929,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs
=
[]
encoder_outputs
=
[]
for
grouped_mm_inputs
in
grouped_mm_inputs_list
:
for
grouped_mm_inputs
in
grouped_mm_inputs_list
:
batched_mm_inputs
=
MultiModalKwargs
.
batch
(
grouped_mm_inputs
)
batched_mm_inputs
=
MultiModalKwargs
.
batch
(
grouped_mm_inputs
)
batched_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
batched_mm_inputs
,
batched_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
device
=
self
.
device
)
batched_mm_inputs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
# Run the encoder.
# Run the encoder.
# `curr_group_outputs` is either of the following:
# `curr_group_outputs` is either of the following:
...
@@ -1348,7 +1366,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1348,7 +1366,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
self
.
device
)
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_name
]
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_names
[
0
]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if
hasattr
(
eagle_attn_metadata
,
"block_table"
):
block_table
=
eagle_attn_metadata
.
block_table
else
:
block_table
=
None
if
spec_decode_metadata
is
None
:
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
# input_ids can be None for multimodal models.
...
@@ -1369,14 +1396,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1369,14 +1396,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
n
+
1
-
len
(
valid_sampled_token_ids
[
i
])
if
n
>
0
else
0
n
+
1
-
len
(
valid_sampled_token_ids
[
i
])
if
n
>
0
else
0
for
i
,
n
in
enumerate
(
num_draft_tokens
)
for
i
,
n
in
enumerate
(
num_draft_tokens
)
]
]
num_rejected_tokens
=
torch
.
tensor
(
num_rejected_tokens
_tensor
=
async_
tensor
_h2d
(
num_rejected_tokens
,
num_rejected_tokens
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
target_device
=
self
.
device
,
)
pin_memory
=
True
)
num_tokens
=
num_scheduled_tokens
-
sum
(
num_rejected_tokens
)
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
eagle_attn_metadata
.
query_start_loc
,
eagle_attn_metadata
.
query_start_loc
,
num_rejected_tokens
,
num_rejected_tokens_tensor
,
num_tokens
,
)
)
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_positions
=
positions
[
token_indices
]
target_positions
=
positions
[
token_indices
]
...
@@ -1387,7 +1416,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1387,7 +1416,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states
=
hidden_states
[
token_indices
]
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
token_indices
]
draft_token_ids
=
self
.
drafter
.
propose
(
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_positions
=
target_positions
,
...
@@ -1395,7 +1423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1395,7 +1423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_slot_mapping
=
target_slot_mapping
,
target_slot_mapping
=
target_slot_mapping
,
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
cu_num_tokens
=
cu_num_tokens
,
cu_num_tokens
=
cu_num_tokens
,
block_table
=
eagle_attn_metadata
.
block_table
,
block_table
=
block_table
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
spec_token_ids
=
draft_token_ids
.
tolist
()
spec_token_ids
=
draft_token_ids
.
tolist
()
...
@@ -1523,6 +1551,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1523,6 +1551,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
time_after_load
-
time_before_load
)
time_after_load
-
time_before_load
)
prepare_communication_buffer_for_model
(
self
.
model
)
prepare_communication_buffer_for_model
(
self
.
model
)
def
save_tensorized_model
(
self
,
tensorizer_config
:
"TensorizerConfig"
,
)
->
None
:
TensorizerLoader
.
save_model
(
self
.
model
,
tensorizer_config
=
tensorizer_config
,
)
def
_get_prompt_logprobs_dict
(
def
_get_prompt_logprobs_dict
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -1703,8 +1740,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1703,8 +1740,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
else
:
hidden_states
=
outputs
hidden_states
=
outputs
if
self
.
use_spec_decode
and
\
if
self
.
use_spec_decode
and
self
.
speculative_config
.
use_eagle
():
self
.
speculative_config
.
method
in
(
'eagle'
,
'eagle3'
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
self
.
drafter
.
dummy_run
(
num_tokens
)
self
.
drafter
.
dummy_run
(
num_tokens
)
...
@@ -1716,6 +1752,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1716,6 +1752,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# The dummy hidden states may contain special values,
# like `inf` or `nan`.
# To avoid breaking the sampler, we use a random tensor here instead.
hidden_states
=
torch
.
rand_like
(
hidden_states
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
num_reqs
=
logits
.
size
(
0
)
num_reqs
=
logits
.
size
(
0
)
...
@@ -1837,7 +1877,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1837,7 +1877,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
batched_dummy_mm_inputs
=
MultiModalKwargs
.
batch
(
batched_dummy_mm_inputs
=
MultiModalKwargs
.
batch
(
[
dummy_mm_kwargs
]
*
max_num_mm_items
)
[
dummy_mm_kwargs
]
*
max_num_mm_items
)
batched_dummy_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
batched_dummy_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
batched_dummy_mm_inputs
,
device
=
self
.
device
)
batched_dummy_mm_inputs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
# Run multimodal encoder.
# Run multimodal encoder.
dummy_encoder_outputs
=
self
.
model
.
get_multimodal_embeddings
(
dummy_encoder_outputs
=
self
.
model
.
get_multimodal_embeddings
(
...
@@ -1947,16 +1990,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1947,16 +1990,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
cache size of each layer
"""
"""
if
len
(
kv_cache_config
.
kv_cache_groups
)
>
1
:
raise
NotImplementedError
(
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
self
.
kv_cache_config
=
kv_cache_config
self
.
kv_cache_config
=
kv_cache_config
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
kv_cache_config
=
kv_cache_config
,
)
self
.
initialize_attn_backend
(
kv_cache_config
)
self
.
initialize_attn_backend
(
kv_cache_config
)
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
...
@@ -1988,6 +2026,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1988,6 +2026,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV cache specs.
# KV cache specs.
raise
ValueError
(
"Unknown KV cache spec type."
)
raise
ValueError
(
"Unknown KV cache spec type."
)
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# validate all draft model layers belong to the same kv cache
# group
self
.
drafter
.
validate_same_kv_cache_group
(
kv_cache_config
)
bind_kv_cache
(
bind_kv_cache
(
kv_caches
,
kv_caches
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
...
...
vllm/v1/worker/gpu_worker.py
View file @
4eabe123
...
@@ -31,6 +31,7 @@ from vllm.v1.worker.worker_base import WorkerBase
...
@@ -31,6 +31,7 @@ from vllm.v1.worker.worker_base import WorkerBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
@@ -171,10 +172,9 @@ class Worker(WorkerBase):
...
@@ -171,10 +172,9 @@ class Worker(WorkerBase):
Then, it calculate the free memory that can be used for KV cache in
Then, it calculate the free memory that can be used for KV cache in
bytes.
bytes.
:::{t
ip
}
T
ip
:
You may limit the usage of GPU memory
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
by adjusting the `gpu_memory_utilization` parameter.
:::
"""
"""
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
reset_peak_memory_stats
()
...
@@ -326,6 +326,13 @@ class Worker(WorkerBase):
...
@@ -326,6 +326,13 @@ class Worker(WorkerBase):
max_size
=
max_size
,
max_size
=
max_size
,
)
)
def
save_tensorized_model
(
self
,
tensorizer_config
:
"TensorizerConfig"
,
)
->
None
:
self
.
model_runner
.
save_tensorized_model
(
tensorizer_config
=
tensorizer_config
,
)
def
init_worker_distributed_environment
(
def
init_worker_distributed_environment
(
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
...
@@ -341,8 +348,7 @@ def init_worker_distributed_environment(
...
@@ -341,8 +348,7 @@ def init_worker_distributed_environment(
distributed_init_method
,
local_rank
)
distributed_init_method
,
local_rank
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
enable_expert_parallel
)
ensure_kv_transfer_initialized
(
vllm_config
)
ensure_kv_transfer_initialized
(
vllm_config
)
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
4eabe123
...
@@ -652,8 +652,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -652,8 +652,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs
=
[]
encoder_outputs
=
[]
for
grouped_mm_inputs
in
grouped_mm_inputs_list
:
for
grouped_mm_inputs
in
grouped_mm_inputs_list
:
batched_mm_inputs
=
MultiModalKwargs
.
batch
(
grouped_mm_inputs
)
batched_mm_inputs
=
MultiModalKwargs
.
batch
(
grouped_mm_inputs
)
batched_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
batched_mm_inputs
,
batched_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
device
=
self
.
device
)
batched_mm_inputs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
# Run the encoder.
# Run the encoder.
# `curr_group_outputs` is either of the following:
# `curr_group_outputs` is either of the following:
...
@@ -1261,7 +1264,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1261,7 +1264,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
kv_cache_config
=
kv_cache_config
,
block_size
=
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
.
block_size
,
)
)
assert
self
.
block_table_cpu
.
dtype
==
self
.
input_batch
.
block_table
[
assert
self
.
block_table_cpu
.
dtype
==
self
.
input_batch
.
block_table
[
0
].
get_cpu_tensor
().
dtype
0
].
get_cpu_tensor
().
dtype
...
@@ -1434,8 +1438,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1434,8 +1438,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
batched_dummy_mm_inputs
=
MultiModalKwargs
.
batch
([
dummy_mm_kwargs
]
*
batched_dummy_mm_inputs
=
MultiModalKwargs
.
batch
([
dummy_mm_kwargs
]
*
batch_size
)
batch_size
)
return
MultiModalKwargs
.
as_kwargs
(
batched_dummy_mm_inputs
,
return
MultiModalKwargs
.
as_kwargs
(
device
=
self
.
device
)
batched_dummy_mm_inputs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
def
_get_req_paddings
(
min_req_size
:
int
,
max_req_size
:
int
)
->
list
[
int
]:
def
_get_req_paddings
(
min_req_size
:
int
,
max_req_size
:
int
)
->
list
[
int
]:
...
...
vllm/v1/worker/tpu_worker.py
View file @
4eabe123
...
@@ -265,8 +265,7 @@ def init_tpu_worker_distributed_environment(
...
@@ -265,8 +265,7 @@ def init_tpu_worker_distributed_environment(
backend
=
"gloo"
,
backend
=
"gloo"
,
)
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
enable_expert_parallel
)
try
:
try
:
...
...
vllm/v1/worker/utils.py
View file @
4eabe123
...
@@ -10,7 +10,7 @@ def sanity_check_mm_encoder_outputs(
...
@@ -10,7 +10,7 @@ def sanity_check_mm_encoder_outputs(
)
->
None
:
)
->
None
:
"""
"""
Perform sanity checks for the result of
Perform sanity checks for the result of
{meth}
`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`.
[
`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`
][]
.
"""
"""
assert
isinstance
(
mm_embeddings
,
(
list
,
tuple
,
torch
.
Tensor
)),
(
assert
isinstance
(
mm_embeddings
,
(
list
,
tuple
,
torch
.
Tensor
)),
(
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
...
@@ -39,7 +39,7 @@ def scatter_mm_placeholders(
...
@@ -39,7 +39,7 @@ def scatter_mm_placeholders(
Scatter the multimodal embeddings into a contiguous tensor that represents
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
the placeholder tokens.
{class}
`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
[
`vllm.multimodal.processing.PromptUpdateDetails.is_embed`
][]
.
Args:
Args:
embeds: The multimodal embeddings.
embeds: The multimodal embeddings.
...
@@ -66,7 +66,7 @@ def gather_mm_placeholders(
...
@@ -66,7 +66,7 @@ def gather_mm_placeholders(
"""
"""
Reconstructs the embeddings from the placeholder tokens.
Reconstructs the embeddings from the placeholder tokens.
This is the operation of
{func}`
scatter_mm_placeholders
`
.
This is the operation of
[
scatter_mm_placeholders
][]
.
"""
"""
if
is_embed
is
None
:
if
is_embed
is
None
:
return
placeholders
return
placeholders
...
...
vllm/worker/cpu_enc_dec_model_runner.py
View file @
4eabe123
...
@@ -297,8 +297,11 @@ class CPUEncoderDecoderModelRunner(
...
@@ -297,8 +297,11 @@ class CPUEncoderDecoderModelRunner(
model_input
.
encoder_input_tokens
,
model_input
.
encoder_input_tokens
,
"encoder_positions"
:
"encoder_positions"
:
model_input
.
encoder_input_positions
,
model_input
.
encoder_input_positions
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
**
MultiModalKwargs
.
as_kwargs
(
device
=
self
.
device
),
model_input
.
multi_modal_kwargs
or
{},
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
"intermediate_tensors"
:
"intermediate_tensors"
:
intermediate_tensors
,
intermediate_tensors
,
}
}
...
...
vllm/worker/cpu_model_runner.py
View file @
4eabe123
...
@@ -628,7 +628,10 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
...
@@ -628,7 +628,10 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
multimodal_kwargs
=
{}
multimodal_kwargs
=
{}
if
model_input
.
multi_modal_kwargs
is
not
None
:
if
model_input
.
multi_modal_kwargs
is
not
None
:
multimodal_kwargs
=
MultiModalKwargs
.
as_kwargs
(
multimodal_kwargs
=
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
,
device
=
self
.
device
)
model_input
.
multi_modal_kwargs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
execute_model_kwargs
=
{}
execute_model_kwargs
=
{}
if
previous_hidden_states
is
not
None
:
if
previous_hidden_states
is
not
None
:
execute_model_kwargs
.
update
(
execute_model_kwargs
.
update
(
...
...
vllm/worker/cpu_pooling_model_runner.py
View file @
4eabe123
...
@@ -50,8 +50,11 @@ class CPUPoolingModelRunner(
...
@@ -50,8 +50,11 @@ class CPUPoolingModelRunner(
model_input
.
input_tokens
,
model_input
.
input_tokens
,
"positions"
:
"positions"
:
model_input
.
input_positions
,
model_input
.
input_positions
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
**
MultiModalKwargs
.
as_kwargs
(
device
=
self
.
device
),
model_input
.
multi_modal_kwargs
or
{},
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
**
cross_enc_kwargs
,
**
cross_enc_kwargs
,
"intermediate_tensors"
:
"intermediate_tensors"
:
intermediate_tensors
,
intermediate_tensors
,
...
...
vllm/worker/cpu_worker.py
View file @
4eabe123
...
@@ -390,8 +390,7 @@ class CPUWorker(LocalOrDistributedWorkerBase):
...
@@ -390,8 +390,7 @@ class CPUWorker(LocalOrDistributedWorkerBase):
ensure_model_parallel_initialized
(
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
enable_expert_parallel
)
def
get_cache_block_size_bytes
(
self
)
->
int
:
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size in bytes of a single KV cache block.
"""Return the size in bytes of a single KV cache block.
...
...
vllm/worker/enc_dec_model_runner.py
View file @
4eabe123
...
@@ -202,9 +202,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -202,9 +202,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
encoder_input_ids
=
model_input
.
encoder_input_tokens
,
encoder_input_ids
=
model_input
.
encoder_input_tokens
,
encoder_positions
=
model_input
.
encoder_input_positions
,
encoder_positions
=
model_input
.
encoder_input_positions
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
**
MultiModalKwargs
.
as_kwargs
(
device
=
self
.
device
),
multi_modal_kwargs
,
**
seqlen_agnostic_kwargs
)
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
**
seqlen_agnostic_kwargs
,
)
logits
=
self
.
model
.
compute_logits
(
hidden_or_intermediate_states
,
logits
=
self
.
model
.
compute_logits
(
hidden_or_intermediate_states
,
model_input
.
sampling_metadata
)
model_input
.
sampling_metadata
)
...
...
vllm/worker/hpu_worker.py
View file @
4eabe123
...
@@ -201,10 +201,9 @@ class HPUWorker(LocalOrDistributedWorkerBase):
...
@@ -201,10 +201,9 @@ class HPUWorker(LocalOrDistributedWorkerBase):
Then, it calculate the maximum possible number of GPU and CPU blocks
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
that can be allocated with the remaining free memory.
:::{t
ip
}
T
ip
:
You may limit the usage of GPU memory
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
by adjusting the `gpu_memory_utilization` parameter.
:::
"""
"""
# Profile the memory usage of the model and get the maximum number of
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
# cache blocks that can be allocated with the remaining free memory.
...
@@ -416,8 +415,7 @@ def init_worker_distributed_environment(
...
@@ -416,8 +415,7 @@ def init_worker_distributed_environment(
backend
=
'hccl'
)
backend
=
'hccl'
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
enable_expert_parallel
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
torch_world_size
=
torch
.
distributed
.
get_world_size
()
torch_world_size
=
torch
.
distributed
.
get_world_size
()
...
@@ -443,8 +441,7 @@ def init_worker_distributed_environment(
...
@@ -443,8 +441,7 @@ def init_worker_distributed_environment(
torch
.
distributed
.
all_reduce
(
dummy_tensor_hpu
)
torch
.
distributed
.
all_reduce
(
dummy_tensor_hpu
)
assert
dummy_tensor_hpu
.
item
()
==
parallel_config
.
world_size
assert
dummy_tensor_hpu
.
item
()
==
parallel_config
.
world_size
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
enable_expert_parallel
)
def
raise_if_cache_size_invalid
(
num_gpu_blocks
,
block_size
,
max_model_len
,
def
raise_if_cache_size_invalid
(
num_gpu_blocks
,
block_size
,
max_model_len
,
...
...
vllm/worker/model_runner.py
View file @
4eabe123
...
@@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState
...
@@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
from
vllm.distributed.kv_transfer
import
get_kv_transfer_group
from
vllm.distributed.kv_transfer
import
get_kv_transfer_group
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
graph_capture
)
graph_capture
)
...
@@ -729,7 +729,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -729,7 +729,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
mm_kwargs
,
placeholder_maps
=
MultiModalPlaceholderMap
.
from_seq_group
(
mm_kwargs
,
placeholder_maps
=
MultiModalPlaceholderMap
.
from_seq_group
(
seq_group_metadata
,
seq_group_metadata
,
range
(
positions
[
0
],
positions
[
0
]
+
len
(
positions
)))
range
(
positions
[
0
],
positions
[
0
]
+
len
(
positions
)))
if
not
mm_kwargs
:
# M-RoPE requires mrope_positions even for plain text; return early
# when mm_kwargs is empty only if inter_data.is_prompt is False.
if
not
mm_kwargs
and
not
inter_data
.
is_prompt
:
return
return
inter_data
.
multi_modal_kwargs
=
mm_kwargs
inter_data
.
multi_modal_kwargs
=
mm_kwargs
...
@@ -741,12 +744,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -741,12 +744,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
video_grid_thw
=
mm_kwargs
.
get
(
"video_grid_thw"
,
None
)
video_grid_thw
=
mm_kwargs
.
get
(
"video_grid_thw"
,
None
)
audio_feature_lengths
=
mm_kwargs
.
get
(
"audio_feature_lengths"
,
audio_feature_lengths
=
mm_kwargs
.
get
(
"audio_feature_lengths"
,
None
)
None
)
assert
(
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
or
audio_feature_lengths
is
not
None
),
(
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw' or "
"'audio_feature_lengths'."
)
second_per_grid_ts
=
mm_kwargs
.
get
(
"second_per_grid_ts"
,
None
)
second_per_grid_ts
=
mm_kwargs
.
get
(
"second_per_grid_ts"
,
None
)
use_audio_in_video
=
mm_kwargs
.
get
(
"use_audio_in_video"
,
False
)
use_audio_in_video
=
mm_kwargs
.
get
(
"use_audio_in_video"
,
False
)
...
@@ -872,7 +869,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -872,7 +869,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""
"""
# Combine and flatten intermediate data.
# Combine and flatten intermediate data.
input_tokens
=
list
[
int
]()
input_tokens
=
list
[
int
]()
inputs_embeds_lst
=
list
[
torch
.
Tensor
]()
inputs_embeds_l
i
st
=
list
[
torch
.
Tensor
]()
token_types
=
list
[
int
]()
token_types
=
list
[
int
]()
for
inter_data
in
self
.
inter_data_list
:
for
inter_data
in
self
.
inter_data_list
:
for
cur_input_tokens
in
inter_data
.
input_tokens
:
for
cur_input_tokens
in
inter_data
.
input_tokens
:
...
@@ -880,15 +877,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -880,15 +877,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for
cur_token_types
in
inter_data
.
token_types
:
for
cur_token_types
in
inter_data
.
token_types
:
token_types
.
extend
(
cur_token_types
)
token_types
.
extend
(
cur_token_types
)
if
inter_data
.
inputs_embeds
is
not
None
:
if
inter_data
.
inputs_embeds
is
not
None
:
inputs_embeds_lst
.
append
(
inputs_embeds_l
i
st
.
append
(
inter_data
.
inputs_embeds
.
to
(
inter_data
.
inputs_embeds
.
to
(
dtype
=
self
.
runner
.
model_config
.
dtype
,
dtype
=
self
.
runner
.
model_config
.
dtype
,
device
=
self
.
runner
.
device
))
device
=
self
.
runner
.
device
))
inputs_embeds
:
Optional
[
torch
.
Tensor
]
inputs_embeds
:
Optional
[
torch
.
Tensor
]
if
len
(
inputs_embeds_lst
)
==
0
:
if
len
(
inputs_embeds_l
i
st
)
==
0
:
inputs_embeds
=
None
inputs_embeds
=
None
else
:
else
:
inputs_embeds
=
torch
.
cat
(
inputs_embeds_lst
,
dim
=
0
).
to
(
inputs_embeds
=
torch
.
cat
(
inputs_embeds_l
i
st
,
dim
=
0
).
to
(
dtype
=
self
.
runner
.
model_config
.
dtype
,
dtype
=
self
.
runner
.
model_config
.
dtype
,
device
=
self
.
runner
.
device
)
device
=
self
.
runner
.
device
)
assert
len
(
inputs_embeds
)
==
len
(
input_tokens
)
assert
len
(
inputs_embeds
)
==
len
(
input_tokens
)
...
@@ -1848,8 +1845,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1848,8 +1845,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
inputs_embeds
=
model_input
.
inputs_embeds
,
inputs_embeds
=
model_input
.
inputs_embeds
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
**
MultiModalKwargs
.
as_kwargs
(
device
=
self
.
device
),
multi_modal_kwargs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
**
seqlen_agnostic_kwargs
,
**
seqlen_agnostic_kwargs
,
**
model_kwargs
,
**
model_kwargs
,
)
)
...
@@ -1893,15 +1893,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1893,15 +1893,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
logits
=
self
.
model
.
compute_logits
(
hidden_or_intermediate_states
,
logits
=
self
.
model
.
compute_logits
(
hidden_or_intermediate_states
,
model_input
.
sampling_metadata
)
model_input
.
sampling_metadata
)
if
not
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
return
[]
if
model_input
.
async_callback
is
not
None
:
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
model_input
.
async_callback
()
# Sample the next token.
# Sample the next token.
assert
isinstance
(
self
.
sampler
,
Sampler
)
assert
isinstance
(
self
.
sampler
,
Sampler
)
orig_include_gpu_probs
_tensor
=
self
.
sampler
.
include_gpu_probs_tensor
orig_include_gpu_probs
=
self
.
sampler
.
include_gpu_probs_tensor
if
model_input
.
inputs_embeds
is
not
None
:
if
model_input
.
inputs_embeds
is
not
None
:
self
.
sampler
.
include_gpu_probs_tensor
=
True
self
.
sampler
.
include_gpu_probs_tensor
=
True
...
@@ -1919,24 +1917,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1919,24 +1917,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if
intermediate_tensors
is
not
None
:
if
intermediate_tensors
is
not
None
:
orig_model_forward_time
=
intermediate_tensors
.
tensors
.
get
(
orig_model_forward_time
=
intermediate_tensors
.
tensors
.
get
(
"model_forward_time"
,
torch
.
tensor
(
0.0
)).
item
()
"model_forward_time"
,
torch
.
tensor
(
0.0
)).
item
()
# If there are multiple workers, we are still tracking the
latency
# If there are multiple workers, we are still tracking the
#
from the start time of the driver worker to the end
time of the
# latency
from the start time of the driver worker to the end
#
driver worker. The model forward time will then
end up covering
# time of the
driver worker. The model forward time will then
#
the communication time as well.
# end up covering
the communication time as well.
output
.
model_forward_time
=
(
orig_model_forward_time
+
output
.
model_forward_time
=
(
orig_model_forward_time
+
model_forward_time
)
model_forward_time
)
if
model_input
.
inputs_embeds
is
not
None
:
if
model_input
.
inputs_embeds
is
not
None
:
if
self
.
is_driver_worker
:
sampled
=
broadcast_tensor_dict
(
{
"token_ids"
:
output
.
sampled_token_ids
})
else
:
sampled
=
broadcast_tensor_dict
()
if
sampled
[
"token_ids"
]
is
not
None
:
sampled_token_embeds
=
self
.
model
.
get_input_embeddings
(
sampled
[
"token_ids"
].
squeeze
(
1
))
if
self
.
is_driver_worker
:
self
.
sampler
.
include_gpu_probs_tensor
=
\
self
.
sampler
.
include_gpu_probs_tensor
=
\
orig_include_gpu_probs_tensor
orig_include_gpu_probs
if
output
.
sampled_token_ids
is
not
None
:
output
.
sampled_token_embeds
=
self
.
model
.
get_input_embeddings
(
output
.
sampled_token_embeds
=
sampled_token_embeds
output
.
sampled_token_ids
.
squeeze
(
1
))
for
token_embed
,
sequence_group_output
in
zip
(
for
token_embed
,
sequence_group_output
in
zip
(
output
.
sampled_token_embeds
,
output
.
outputs
):
output
.
sampled_token_embeds
,
output
.
outputs
):
assert
len
(
sequence_group_output
.
samples
)
==
1
assert
len
(
sequence_group_output
.
samples
)
==
1
sequence_group_output
.
samples
[
0
].
output_embed
=
token_embed
sequence_group_output
.
samples
[
0
].
output_embed
=
token_embed
if
not
self
.
is_driver_worker
:
return
[]
if
self
.
return_hidden_states
:
if
self
.
return_hidden_states
:
# we only need to pass hidden states of most recent token
# we only need to pass hidden states of most recent token
...
...
Prev
1
…
29
30
31
32
33
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment