Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19a53b27
Unverified
Commit
19a53b27
authored
Jun 18, 2025
by
afeldman-nm
Committed by
GitHub
Jun 18, 2025
Browse files
[V1] Decouple GPU and TPU `InputBatch` (#19778)
Signed-off-by:
Andrew Feldman
<
afeldman@redhat.com
>
parent
eccdc831
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
597 additions
and
4 deletions
+597
-4
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+1
-1
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+6
-1
vllm/v1/worker/lora_model_runner_mixin.py
vllm/v1/worker/lora_model_runner_mixin.py
+5
-1
vllm/v1/worker/tpu_input_batch.py
vllm/v1/worker/tpu_input_batch.py
+584
-0
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+1
-1
No files found.
vllm/v1/sample/tpu/metadata.py
View file @
19a53b27
...
...
@@ -5,7 +5,7 @@ from typing import Optional
import
torch
from
vllm.v1.worker.
g
pu_input_batch
import
InputBatch
from
vllm.v1.worker.
t
pu_input_batch
import
InputBatch
DEFAULT_SAMPLING_PARAMS
=
dict
(
temperature
=-
1.0
,
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
19a53b27
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a
n
input batch
# Datastructures defining a
GPU
input batch
from
dataclasses
import
dataclass
from
typing
import
Optional
,
cast
...
...
@@ -453,6 +453,11 @@ class InputBatch:
self
.
block_table
.
swap_row
(
i1
,
i2
)
def
condense
(
self
,
empty_req_indices
:
list
[
int
])
->
None
:
"""Move non-empty requests down into lower, empty indices.
Args:
empty_req_indices: empty batch indices, sorted descending.
"""
num_reqs
=
self
.
num_reqs
if
num_reqs
==
0
:
# The batched states are empty.
...
...
vllm/v1/worker/lora_model_runner_mixin.py
View file @
19a53b27
...
...
@@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
"""
from
contextlib
import
contextmanager
from
typing
import
Union
import
numpy
as
np
import
torch.nn
as
nn
...
...
@@ -15,7 +16,10 @@ from vllm.lora.layers import LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor.models
import
supports_lora
,
supports_multimodal
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
as
GPUInputBatch
from
vllm.v1.worker.tpu_input_batch
import
InputBatch
as
TPUInputBatch
InputBatch
=
Union
[
TPUInputBatch
,
GPUInputBatch
]
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/worker/tpu_input_batch.py
0 → 100644
View file @
19a53b27
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a TPU input batch
from
typing
import
Optional
,
cast
import
numpy
as
np
import
torch
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingType
from
vllm.utils
import
swap_dict_values
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.worker.block_table
import
MultiGroupBlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
_SAMPLING_EPS
=
1e-5
class
InputBatch
:
def
__init__
(
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_batched_tokens
:
int
,
device
:
torch
.
device
,
pin_memory
:
bool
,
vocab_size
:
int
,
block_sizes
:
list
[
int
],
# The block_size of each kv cache group
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
vocab_size
=
vocab_size
self
.
_req_ids
:
list
[
Optional
[
str
]]
=
[]
self
.
req_id_to_index
:
dict
[
str
,
int
]
=
{}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self
.
token_ids_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,
max_model_len
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
False
,
)
self
.
token_ids_cpu
=
self
.
token_ids_cpu_tensor
.
numpy
()
self
.
num_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_tokens_no_spec
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
pin_memory
,
)
self
.
num_computed_tokens_cpu
=
\
self
.
num_computed_tokens_cpu_tensor
.
numpy
()
# Block table.
self
.
block_table
=
MultiGroupBlockTable
(
max_num_reqs
=
max_num_reqs
,
max_model_len
=
max_model_len
,
max_num_batched_tokens
=
max_num_batched_tokens
,
pin_memory
=
pin_memory
,
device
=
device
,
block_sizes
=
block_sizes
,
)
# Sampling-related.
self
.
temperature
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
temperature_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
temperature_cpu
=
self
.
temperature_cpu_tensor
.
numpy
()
self
.
greedy_reqs
:
set
[
str
]
=
set
()
self
.
random_reqs
:
set
[
str
]
=
set
()
self
.
top_p
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
top_p_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
top_p_cpu
=
self
.
top_p_cpu_tensor
.
numpy
()
self
.
top_p_reqs
:
set
[
str
]
=
set
()
self
.
top_k
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
top_k_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
top_k_cpu
=
self
.
top_k_cpu_tensor
.
numpy
()
self
.
top_k_reqs
:
set
[
str
]
=
set
()
self
.
min_p
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
min_p_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
min_p_cpu
=
self
.
min_p_cpu_tensor
.
numpy
()
self
.
min_p_reqs
:
set
[
str
]
=
set
()
# Frequency penalty related data structures
self
.
frequency_penalties
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
device
)
self
.
frequency_penalties_cpu_tensor
=
torch
.
empty
(
(
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
frequency_penalties_cpu
=
\
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_reqs
:
set
[
str
]
=
set
()
# Presence penalty related data structures
self
.
presence_penalties
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
device
)
self
.
presence_penalties_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
presence_penalties_cpu
=
self
.
presence_penalties_cpu_tensor
.
numpy
(
)
self
.
presence_penalties_reqs
:
set
[
str
]
=
set
()
# Repetition penalty related data structures
self
.
repetition_penalties
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
device
)
self
.
repetition_penalties_cpu_tensor
=
torch
.
empty
(
(
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
repetition_penalties_cpu
=
\
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
set
[
str
]
=
set
()
# req_index -> (min_tokens, stop_token_ids)
self
.
min_tokens
:
dict
[
int
,
tuple
[
int
,
set
[
int
]]]
=
{}
# lora related
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,
),
dtype
=
np
.
int32
)
self
.
lora_id_to_request_ids
:
dict
[
int
,
set
[
str
]]
=
{}
self
.
lora_id_to_lora_request
:
dict
[
int
,
LoRARequest
]
=
{}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self
.
generators
:
dict
[
int
,
torch
.
Generator
]
=
{}
self
.
num_logprobs
:
dict
[
str
,
int
]
=
{}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
dict
[
str
,
int
]
=
{}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self
.
in_progress_prompt_logprobs_cpu
:
dict
[
str
,
LogprobsTensors
]
=
{}
self
.
logit_bias
:
list
[
Optional
[
dict
[
int
,
float
]]]
=
[
None
]
*
max_num_reqs
self
.
has_allowed_token_ids
:
set
[
str
]
=
set
()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self
.
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
self
.
allowed_token_ids_mask_cpu_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# req_index -> bad_words_token_ids
self
.
bad_words_token_ids
:
dict
[
int
,
list
[
list
[
int
]]]
=
{}
self
.
req_output_token_ids
:
list
[
Optional
[
list
[
int
]]]
=
[]
@
property
def
req_ids
(
self
)
->
list
[
str
]:
# None elements should only be present transiently
# while performing state updates to the batch.
return
cast
(
list
[
str
],
self
.
_req_ids
)
def
add_request
(
self
,
request
:
"CachedRequestState"
,
req_index
:
Optional
[
int
]
=
None
,
)
->
None
:
if
req_index
is
None
:
req_index
=
self
.
num_reqs
assert
req_index
<
self
.
max_num_reqs
req_id
=
request
.
req_id
if
req_index
==
len
(
self
.
_req_ids
):
self
.
_req_ids
.
append
(
req_id
)
self
.
req_output_token_ids
.
append
(
request
.
output_token_ids
)
else
:
self
.
_req_ids
[
req_index
]
=
req_id
self
.
req_output_token_ids
[
req_index
]
=
request
.
output_token_ids
self
.
req_id_to_index
[
req_id
]
=
req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens
=
len
(
request
.
prompt_token_ids
)
self
.
num_prompt_tokens
[
req_index
]
=
num_prompt_tokens
self
.
token_ids_cpu
[
req_index
,
:
num_prompt_tokens
]
=
request
.
prompt_token_ids
start_idx
=
num_prompt_tokens
end_idx
=
start_idx
+
len
(
request
.
output_token_ids
)
self
.
token_ids_cpu
[
req_index
,
start_idx
:
end_idx
]
=
request
.
output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self
.
num_tokens
[
req_index
]
=
request
.
num_tokens
# Number of tokens without spec decode tokens.
self
.
num_tokens_no_spec
[
req_index
]
=
request
.
num_tokens
self
.
num_computed_tokens_cpu
[
req_index
]
=
request
.
num_computed_tokens
self
.
block_table
.
add_row
(
request
.
block_ids
,
req_index
)
sampling_params
=
request
.
sampling_params
if
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
:
# Avoid later division by zero.
self
.
temperature_cpu
[
req_index
]
=
-
1.0
self
.
greedy_reqs
.
add
(
req_id
)
else
:
self
.
temperature_cpu
[
req_index
]
=
sampling_params
.
temperature
self
.
random_reqs
.
add
(
req_id
)
self
.
top_p_cpu
[
req_index
]
=
sampling_params
.
top_p
if
sampling_params
.
top_p
<
1
:
self
.
top_p_reqs
.
add
(
req_id
)
top_k
=
sampling_params
.
top_k
if
0
<
top_k
<
self
.
vocab_size
:
self
.
top_k_reqs
.
add
(
req_id
)
else
:
top_k
=
self
.
vocab_size
self
.
top_k_cpu
[
req_index
]
=
top_k
self
.
min_p_cpu
[
req_index
]
=
sampling_params
.
min_p
self
.
frequency_penalties_cpu
[
req_index
]
=
sampling_params
.
frequency_penalty
if
sampling_params
.
min_p
>
_SAMPLING_EPS
:
self
.
min_p_reqs
.
add
(
req_id
)
if
sampling_params
.
frequency_penalty
!=
0.0
:
self
.
frequency_penalties_reqs
.
add
(
req_id
)
self
.
presence_penalties_cpu
[
req_index
]
=
sampling_params
.
presence_penalty
if
sampling_params
.
presence_penalty
!=
0.0
:
self
.
presence_penalties_reqs
.
add
(
req_id
)
self
.
repetition_penalties_cpu
[
req_index
]
=
sampling_params
.
repetition_penalty
if
sampling_params
.
repetition_penalty
!=
1.0
:
self
.
repetition_penalties_reqs
.
add
(
req_id
)
if
sampling_params
.
min_tokens
:
self
.
min_tokens
[
req_index
]
=
(
sampling_params
.
min_tokens
,
sampling_params
.
all_stop_token_ids
)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if
request
.
generator
is
not
None
:
self
.
generators
[
req_index
]
=
request
.
generator
if
sampling_params
.
logprobs
is
not
None
:
self
.
num_logprobs
[
req_id
]
=
sampling_params
.
logprobs
if
sampling_params
.
prompt_logprobs
is
not
None
:
self
.
num_prompt_logprobs
[
req_id
]
=
sampling_params
.
prompt_logprobs
if
sampling_params
.
logit_bias
is
not
None
:
self
.
logit_bias
[
req_index
]
=
sampling_params
.
logit_bias
if
sampling_params
.
allowed_token_ids
:
self
.
has_allowed_token_ids
.
add
(
req_id
)
if
self
.
allowed_token_ids_mask_cpu_tensor
is
None
:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self
.
allowed_token_ids_mask
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
self
.
device
)
self
.
allowed_token_ids_mask_cpu_tensor
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
"cpu"
)
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
]
=
True
# False means we don't fill with -inf.
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
][
sampling_params
.
allowed_token_ids
]
=
False
if
sampling_params
.
bad_words_token_ids
:
self
.
bad_words_token_ids
[
req_index
]
=
sampling_params
.
bad_words_token_ids
# Add request lora ID
if
request
.
lora_request
:
lora_id
=
request
.
lora_request
.
lora_int_id
if
lora_id
not
in
self
.
lora_id_to_request_ids
:
self
.
lora_id_to_request_ids
[
lora_id
]
=
set
()
self
.
request_lora_mapping
[
req_index
]
=
lora_id
self
.
lora_id_to_request_ids
[
lora_id
].
add
(
request
.
req_id
)
self
.
lora_id_to_lora_request
[
lora_id
]
=
request
.
lora_request
else
:
# No LoRA
self
.
request_lora_mapping
[
req_index
]
=
0
def
remove_request
(
self
,
req_id
:
str
)
->
Optional
[
int
]:
"""This method must always be followed by a call to condense()."""
req_index
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
if
req_index
is
None
:
return
None
self
.
_req_ids
[
req_index
]
=
None
self
.
req_output_token_ids
[
req_index
]
=
None
self
.
greedy_reqs
.
discard
(
req_id
)
self
.
random_reqs
.
discard
(
req_id
)
self
.
top_p_reqs
.
discard
(
req_id
)
self
.
top_k_reqs
.
discard
(
req_id
)
self
.
min_p_reqs
.
discard
(
req_id
)
self
.
min_tokens
.
pop
(
req_index
,
None
)
self
.
frequency_penalties_reqs
.
discard
(
req_id
)
self
.
presence_penalties_reqs
.
discard
(
req_id
)
self
.
repetition_penalties_reqs
.
discard
(
req_id
)
self
.
generators
.
pop
(
req_index
,
None
)
self
.
num_logprobs
.
pop
(
req_id
,
None
)
self
.
num_prompt_logprobs
.
pop
(
req_id
,
None
)
self
.
in_progress_prompt_logprobs_cpu
.
pop
(
req_id
,
None
)
# LoRA
lora_id
=
self
.
request_lora_mapping
[
req_index
]
if
lora_id
!=
0
:
self
.
lora_id_to_request_ids
[
lora_id
].
discard
(
req_id
)
if
len
(
self
.
lora_id_to_request_ids
[
lora_id
])
==
0
:
self
.
lora_id_to_request_ids
.
pop
(
lora_id
)
self
.
lora_id_to_lora_request
.
pop
(
lora_id
)
self
.
request_lora_mapping
[
req_index
]
=
0
self
.
logit_bias
[
req_index
]
=
None
self
.
has_allowed_token_ids
.
discard
(
req_id
)
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
# False means we don't fill with -inf.
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
].
fill_
(
False
)
self
.
bad_words_token_ids
.
pop
(
req_index
,
None
)
return
req_index
def
swap_states
(
self
,
i1
:
int
,
i2
:
int
)
->
None
:
old_id_i1
=
self
.
_req_ids
[
i1
]
old_id_i2
=
self
.
_req_ids
[
i2
]
self
.
_req_ids
[
i1
],
self
.
_req_ids
[
i2
]
=
\
self
.
_req_ids
[
i2
],
self
.
_req_ids
[
i1
]
# noqa
self
.
req_output_token_ids
[
i1
],
self
.
req_output_token_ids
[
i2
]
=
\
self
.
req_output_token_ids
[
i2
],
self
.
req_output_token_ids
[
i1
]
assert
old_id_i1
is
not
None
and
old_id_i2
is
not
None
self
.
req_id_to_index
[
old_id_i1
],
self
.
req_id_to_index
[
old_id_i2
]
=
\
self
.
req_id_to_index
[
old_id_i2
],
self
.
req_id_to_index
[
old_id_i1
]
self
.
num_tokens
[
i1
],
self
.
num_tokens
[
i2
]
=
\
self
.
num_tokens
[
i2
],
self
.
num_tokens
[
i1
]
self
.
num_tokens_no_spec
[
i1
],
self
.
num_tokens_no_spec
[
i2
]
=
\
self
.
num_tokens_no_spec
[
i2
],
self
.
num_tokens_no_spec
[
i1
]
self
.
num_prompt_tokens
[
i1
],
self
.
num_prompt_tokens
[
i2
]
=
\
self
.
num_prompt_tokens
[
i2
],
self
.
num_prompt_tokens
[
i1
]
self
.
num_computed_tokens_cpu
[
i1
],
self
.
num_computed_tokens_cpu
[
i2
]
=
\
self
.
num_computed_tokens_cpu
[
i2
],
self
.
num_computed_tokens_cpu
[
i1
]
self
.
temperature_cpu
[
i1
],
self
.
temperature_cpu
[
i2
]
=
\
self
.
temperature_cpu
[
i2
],
self
.
temperature_cpu
[
i1
]
self
.
top_p_cpu
[
i1
],
self
.
top_p_cpu
[
i2
]
=
\
self
.
top_p_cpu
[
i2
],
self
.
top_p_cpu
[
i1
]
self
.
top_k_cpu
[
i1
],
self
.
top_k_cpu
[
i2
]
=
\
self
.
top_k_cpu
[
i2
],
self
.
top_k_cpu
[
i1
]
self
.
frequency_penalties_cpu
[
i1
],
self
.
frequency_penalties_cpu
[
i2
]
=
\
self
.
frequency_penalties_cpu
[
i2
],
self
.
frequency_penalties_cpu
[
i1
]
self
.
presence_penalties_cpu
[
i1
],
self
.
presence_penalties_cpu
[
i2
]
=
\
self
.
presence_penalties_cpu
[
i2
],
self
.
presence_penalties_cpu
[
i1
]
self
.
repetition_penalties_cpu
[
i1
],
self
.
repetition_penalties_cpu
[
i2
]
=
\
self
.
repetition_penalties_cpu
[
i2
],
self
.
repetition_penalties_cpu
[
i1
]
self
.
min_p_cpu
[
i1
],
self
.
min_p_cpu
[
i2
]
=
\
self
.
min_p_cpu
[
i2
],
self
.
min_p_cpu
[
i1
]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp
=
self
.
token_ids_cpu
[
i1
,
...].
copy
()
self
.
token_ids_cpu
[
i1
,
...]
=
self
.
token_ids_cpu
[
i2
,
...]
self
.
token_ids_cpu
[
i2
,
...]
=
tmp
swap_dict_values
(
self
.
generators
,
i1
,
i2
)
swap_dict_values
(
self
.
min_tokens
,
i1
,
i2
)
swap_dict_values
(
self
.
bad_words_token_ids
,
i1
,
i2
)
self
.
request_lora_mapping
[
i1
],
self
.
request_lora_mapping
[
i2
]
=
\
self
.
request_lora_mapping
[
i2
],
self
.
request_lora_mapping
[
i1
]
self
.
logit_bias
[
i1
],
self
.
logit_bias
[
i2
]
=
\
self
.
logit_bias
[
i2
],
self
.
logit_bias
[
i1
]
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
self
.
allowed_token_ids_mask_cpu_tensor
[
i1
],
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i2
]
=
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i2
],
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i1
]
self
.
block_table
.
swap_row
(
i1
,
i2
)
def
condense
(
self
,
empty_req_indices
:
list
[
int
])
->
None
:
"""Move non-empty requests down into lower, empty indices.
Args:
empty_req_indices: empty batch indices, sorted descending.
"""
num_reqs
=
self
.
num_reqs
if
num_reqs
==
0
:
# The batched states are empty.
self
.
_req_ids
.
clear
()
self
.
req_output_token_ids
.
clear
()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index
=
num_reqs
+
len
(
empty_req_indices
)
-
1
while
empty_req_indices
:
# Find the largest non-empty index.
while
last_req_index
in
empty_req_indices
:
last_req_index
-=
1
# Find the smallest empty index.
empty_index
=
empty_req_indices
.
pop
()
if
empty_index
>=
last_req_index
:
break
# Swap the states.
req_id
=
self
.
_req_ids
[
last_req_index
]
output_token_ids
=
self
.
req_output_token_ids
[
last_req_index
]
assert
req_id
is
not
None
self
.
_req_ids
[
empty_index
]
=
req_id
self
.
_req_ids
[
last_req_index
]
=
None
self
.
req_output_token_ids
[
empty_index
]
=
output_token_ids
self
.
req_output_token_ids
[
last_req_index
]
=
None
self
.
req_id_to_index
[
req_id
]
=
empty_index
num_tokens
=
self
.
num_tokens
[
last_req_index
]
self
.
token_ids_cpu
[
empty_index
,
:
num_tokens
]
=
self
.
token_ids_cpu
[
last_req_index
,
:
num_tokens
]
self
.
num_tokens
[
empty_index
]
=
num_tokens
self
.
num_tokens_no_spec
[
empty_index
]
=
self
.
num_tokens_no_spec
[
last_req_index
]
self
.
num_prompt_tokens
[
empty_index
]
=
self
.
num_prompt_tokens
[
last_req_index
]
self
.
num_computed_tokens_cpu
[
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
self
.
block_table
.
move_row
(
last_req_index
,
empty_index
)
self
.
temperature_cpu
[
empty_index
]
=
self
.
temperature_cpu
[
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
self
.
top_k_cpu
[
empty_index
]
=
self
.
top_k_cpu
[
last_req_index
]
self
.
frequency_penalties_cpu
[
empty_index
]
=
self
.
frequency_penalties_cpu
[
last_req_index
]
self
.
presence_penalties_cpu
[
empty_index
]
=
self
.
presence_penalties_cpu
[
last_req_index
]
self
.
repetition_penalties_cpu
[
empty_index
]
=
self
.
repetition_penalties_cpu
[
last_req_index
]
self
.
min_p_cpu
[
empty_index
]
=
self
.
min_p_cpu
[
last_req_index
]
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
if
generator
is
not
None
:
self
.
generators
[
empty_index
]
=
generator
min_token
=
self
.
min_tokens
.
pop
(
last_req_index
,
None
)
if
min_token
is
not
None
:
self
.
min_tokens
[
empty_index
]
=
min_token
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
last_req_index
]
self
.
logit_bias
[
empty_index
]
=
self
.
logit_bias
[
last_req_index
]
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
self
.
allowed_token_ids_mask_cpu_tensor
[
empty_index
]
=
self
.
allowed_token_ids_mask_cpu_tensor
[
last_req_index
]
bad_words_token_ids
=
self
.
bad_words_token_ids
.
pop
(
last_req_index
,
None
)
if
bad_words_token_ids
is
not
None
:
self
.
bad_words_token_ids
[
empty_index
]
=
bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
# Trim lists to the batch size.
del
self
.
_req_ids
[
self
.
num_reqs
:]
del
self
.
req_output_token_ids
[
self
.
num_reqs
:]
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
(
self
.
num_reqs
,
max_prompt_len
),
device
=
"cpu"
,
dtype
=
torch
.
int64
,
pin_memory
=
self
.
pin_memory
,
)
prompt_token_ids
=
prompt_token_ids_cpu_tensor
.
numpy
()
prompt_token_ids
[:]
=
self
.
token_ids_cpu
[:
self
.
num_reqs
,
:
max_prompt_len
]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for
i
in
range
(
self
.
num_reqs
):
prompt_token_ids
[
i
,
self
.
num_prompt_tokens
[
i
]:]
=
self
.
vocab_size
return
prompt_token_ids_cpu_tensor
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
def
make_lora_inputs
(
self
,
num_scheduled_tokens
:
np
.
ndarray
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
set
[
LoRARequest
]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping
=
self
.
request_lora_mapping
[:
self
.
num_reqs
]
prompt_lora_mapping
=
tuple
(
req_lora_mapping
)
token_lora_mapping
=
tuple
(
req_lora_mapping
.
repeat
(
num_scheduled_tokens
))
active_lora_requests
:
set
[
LoRARequest
]
=
set
(
self
.
lora_id_to_lora_request
.
values
())
return
prompt_lora_mapping
,
token_lora_mapping
,
active_lora_requests
@
property
def
num_reqs
(
self
)
->
int
:
return
len
(
self
.
req_id_to_index
)
@
property
def
all_greedy
(
self
)
->
bool
:
return
len
(
self
.
random_reqs
)
==
0
@
property
def
all_random
(
self
)
->
bool
:
return
len
(
self
.
greedy_reqs
)
==
0
@
property
def
no_top_p
(
self
)
->
bool
:
return
len
(
self
.
top_p_reqs
)
==
0
@
property
def
no_top_k
(
self
)
->
bool
:
return
len
(
self
.
top_k_reqs
)
==
0
@
property
def
no_min_p
(
self
)
->
bool
:
return
len
(
self
.
min_p_reqs
)
==
0
@
property
def
no_penalties
(
self
)
->
bool
:
return
(
len
(
self
.
presence_penalties_reqs
)
==
0
and
len
(
self
.
frequency_penalties_reqs
)
==
0
and
len
(
self
.
repetition_penalties_reqs
)
==
0
)
@
property
def
max_num_logprobs
(
self
)
->
Optional
[
int
]:
return
max
(
self
.
num_logprobs
.
values
())
if
self
.
num_logprobs
else
None
@
property
def
no_prompt_logprob
(
self
)
->
bool
:
return
not
self
.
num_prompt_logprobs
@
property
def
no_allowed_token_ids
(
self
)
->
bool
:
return
len
(
self
.
has_allowed_token_ids
)
==
0
vllm/v1/worker/tpu_model_runner.py
View file @
19a53b27
...
...
@@ -42,8 +42,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.tpu_input_batch
import
CachedRequestState
,
InputBatch
from
.utils
import
(
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment