Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96ae75ad
Commit
96ae75ad
authored
Jan 04, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev
parents
f9f4a735
2339d59f
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
559 additions
and
158 deletions
+559
-158
vllm/v1/sample/ops/penalties.py
vllm/v1/sample/ops/penalties.py
+59
-0
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+201
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+72
-95
vllm/v1/utils.py
vllm/v1/utils.py
+0
-25
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+142
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+48
-10
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+0
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+7
-9
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+1
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+13
-6
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+8
-7
vllm/worker/pooling_model_runner.py
vllm/worker/pooling_model_runner.py
+6
-1
vllm/worker/utils.py
vllm/worker/utils.py
+1
-1
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+1
-1
No files found.
vllm/v1/sample/ops/penalties.py
0 → 100644
View file @
96ae75ad
from
typing
import
List
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
def
apply_min_token_penalties
(
logits
:
torch
.
Tensor
,
output_token_ids
:
List
[
List
[
int
]],
stop_token_ids
:
List
[
Set
[
int
]],
min_tokens
:
List
[
int
])
->
None
:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
index
,
min_token
in
enumerate
(
min_tokens
):
if
len
(
output_token_ids
[
index
])
<
min_token
:
for
stop_token_id
in
stop_token_ids
[
index
]:
min_tokens_logits_to_penalize
.
append
((
index
,
stop_token_id
))
if
min_tokens_logits_to_penalize
:
logits
[
tuple
(
zip
(
*
min_tokens_logits_to_penalize
))]
=
-
float
(
"inf"
)
def
apply_all_penalties
(
logits
:
torch
.
Tensor
,
prompt_token_ids
:
torch
.
Tensor
,
presence_penalties
:
torch
.
Tensor
,
frequency_penalties
:
torch
.
Tensor
,
repetition_penalties
:
torch
.
Tensor
,
output_token_ids
:
List
[
List
[
int
]],
)
->
torch
.
Tensor
:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_
,
vocab_size
=
logits
.
shape
output_tokens_t
=
_convert_to_tensors
(
output_token_ids
,
vocab_size
,
logits
.
device
)
return
apply_penalties
(
logits
,
prompt_token_ids
,
output_tokens_t
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
)
def
_convert_to_tensors
(
output_token_ids
:
List
[
List
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor
=
make_tensor_with_pad
(
output_token_ids
,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad
=
vocab_size
,
device
=
"cpu"
,
dtype
=
torch
.
int64
,
pin_memory
=
is_pin_memory_available
(),
)
return
output_tokens_tensor
.
to
(
device
,
non_blocking
=
True
)
vllm/v1/sample/ops/topk_topp_sampler.py
0 → 100644
View file @
96ae75ad
from
typing
import
Dict
import
torch
import
torch.nn
as
nn
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
try
:
import
flashinfer.sampling
is_flashinfer_available
=
True
except
ImportError
:
is_flashinfer_available
=
False
class
TopKTopPSampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
if
current_platform
.
is_cuda
:
if
is_flashinfer_available
:
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
is
not
False
:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger
.
info
(
"Using FlashInfer for top-p & top-k sampling."
)
self
.
forward
=
self
.
forward_cuda
else
:
logger
.
warning
(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"top-p & top-k sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1."
)
self
.
forward
=
self
.
forward_native
else
:
logger
.
warning
(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FalshInfer."
)
self
.
forward
=
self
.
forward_native
else
:
self
.
forward
=
self
.
forward_native
def
forward_native
(
self
,
logits
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""PyTorch-native implementation of top-k and top-p sampling."""
logits
=
apply_top_k_top_p
(
logits
,
no_top_k
,
k
,
no_top_p
,
p
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
def
forward_cuda
(
self
,
logits
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""More optimized implementation for top-k and top-p sampling."""
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
if
no_top_k
and
no_top_p
:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
return
random_sample
(
probs
,
generators
)
return
flashinfer_sample
(
probs
,
no_top_k
,
k
,
no_top_p
,
p
,
generators
)
def
apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
"""
if
no_top_k
and
no_top_p
:
return
logits
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
if
not
no_top_k
:
# Apply top-k.
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# Get all the top_k values.
top_k_mask
=
logits_sort
.
gather
(
1
,
top_k_mask
.
unsqueeze
(
dim
=
1
))
top_k_mask
=
logits_sort
<
top_k_mask
logits_sort
.
masked_fill_
(
top_k_mask
,
-
float
(
"inf"
))
if
not
no_top_p
:
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
probs_sum
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
# at least one
top_p_mask
[:,
-
1
]
=
False
logits_sort
.
masked_fill_
(
top_p_mask
,
-
float
(
"inf"
))
# Re-sort the probabilities.
logits
=
logits_sort
.
scatter
(
dim
=-
1
,
index
=
logits_idx
,
src
=
logits_sort
)
return
logits
def
random_sample
(
probs
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q
=
torch
.
empty_like
(
probs
)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if
len
(
generators
)
!=
probs
.
shape
[
0
]:
q
.
exponential_
()
if
generators
:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for
i
,
generator
in
generators
.
items
():
q
[
i
].
exponential_
(
generator
=
generator
)
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
)
def
flashinfer_sample
(
probs
:
torch
.
Tensor
,
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
"""Sample from the probabilities using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert
not
(
no_top_k
and
no_top_p
)
max_top_k_round
=
32
batch_size
=
probs
.
shape
[
0
]
uniform_samples
=
torch
.
empty
((
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
if
len
(
generators
)
!=
batch_size
:
uniform_samples
.
uniform_
()
if
generators
:
for
i
,
generator
in
generators
.
items
():
uniform_samples
[:,
i
].
uniform_
(
generator
=
generator
)
if
no_top_k
:
# Top-p only.
next_token_ids
,
success
=
flashinfer
.
sampling
.
top_p_sampling_from_probs
(
probs
,
uniform_samples
,
p
,
deterministic
=
True
)
elif
no_top_p
:
# Top-k only.
next_token_ids
,
success
=
flashinfer
.
sampling
.
top_k_sampling_from_probs
(
probs
,
uniform_samples
,
k
,
deterministic
=
True
)
else
:
# Both top-k and top-p.
next_token_ids
,
success
=
(
flashinfer
.
sampling
.
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
k
,
p
,
deterministic
=
True
))
# NOTE: CPU-GPU synchronization happens here.
if
not
success
.
all
():
if
not
no_top_k
:
probs
=
flashinfer
.
sampling
.
top_k_renorm_prob
(
probs
,
k
)
if
not
no_top_p
:
probs
=
flashinfer
.
sampling
.
top_p_renorm_prob
(
probs
,
p
)
next_token_ids
=
flashinfer
.
sampling
.
sampling_from_probs
(
probs
,
uniform_samples
[
0
],
deterministic
=
True
)
return
next_token_ids
.
view
(
-
1
)
vllm/v1/sample/sampler.py
View file @
96ae75ad
"""A layer that samples the next tokens from the model's outputs."""
from
typing
import
Dict
from
typing
import
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.v1.outputs
import
SamplerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.ops.penalties
import
(
apply_all_penalties
,
apply_min_token_penalties
)
from
vllm.v1.sample.ops.topk_topp_sampler
import
TopKTopPSampler
_SAMPLING_EPS
=
1e-5
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
logits
=
self
.
apply_temperature
(
logits
,
sampling_metadata
.
temperature
)
logits
=
self
.
apply_top_k_top_p
(
logits
,
sampling_metadata
)
probs
=
self
.
get_probs
(
logits
)
sampled
=
self
.
sample
(
probs
,
sampling_metadata
)
# Use int32 to reduce the tensor size.
sampled
=
sampled
.
to
(
torch
.
int32
)
if
sampling_metadata
.
max_num_logprobs
>
0
:
logprobs
=
self
.
get_logprobs
(
logits
)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs
,
topk_indices
=
torch
.
topk
(
logprobs
,
sampling_metadata
.
max_num_logprobs
,
dim
=-
1
)
# Use int32 to reduce the tensor size.
topk_indices
=
topk_indices
.
to
(
torch
.
int32
)
needs_logprobs
=
sampling_metadata
.
max_num_logprobs
>
0
if
needs_logprobs
:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# NOTE: We compute logprobs first because the below ops may
# modify the logits tensor in-place (and we don't want to clone
# the logits tensor for memory efficiency).
topk_logprobs
,
topk_indices
=
self
.
get_topk_logprobs
(
logits
,
sampling_metadata
)
else
:
topk_logprobs
=
None
topk_indices
=
None
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
# Apply temperature.
logits
=
self
.
apply_temperature
(
logits
,
sampling_metadata
.
temperature
)
# Sample the next token.
sampled
=
self
.
sample
(
logits
,
sampling_metadata
)
# Use int32 to reduce the tensor size.
sampled
=
sampled
.
to
(
torch
.
int32
)
# NOTE: CPU-GPU synchronization happens here.
sampler_output
=
SamplerOutput
(
sampled_token_ids
=
sampled
.
tolist
(),
...
...
@@ -52,71 +65,37 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
temp
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Use float32 to apply temperature scaling.
logits
=
logits
.
to
(
torch
.
float32
)
# Avoid division by zero.
temp
=
torch
.
where
(
temp
<
_SAMPLING_EPS
,
1.0
,
temp
)
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
temp
.
unsqueeze
(
dim
=
1
))
return
logits
def
apply_top_k_top_p
(
def
greedy_sample
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
logits
.
argmax
(
dim
=-
1
).
view
(
-
1
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
return
_apply_top_k_top_p
(
assert
not
(
sampling_metadata
.
all_greedy
and
sampling_metadata
.
all_random
)
if
sampling_metadata
.
all_greedy
:
return
self
.
greedy_sample
(
logits
)
random_sampled
=
self
.
topk_topp_sampler
(
logits
,
sampling_metadata
.
generators
,
sampling_metadata
.
no_top_k
,
sampling_metadata
.
top_k
,
sampling_metadata
.
no_top_p
,
sampling_metadata
.
top_p
,
)
def
get_probs
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
def
get_logprobs
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
def
greedy_sample
(
self
,
probs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
probs
.
argmax
(
dim
=-
1
).
view
(
-
1
)
def
random_sample
(
self
,
probs
:
torch
.
Tensor
,
generators
:
Dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
q
=
torch
.
empty_like
(
probs
)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if
len
(
generators
)
!=
probs
.
shape
[
0
]:
# This might still be done here unnecessarily if there are greedies
q
.
exponential_
()
if
generators
:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for
i
,
generator
in
generators
.
items
():
q
[
i
].
exponential_
(
generator
=
generator
)
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
)
def
sample
(
self
,
probs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
assert
not
(
sampling_metadata
.
all_greedy
and
sampling_metadata
.
all_random
)
if
sampling_metadata
.
all_greedy
:
return
self
.
greedy_sample
(
probs
)
if
sampling_metadata
.
all_random
:
return
self
.
random_sample
(
probs
,
sampling_metadata
.
generators
)
return
random_sample
d
greedy_sampled
=
self
.
greedy_sample
(
probs
)
random_sampled
=
self
.
random_sample
(
probs
,
sampling_metadata
.
generators
)
greedy_sampled
=
self
.
greedy_sample
(
logits
)
sampled
=
torch
.
where
(
sampling_metadata
.
temperature
<
_SAMPLING_EPS
,
greedy_sampled
,
...
...
@@ -124,36 +103,34 @@ class Sampler(nn.Module):
)
return
sampled
def
get_topk_logprobs
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
logprobs
=
logits
.
log_softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs
,
topk_indices
=
torch
.
topk
(
logprobs
,
sampling_metadata
.
max_num_logprobs
,
dim
=-
1
)
# Use int32 to reduce the tensor size.
topk_indices
=
topk_indices
.
to
(
torch
.
int32
)
return
topk_logprobs
,
topk_indices
# TODO(woosuk): Optimize this with a custom kernel.
def
_apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
no_top_k
:
bool
,
k
:
torch
.
Tensor
,
no_top_p
:
bool
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
no_top_k
and
no_top_p
:
def
apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
apply_min_token_penalties
(
logits
,
sampling_metadata
.
output_token_ids
,
sampling_metadata
.
stop_token_ids
,
sampling_metadata
.
min_tokens
)
if
not
sampling_metadata
.
no_penalties
:
assert
sampling_metadata
.
prompt_token_ids
is
not
None
logits
=
apply_all_penalties
(
logits
,
sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
presence_penalties
,
sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
output_token_ids
)
return
logits
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
if
not
no_top_k
:
# Apply top-k.
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# Get all the top_k values.
top_k_mask
=
logits_sort
.
gather
(
1
,
top_k_mask
.
unsqueeze
(
dim
=
1
))
top_k_mask
=
logits_sort
<
top_k_mask
logits_sort
.
masked_fill_
(
top_k_mask
,
-
float
(
"inf"
))
if
not
no_top_p
:
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
probs_sum
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
# at least one
top_p_mask
[:,
-
1
]
=
False
logits_sort
.
masked_fill_
(
top_p_mask
,
-
float
(
"inf"
))
# Re-sort the probabilities.
logits
=
logits_sort
.
scatter
(
dim
=-
1
,
index
=
logits_idx
,
src
=
logits_sort
)
return
logits
vllm/v1/utils.py
View file @
96ae75ad
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
Generic
,
Iterator
,
List
,
Optional
,
TypeVar
,
Union
,
...
...
@@ -102,27 +101,3 @@ def make_zmq_socket(
finally
:
ctx
.
destroy
(
linger
=
0
)
K
=
TypeVar
(
'K'
)
V
=
TypeVar
(
'V'
)
class
LRUDictCache
(
Generic
[
K
,
V
]):
def
__init__
(
self
,
size
:
int
):
self
.
cache
:
OrderedDict
[
K
,
V
]
=
OrderedDict
()
self
.
size
=
size
def
get
(
self
,
key
:
K
,
default
=
None
)
->
V
:
if
key
not
in
self
.
cache
:
return
default
self
.
cache
.
move_to_end
(
key
)
return
self
.
cache
[
key
]
def
put
(
self
,
key
:
K
,
value
:
V
):
self
.
cache
[
key
]
=
value
self
.
cache
.
move_to_end
(
key
)
if
len
(
self
.
cache
)
>
self
.
size
:
self
.
cache
.
popitem
(
last
=
False
)
vllm/v1/worker/gpu_input_batch.py
View file @
96ae75ad
...
...
@@ -43,12 +43,14 @@ class InputBatch:
max_num_blocks_per_req
:
int
,
device
:
torch
.
device
,
pin_memory
:
bool
,
vocab_size
:
int
,
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_num_blocks_per_req
=
max_num_blocks_per_req
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
vocab_size
=
vocab_size
self
.
req_ids
:
List
[
Optional
[
str
]]
=
[
None
]
*
max_num_reqs
self
.
req_id_to_index
:
Dict
[
str
,
int
]
=
{}
...
...
@@ -63,6 +65,7 @@ class InputBatch:
)
self
.
token_ids_cpu
=
self
.
token_ids_cpu_tensor
.
numpy
()
self
.
num_computed_tokens_cpu
=
np
.
empty
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
# Attention-related.
self
.
block_table
=
torch
.
zeros
(
...
...
@@ -110,6 +113,50 @@ class InputBatch:
self
.
top_k_cpu
=
self
.
top_k_cpu_tensor
.
numpy
()
self
.
top_k_reqs
:
Set
[
str
]
=
set
()
# Frequency penalty related data structures
self
.
frequency_penalties
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
device
)
self
.
frequency_penalties_cpu_tensor
=
torch
.
empty
(
(
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
frequency_penalties_cpu
=
\
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_reqs
:
Set
[
str
]
=
set
()
# Presence penalty related data structures
self
.
presence_penalties
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
device
)
self
.
presence_penalties_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
presence_penalties_cpu
=
\
self
.
presence_penalties_cpu_tensor
.
numpy
()
self
.
presence_penalties_reqs
:
Set
[
str
]
=
set
()
# Repetition penalty related data structures
self
.
repetition_penalties
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
device
)
self
.
repetition_penalties_cpu_tensor
=
torch
.
empty
(
(
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
repetition_penalties_cpu
=
\
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
Set
[
str
]
=
set
()
self
.
min_tokens
:
List
[
int
]
=
[
0
]
*
max_num_reqs
self
.
stop_token_ids
:
List
[
Set
[
int
]]
=
[
set
()
for
_
in
range
(
max_num_reqs
)
]
self
.
prompt_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
...
...
@@ -133,6 +180,7 @@ class InputBatch:
# Copy the prompt token ids and output token ids.
num_prompt_tokens
=
len
(
request
.
prompt_token_ids
)
self
.
num_prompt_tokens
[
req_index
]
=
num_prompt_tokens
self
.
token_ids_cpu
[
req_index
,
:
num_prompt_tokens
]
=
request
.
prompt_token_ids
start_idx
=
num_prompt_tokens
...
...
@@ -157,6 +205,20 @@ class InputBatch:
self
.
top_k_cpu
[
req_index
]
=
sampling_params
.
top_k
if
sampling_params
.
top_k
>
0
:
self
.
top_k_reqs
.
add
(
req_id
)
self
.
frequency_penalties_cpu
[
req_index
]
=
\
sampling_params
.
frequency_penalty
if
sampling_params
.
frequency_penalty
!=
0.0
:
self
.
frequency_penalties_reqs
.
add
(
req_id
)
self
.
presence_penalties_cpu
[
req_index
]
=
\
sampling_params
.
presence_penalty
if
sampling_params
.
presence_penalty
!=
0.0
:
self
.
presence_penalties_reqs
.
add
(
req_id
)
self
.
repetition_penalties_cpu
[
req_index
]
=
\
sampling_params
.
repetition_penalty
if
sampling_params
.
repetition_penalty
!=
1.0
:
self
.
repetition_penalties_reqs
.
add
(
req_id
)
self
.
min_tokens
[
req_index
]
=
sampling_params
.
min_tokens
self
.
stop_token_ids
[
req_index
]
=
sampling_params
.
all_stop_token_ids
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
...
...
@@ -179,6 +241,9 @@ class InputBatch:
self
.
random_reqs
.
discard
(
req_id
)
self
.
top_p_reqs
.
discard
(
req_id
)
self
.
top_k_reqs
.
discard
(
req_id
)
self
.
frequency_penalties_reqs
.
discard
(
req_id
)
self
.
presence_penalties_reqs
.
discard
(
req_id
)
self
.
repetition_penalties_reqs
.
discard
(
req_id
)
self
.
generators
.
pop
(
req_index
,
None
)
self
.
num_logprobs
.
pop
(
req_id
,
None
)
self
.
prompt_logprob_reqs
.
discard
(
req_id
)
...
...
@@ -191,6 +256,9 @@ class InputBatch:
self
.
random_reqs
.
clear
()
self
.
top_p_reqs
.
clear
()
self
.
top_k_reqs
.
clear
()
self
.
frequency_penalties_reqs
.
clear
()
self
.
presence_penalties_reqs
.
clear
()
self
.
repetition_penalties_reqs
.
clear
()
self
.
generators
.
clear
()
self
.
num_logprobs
.
clear
()
self
.
prompt_logprob_reqs
.
clear
()
...
...
@@ -224,6 +292,8 @@ class InputBatch:
# block_table_cpu.
self
.
token_ids_cpu
[
empty_index
]
=
self
.
token_ids_cpu
[
last_req_index
]
self
.
num_prompt_tokens
[
empty_index
]
=
\
self
.
num_prompt_tokens
[
last_req_index
]
self
.
num_computed_tokens_cpu
[
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
self
.
block_table_cpu
[
empty_index
]
=
self
.
block_table_cpu
[
...
...
@@ -232,6 +302,15 @@ class InputBatch:
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
self
.
top_k_cpu
[
empty_index
]
=
self
.
top_k_cpu
[
last_req_index
]
self
.
frequency_penalties_cpu
[
empty_index
]
=
\
self
.
frequency_penalties_cpu
[
last_req_index
]
self
.
presence_penalties_cpu
[
empty_index
]
=
\
self
.
presence_penalties_cpu
[
last_req_index
]
self
.
repetition_penalties_cpu
[
empty_index
]
=
\
self
.
repetition_penalties_cpu
[
last_req_index
]
self
.
min_tokens
[
empty_index
]
=
self
.
min_tokens
[
last_req_index
]
self
.
stop_token_ids
[
empty_index
]
=
\
self
.
stop_token_ids
[
last_req_index
]
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
if
generator
is
not
None
:
self
.
generators
[
empty_index
]
=
generator
...
...
@@ -241,6 +320,7 @@ class InputBatch:
def
make_sampling_metadata
(
self
,
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]],
skip_copy
:
bool
=
False
,
)
->
SamplingMetadata
:
if
not
skip_copy
:
...
...
@@ -250,6 +330,37 @@ class InputBatch:
self
.
top_p_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
self
.
top_k
[:
self
.
num_reqs
].
copy_
(
self
.
top_k_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
if
not
self
.
no_penalties
:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
self
.
frequency_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
frequency_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
self
.
presence_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
presence_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
self
.
repetition_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
repetition_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
self
.
prompt_token_ids
=
self
.
_make_prompt_token_ids_tensor
()
output_token_ids
:
List
[
List
[
int
]]
=
[]
for
req_id
in
self
.
req_ids
[:
self
.
num_reqs
]:
assert
req_id
is
not
None
# Currently we create a tensor for output_token_ids from scratch
# at each step. However, for the penalties computation what we
# need is stats about the token ids present in the output. This
# stats can be maintained incrementally instead of computing it
# from scratch at each step.
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids
.
append
(
req_id_output_token_ids
[
req_id
])
return
SamplingMetadata
(
temperature
=
self
.
temperature
[:
self
.
num_reqs
],
all_greedy
=
self
.
all_greedy
,
...
...
@@ -260,8 +371,33 @@ class InputBatch:
no_top_k
=
self
.
no_top_k
,
generators
=
self
.
generators
,
max_num_logprobs
=
self
.
max_num_logprobs
,
prompt_token_ids
=
self
.
prompt_token_ids
,
frequency_penalties
=
self
.
frequency_penalties
[:
self
.
num_reqs
],
presence_penalties
=
self
.
presence_penalties
[:
self
.
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
self
.
num_reqs
],
output_token_ids
=
output_token_ids
,
min_tokens
=
self
.
min_tokens
[:
self
.
num_reqs
],
stop_token_ids
=
self
.
stop_token_ids
[:
self
.
num_reqs
],
no_penalties
=
self
.
no_penalties
,
)
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
(
self
.
num_reqs
,
max_prompt_len
),
device
=
"cpu"
,
dtype
=
torch
.
int64
,
pin_memory
=
self
.
pin_memory
)
prompt_token_ids
=
prompt_token_ids_cpu_tensor
.
numpy
()
prompt_token_ids
[:]
=
(
self
.
token_ids_cpu
[:
self
.
num_reqs
,
:
max_prompt_len
])
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for
i
in
range
(
self
.
num_reqs
):
prompt_token_ids
[
i
,
self
.
num_prompt_tokens
[
i
]:]
=
self
.
vocab_size
return
prompt_token_ids_cpu_tensor
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
@
property
def
num_reqs
(
self
)
->
int
:
return
len
(
self
.
req_id_to_index
)
...
...
@@ -282,6 +418,12 @@ class InputBatch:
def
no_top_k
(
self
)
->
bool
:
return
len
(
self
.
top_k_reqs
)
==
0
@
property
def
no_penalties
(
self
)
->
bool
:
return
(
len
(
self
.
presence_penalties_reqs
)
==
0
and
len
(
self
.
frequency_penalties_reqs
)
==
0
and
len
(
self
.
repetition_penalties_reqs
)
==
0
)
@
property
def
max_num_logprobs
(
self
)
->
int
:
return
max
(
self
.
num_logprobs
.
values
())
if
self
.
num_logprobs
else
0
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
96ae75ad
...
...
@@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
FlashAttentionMetadata
)
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperClient
from
vllm.v1.engine.mm_input_mapper
import
MMHasher
,
MMInputMapperClient
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
...
@@ -79,8 +79,14 @@ class GPUModelRunner:
# Multi-modal data support
self
.
input_registry
=
INPUT_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
# NOTE: mm_input_mapper is only used for memory profiling.
self
.
mm_input_mapper
=
MMInputMapperClient
(
self
.
model_config
)
# NOTE: mm_input_mapper_client and mm_hasher are only used for memory
# profiling.
self
.
mm_input_mapper_client
=
MMInputMapperClient
(
self
.
model_config
)
self
.
mm_hasher
=
MMHasher
()
self
.
use_hash
=
(
not
model_config
.
disable_mm_preprocessor_cache
)
or
\
cache_config
.
enable_prefix_caching
self
.
max_num_encoder_input_tokens
=
self
.
scheduler_config
.
max_num_encoder_input_tokens
# noqa: E501
self
.
encoder_cache_size
=
self
.
scheduler_config
.
encoder_cache_size
...
...
@@ -99,6 +105,7 @@ class GPUModelRunner:
max_num_blocks_per_req
=
self
.
max_num_blocks_per_req
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
model_config
.
get_vocab_size
(),
)
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
...
...
@@ -377,7 +384,12 @@ class GPUModelRunner:
or
scheduler_output
.
scheduled_resumed_reqs
):
skip_copy
=
False
# Create the sampling metadata.
sampling_metadata
=
self
.
input_batch
.
make_sampling_metadata
(
skip_copy
)
req_id_output_token_ids
:
Dict
[
str
,
List
[
int
]]
=
\
{
req_id
:
req
.
output_token_ids
\
for
req_id
,
req
in
self
.
requests
.
items
()}
sampling_metadata
=
self
.
input_batch
.
make_sampling_metadata
(
req_id_output_token_ids
,
skip_copy
)
return
sampling_metadata
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
...
...
@@ -628,11 +640,6 @@ class GPUModelRunner:
mm_registry
=
self
.
mm_registry
,
)
dummy_mm_data
=
dummy_request_data
.
multi_modal_data
dummy_mm_kwargs
,
_
=
self
.
mm_input_mapper
.
process_inputs
(
mm_data
=
dummy_mm_data
,
mm_hashes
=
None
,
mm_processor_kwargs
=
None
,
precomputed_mm_inputs
=
None
)
# NOTE: Currently model is profiled with a single non-text
# modality even when it supports multiple.
...
...
@@ -648,8 +655,39 @@ class GPUModelRunner:
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately.
# Case when models have a merged processor, their dummy data is
# already batched `MultiModalKwargs`, therefore we need to "unbatch"
# and take the first item in each batched tensor.
# TODO (ywang96): This is somewhat hacky. Refactor this to be
# consistent with the other case.
if
isinstance
(
dummy_mm_data
,
MultiModalKwargs
):
dummy_mm_kwargs
=
{
k
:
v
[
0
].
unsqueeze
(
0
)
for
k
,
v
in
dummy_mm_data
.
items
()
}
# Case when models have dummy data explicitly defined as
# `MultiModalDataDict`, so they need to be processed through input
# mapper.
else
:
# Compute MM hashes (if enabled)
mm_hashes
=
None
if
self
.
use_hash
:
mm_hashes
=
self
.
mm_hasher
.
hash_dummy_mm_data
(
dummy_mm_data
)
mm_kwargs_list
=
self
.
mm_input_mapper_client
.
process_inputs
(
mm_data
=
dummy_mm_data
,
mm_hashes
=
mm_hashes
,
mm_processor_kwargs
=
None
,
precomputed_mm_inputs
=
None
)
# Take the first `MultiModalKwargs`
dummy_mm_kwargs
=
mm_kwargs_list
[
0
]
batched_dummy_mm_inputs
=
MultiModalKwargs
.
batch
(
[
dummy_mm_kwargs
[
0
]
]
*
max_num_mm_items
)
[
dummy_mm_kwargs
]
*
max_num_mm_items
)
batched_dummy_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
batched_dummy_mm_inputs
,
device
=
self
.
device
)
...
...
vllm/v1/worker/gpu_worker.py
View file @
96ae75ad
...
...
@@ -202,7 +202,6 @@ class Worker:
)
->
ModelRunnerOutput
:
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
)
return
output
if
self
.
rank
==
0
else
None
return
output
def
profile
(
self
,
is_start
:
bool
=
True
):
if
self
.
profiler
is
None
:
...
...
vllm/worker/cpu_model_runner.py
View file @
96ae75ad
...
...
@@ -114,8 +114,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
def
__init__
(
self
,
use_mrope
:
bool
):
self
.
use_mrope
=
use_mrope
self
.
input_tokens
:
List
[
int
]
=
[]
self
.
input_positions
:
Optional
[
List
[
int
]]
=
[]
if
not
self
.
use_mrope
else
None
self
.
input_positions
:
List
[
int
]
=
[]
self
.
token_type_ids
:
Optional
[
List
[
int
]]
=
[]
self
.
seq_lens
:
List
[
int
]
=
[]
self
.
query_lens
:
List
[
int
]
=
[]
...
...
@@ -130,9 +129,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
self
.
multi_modal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
input_mrope_positions
:
Optional
[
List
[
List
[
int
]]]
=
[
[]
for
_
in
range
(
3
)
]
if
self
.
use_mrope
else
None
self
.
input_mrope_positions
:
List
[
List
[
int
]]
=
[[]
for
_
in
range
(
3
)]
def
__init__
(
self
,
runner
:
"CPUModelRunner"
,
...
...
@@ -167,7 +165,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_data
.
input_positions
if
not
input_data
.
use_mrope
else
input_data
.
input_mrope_positions
,
if
not
any
(
input_data
.
input_mrope_positions
)
else
input_data
.
input_mrope_positions
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
token_type_ids
=
torch
.
tensor
(
input_data
.
token_type_ids
,
...
...
@@ -236,7 +235,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
block_table
=
block_table
[
start_block
:]
# For MRotaryEmbedding
if
data
.
input
_position
s
is
None
:
if
seq_
data
.
mrope
_position
_delta
is
not
None
:
next_pos
=
MRotaryEmbedding
.
get_next_input_positions
(
seq_data
.
mrope_position_delta
,
context_len
,
...
...
@@ -309,8 +308,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
data
.
slot_mapping
.
extend
(
slot_mapping
)
# The MROPE positions are prepared in _compute_multi_modal_input
if
data
.
input_positions
is
not
None
:
data
.
input_positions
.
extend
(
token_positions
)
data
.
input_positions
.
extend
(
token_positions
)
if
data
.
token_type_ids
is
not
None
:
data
.
token_type_ids
.
extend
(
token_types
if
token_types
else
[])
...
...
vllm/worker/cpu_worker.py
View file @
96ae75ad
...
...
@@ -338,9 +338,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def
prepare_worker_input
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
WorkerInput
:
assert
execute_model_req
is
not
None
virtual_engine
=
execute_model_req
.
virtual_engine
virtual_engine
:
int
=
execute_model_req
.
virtual_engine
num_seq_groups
:
int
=
len
(
execute_model_req
.
seq_group_metadata_list
)
blocks_to_copy
=
execute_model_req
.
blocks_to_copy
blocks_to_copy
=
torch
.
tensor
(
execute_model_req
.
blocks_to_copy
,
device
=
"cpu"
,
dtype
=
torch
.
int64
).
view
(
-
1
,
2
)
...
...
vllm/worker/model_runner.py
View file @
96ae75ad
...
...
@@ -13,7 +13,7 @@ import numpy as np
import
torch
import
torch.distributed
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
tqdm
import
tqdm
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
...
...
@@ -22,7 +22,8 @@ from vllm.attention.backends.utils import CommonAttentionState
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.distributed
import
get_kv_transfer_group
,
get_pp_group
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
graph_capture
)
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
...
...
@@ -1416,8 +1417,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
logger
.
info
(
"Capturing cudagraphs for decoding. This may lead to "
"unexpected consequences if the model is not static. To "
"run the model in eager mode, set 'enforce_eager=True' or "
"use '--enforce-eager' in the CLI."
)
logger
.
info
(
"If out-of-memory error occurs during cudagraph capture,"
"use '--enforce-eager' in the CLI.
"
"If out-of-memory error occurs during cudagraph capture,"
" consider decreasing `gpu_memory_utilization` or "
"switching to eager mode. You can also reduce the "
"`max_num_seqs` as needed to decrease memory usage."
)
...
...
@@ -1454,8 +1455,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# memory usage of CUDA graph.
for
virtual_engine
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
):
for
batch_size
in
\
self
.
vllm_config
.
compilation_config
.
capture_sizes
:
# Only rank 0 should print progress bar during capture
capture_sizes
=
(
tqdm
(
self
.
vllm_config
.
compilation_config
.
capture_sizes
,
desc
=
"Capturing CUDA graph shapes"
,
)
if
get_tensor_model_parallel_rank
()
==
0
else
self
.
vllm_config
.
compilation_config
.
capture_sizes
)
for
batch_size
in
capture_sizes
:
attn_metadata
=
(
self
.
attn_state
.
graph_capture_get_metadata_for_batch
(
batch_size
,
...
...
vllm/worker/multi_step_model_runner.py
View file @
96ae75ad
...
...
@@ -406,8 +406,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if
not
cont
:
break
def
_final_process_outputs
(
self
,
model_input
:
StatefulModelInput
,
output_proc_callback
:
Optional
[
Callable
]):
def
_final_process_outputs
(
self
,
model_input
:
StatefulModelInput
,
output_proc_callback
:
Optional
[
Callable
])
->
List
[
SamplerOutput
]:
assert
model_input
.
frozen_model_input
is
not
None
has_async_callback
=
output_proc_callback
is
not
None
...
...
@@ -594,8 +595,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# should be [SamplerOutput]
return
output
def
_update_sampling_metadata
(
self
,
sampling_metadata
,
num_seqs
,
num_queries
):
def
_update_sampling_metadata
(
self
,
sampling_metadata
:
SamplingMetadata
,
num_seqs
:
Optional
[
int
],
num_queries
:
int
):
assert
sampling_metadata
.
num_prompts
==
0
assert
len
(
sampling_metadata
.
seq_groups
)
==
num_queries
...
...
@@ -820,7 +821,7 @@ def _pythonize_sampler_output(
for
sgdx
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
seq_groups
,
samples_list
)):
# Reminder: Please update docs/source/usage/compatibility_matrix.
rst
# Reminder: Please update docs/source/usage/compatibility_matrix.
md
# If the feature combo become valid
# (Check for Guided Decoding)
if
seq_group
.
sampling_params
.
logits_processors
:
...
...
@@ -850,13 +851,13 @@ def _pythonize_sampler_output(
seq_ids
=
seq_group
.
seq_ids
next_token_ids
=
sample_result
parent_ids
=
[
0
]
seq_outputs
:
List
[
SequenceOutput
]
if
cache
is
not
None
:
completion_seq_group_output
:
CompletionSequenceGroupOutput
=
\
cache
.
cached_completion_seq_group_output
.
get_object
()
completion_seq_group_output
.
samples
.
clear
()
seq_outputs
:
List
[
SequenceOutput
]
=
completion_seq_group_output
.
samples
seq_outputs
=
completion_seq_group_output
.
samples
else
:
seq_outputs
=
[]
...
...
vllm/worker/pooling_model_runner.py
View file @
96ae75ad
...
...
@@ -91,6 +91,10 @@ class PoolingModelRunner(
]
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
seqlen_agnostic_kwargs
=
{
"finished_requests_ids"
:
model_input
.
finished_requests_ids
,
"request_ids_to_seq_ids"
:
model_input
.
request_ids_to_seq_ids
,
}
if
self
.
has_inner_state
else
{}
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
...
...
@@ -110,7 +114,8 @@ class PoolingModelRunner(
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
**
cross_enc_kwargs
)
**
cross_enc_kwargs
,
**
seqlen_agnostic_kwargs
)
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
...
...
vllm/worker/utils.py
View file @
96ae75ad
...
...
@@ -13,7 +13,7 @@ def assert_enc_dec_mr_supported_scenario(
a supported scenario.
'''
# Reminder: Please update docs/source/usage/compatibility_matrix.
rst
# Reminder: Please update docs/source/usage/compatibility_matrix.
md
# If the feature combo become valid
if
enc_dec_mr
.
cache_config
.
enable_prefix_caching
:
...
...
vllm/worker/worker_base.py
View file @
96ae75ad
...
...
@@ -485,7 +485,7 @@ class WorkerWrapperBase:
self
.
worker
=
worker_class
(
*
args
,
**
kwargs
)
assert
self
.
worker
is
not
None
def
execute_method
(
self
,
method
,
*
args
,
**
kwargs
):
def
execute_method
(
self
,
method
:
str
,
*
args
,
**
kwargs
):
try
:
target
=
self
if
self
.
worker
is
None
else
self
.
worker
executor
=
getattr
(
target
,
method
)
...
...
Prev
1
…
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment