Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19a53b27
Unverified
Commit
19a53b27
authored
Jun 18, 2025
by
afeldman-nm
Committed by
GitHub
Jun 18, 2025
Browse files
[V1] Decouple GPU and TPU `InputBatch` (#19778)
Signed-off-by:
Andrew Feldman
<
afeldman@redhat.com
>
parent
eccdc831
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
597 additions
and
4 deletions
+597
-4
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+1
-1
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+6
-1
vllm/v1/worker/lora_model_runner_mixin.py
vllm/v1/worker/lora_model_runner_mixin.py
+5
-1
vllm/v1/worker/tpu_input_batch.py
vllm/v1/worker/tpu_input_batch.py
+584
-0
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+1
-1
No files found.
vllm/v1/sample/tpu/metadata.py
View file @
19a53b27
...
@@ -5,7 +5,7 @@ from typing import Optional
...
@@ -5,7 +5,7 @@ from typing import Optional
import
torch
import
torch
from
vllm.v1.worker.
g
pu_input_batch
import
InputBatch
from
vllm.v1.worker.
t
pu_input_batch
import
InputBatch
DEFAULT_SAMPLING_PARAMS
=
dict
(
DEFAULT_SAMPLING_PARAMS
=
dict
(
temperature
=-
1.0
,
temperature
=-
1.0
,
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
19a53b27
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a
n
input batch
# Datastructures defining a
GPU
input batch
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
cast
from
typing
import
Optional
,
cast
...
@@ -453,6 +453,11 @@ class InputBatch:
...
@@ -453,6 +453,11 @@ class InputBatch:
self
.
block_table
.
swap_row
(
i1
,
i2
)
self
.
block_table
.
swap_row
(
i1
,
i2
)
def
condense
(
self
,
empty_req_indices
:
list
[
int
])
->
None
:
def
condense
(
self
,
empty_req_indices
:
list
[
int
])
->
None
:
"""Move non-empty requests down into lower, empty indices.
Args:
empty_req_indices: empty batch indices, sorted descending.
"""
num_reqs
=
self
.
num_reqs
num_reqs
=
self
.
num_reqs
if
num_reqs
==
0
:
if
num_reqs
==
0
:
# The batched states are empty.
# The batched states are empty.
...
...
vllm/v1/worker/lora_model_runner_mixin.py
View file @
19a53b27
...
@@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
...
@@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
"""
"""
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -15,7 +16,10 @@ from vllm.lora.layers import LoRAMapping
...
@@ -15,7 +16,10 @@ from vllm.lora.layers import LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor.models
import
supports_lora
,
supports_multimodal
from
vllm.model_executor.models
import
supports_lora
,
supports_multimodal
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
as
GPUInputBatch
from
vllm.v1.worker.tpu_input_batch
import
InputBatch
as
TPUInputBatch
InputBatch
=
Union
[
TPUInputBatch
,
GPUInputBatch
]
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/worker/tpu_input_batch.py
0 → 100644
View file @
19a53b27
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a TPU input batch
from
typing
import
Optional
,
cast
import
numpy
as
np
import
torch
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingType
from
vllm.utils
import
swap_dict_values
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.worker.block_table
import
MultiGroupBlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
_SAMPLING_EPS
=
1e-5
class
InputBatch
:
def
__init__
(
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_batched_tokens
:
int
,
device
:
torch
.
device
,
pin_memory
:
bool
,
vocab_size
:
int
,
block_sizes
:
list
[
int
],
# The block_size of each kv cache group
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
vocab_size
=
vocab_size
self
.
_req_ids
:
list
[
Optional
[
str
]]
=
[]
self
.
req_id_to_index
:
dict
[
str
,
int
]
=
{}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self
.
token_ids_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,
max_model_len
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
False
,
)
self
.
token_ids_cpu
=
self
.
token_ids_cpu_tensor
.
numpy
()
self
.
num_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_tokens_no_spec
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
pin_memory
,
)
self
.
num_computed_tokens_cpu
=
\
self
.
num_computed_tokens_cpu_tensor
.
numpy
()
# Block table.
self
.
block_table
=
MultiGroupBlockTable
(
max_num_reqs
=
max_num_reqs
,
max_model_len
=
max_model_len
,
max_num_batched_tokens
=
max_num_batched_tokens
,
pin_memory
=
pin_memory
,
device
=
device
,
block_sizes
=
block_sizes
,
)
# Sampling-related.
self
.
temperature
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
temperature_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
temperature_cpu
=
self
.
temperature_cpu_tensor
.
numpy
()
self
.
greedy_reqs
:
set
[
str
]
=
set
()
self
.
random_reqs
:
set
[
str
]
=
set
()
self
.
top_p
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
top_p_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
top_p_cpu
=
self
.
top_p_cpu_tensor
.
numpy
()
self
.
top_p_reqs
:
set
[
str
]
=
set
()
self
.
top_k
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
top_k_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
top_k_cpu
=
self
.
top_k_cpu_tensor
.
numpy
()
self
.
top_k_reqs
:
set
[
str
]
=
set
()
self
.
min_p
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
min_p_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
min_p_cpu
=
self
.
min_p_cpu_tensor
.
numpy
()
self
.
min_p_reqs
:
set
[
str
]
=
set
()
# Frequency penalty related data structures
self
.
frequency_penalties
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
device
)
self
.
frequency_penalties_cpu_tensor
=
torch
.
empty
(
(
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
frequency_penalties_cpu
=
\
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_reqs
:
set
[
str
]
=
set
()
# Presence penalty related data structures
self
.
presence_penalties
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
device
)
self
.
presence_penalties_cpu_tensor
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
presence_penalties_cpu
=
self
.
presence_penalties_cpu_tensor
.
numpy
(
)
self
.
presence_penalties_reqs
:
set
[
str
]
=
set
()
# Repetition penalty related data structures
self
.
repetition_penalties
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
device
)
self
.
repetition_penalties_cpu_tensor
=
torch
.
empty
(
(
max_num_reqs
,
),
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
repetition_penalties_cpu
=
\
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
set
[
str
]
=
set
()
# req_index -> (min_tokens, stop_token_ids)
self
.
min_tokens
:
dict
[
int
,
tuple
[
int
,
set
[
int
]]]
=
{}
# lora related
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,
),
dtype
=
np
.
int32
)
self
.
lora_id_to_request_ids
:
dict
[
int
,
set
[
str
]]
=
{}
self
.
lora_id_to_lora_request
:
dict
[
int
,
LoRARequest
]
=
{}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self
.
generators
:
dict
[
int
,
torch
.
Generator
]
=
{}
self
.
num_logprobs
:
dict
[
str
,
int
]
=
{}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
dict
[
str
,
int
]
=
{}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self
.
in_progress_prompt_logprobs_cpu
:
dict
[
str
,
LogprobsTensors
]
=
{}
self
.
logit_bias
:
list
[
Optional
[
dict
[
int
,
float
]]]
=
[
None
]
*
max_num_reqs
self
.
has_allowed_token_ids
:
set
[
str
]
=
set
()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self
.
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
self
.
allowed_token_ids_mask_cpu_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# req_index -> bad_words_token_ids
self
.
bad_words_token_ids
:
dict
[
int
,
list
[
list
[
int
]]]
=
{}
self
.
req_output_token_ids
:
list
[
Optional
[
list
[
int
]]]
=
[]
@
property
def
req_ids
(
self
)
->
list
[
str
]:
# None elements should only be present transiently
# while performing state updates to the batch.
return
cast
(
list
[
str
],
self
.
_req_ids
)
def
add_request
(
self
,
request
:
"CachedRequestState"
,
req_index
:
Optional
[
int
]
=
None
,
)
->
None
:
if
req_index
is
None
:
req_index
=
self
.
num_reqs
assert
req_index
<
self
.
max_num_reqs
req_id
=
request
.
req_id
if
req_index
==
len
(
self
.
_req_ids
):
self
.
_req_ids
.
append
(
req_id
)
self
.
req_output_token_ids
.
append
(
request
.
output_token_ids
)
else
:
self
.
_req_ids
[
req_index
]
=
req_id
self
.
req_output_token_ids
[
req_index
]
=
request
.
output_token_ids
self
.
req_id_to_index
[
req_id
]
=
req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens
=
len
(
request
.
prompt_token_ids
)
self
.
num_prompt_tokens
[
req_index
]
=
num_prompt_tokens
self
.
token_ids_cpu
[
req_index
,
:
num_prompt_tokens
]
=
request
.
prompt_token_ids
start_idx
=
num_prompt_tokens
end_idx
=
start_idx
+
len
(
request
.
output_token_ids
)
self
.
token_ids_cpu
[
req_index
,
start_idx
:
end_idx
]
=
request
.
output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self
.
num_tokens
[
req_index
]
=
request
.
num_tokens
# Number of tokens without spec decode tokens.
self
.
num_tokens_no_spec
[
req_index
]
=
request
.
num_tokens
self
.
num_computed_tokens_cpu
[
req_index
]
=
request
.
num_computed_tokens
self
.
block_table
.
add_row
(
request
.
block_ids
,
req_index
)
sampling_params
=
request
.
sampling_params
if
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
:
# Avoid later division by zero.
self
.
temperature_cpu
[
req_index
]
=
-
1.0
self
.
greedy_reqs
.
add
(
req_id
)
else
:
self
.
temperature_cpu
[
req_index
]
=
sampling_params
.
temperature
self
.
random_reqs
.
add
(
req_id
)
self
.
top_p_cpu
[
req_index
]
=
sampling_params
.
top_p
if
sampling_params
.
top_p
<
1
:
self
.
top_p_reqs
.
add
(
req_id
)
top_k
=
sampling_params
.
top_k
if
0
<
top_k
<
self
.
vocab_size
:
self
.
top_k_reqs
.
add
(
req_id
)
else
:
top_k
=
self
.
vocab_size
self
.
top_k_cpu
[
req_index
]
=
top_k
self
.
min_p_cpu
[
req_index
]
=
sampling_params
.
min_p
self
.
frequency_penalties_cpu
[
req_index
]
=
sampling_params
.
frequency_penalty
if
sampling_params
.
min_p
>
_SAMPLING_EPS
:
self
.
min_p_reqs
.
add
(
req_id
)
if
sampling_params
.
frequency_penalty
!=
0.0
:
self
.
frequency_penalties_reqs
.
add
(
req_id
)
self
.
presence_penalties_cpu
[
req_index
]
=
sampling_params
.
presence_penalty
if
sampling_params
.
presence_penalty
!=
0.0
:
self
.
presence_penalties_reqs
.
add
(
req_id
)
self
.
repetition_penalties_cpu
[
req_index
]
=
sampling_params
.
repetition_penalty
if
sampling_params
.
repetition_penalty
!=
1.0
:
self
.
repetition_penalties_reqs
.
add
(
req_id
)
if
sampling_params
.
min_tokens
:
self
.
min_tokens
[
req_index
]
=
(
sampling_params
.
min_tokens
,
sampling_params
.
all_stop_token_ids
)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if
request
.
generator
is
not
None
:
self
.
generators
[
req_index
]
=
request
.
generator
if
sampling_params
.
logprobs
is
not
None
:
self
.
num_logprobs
[
req_id
]
=
sampling_params
.
logprobs
if
sampling_params
.
prompt_logprobs
is
not
None
:
self
.
num_prompt_logprobs
[
req_id
]
=
sampling_params
.
prompt_logprobs
if
sampling_params
.
logit_bias
is
not
None
:
self
.
logit_bias
[
req_index
]
=
sampling_params
.
logit_bias
if
sampling_params
.
allowed_token_ids
:
self
.
has_allowed_token_ids
.
add
(
req_id
)
if
self
.
allowed_token_ids_mask_cpu_tensor
is
None
:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self
.
allowed_token_ids_mask
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
self
.
device
)
self
.
allowed_token_ids_mask_cpu_tensor
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
"cpu"
)
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
]
=
True
# False means we don't fill with -inf.
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
][
sampling_params
.
allowed_token_ids
]
=
False
if
sampling_params
.
bad_words_token_ids
:
self
.
bad_words_token_ids
[
req_index
]
=
sampling_params
.
bad_words_token_ids
# Add request lora ID
if
request
.
lora_request
:
lora_id
=
request
.
lora_request
.
lora_int_id
if
lora_id
not
in
self
.
lora_id_to_request_ids
:
self
.
lora_id_to_request_ids
[
lora_id
]
=
set
()
self
.
request_lora_mapping
[
req_index
]
=
lora_id
self
.
lora_id_to_request_ids
[
lora_id
].
add
(
request
.
req_id
)
self
.
lora_id_to_lora_request
[
lora_id
]
=
request
.
lora_request
else
:
# No LoRA
self
.
request_lora_mapping
[
req_index
]
=
0
def
remove_request
(
self
,
req_id
:
str
)
->
Optional
[
int
]:
"""This method must always be followed by a call to condense()."""
req_index
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
if
req_index
is
None
:
return
None
self
.
_req_ids
[
req_index
]
=
None
self
.
req_output_token_ids
[
req_index
]
=
None
self
.
greedy_reqs
.
discard
(
req_id
)
self
.
random_reqs
.
discard
(
req_id
)
self
.
top_p_reqs
.
discard
(
req_id
)
self
.
top_k_reqs
.
discard
(
req_id
)
self
.
min_p_reqs
.
discard
(
req_id
)
self
.
min_tokens
.
pop
(
req_index
,
None
)
self
.
frequency_penalties_reqs
.
discard
(
req_id
)
self
.
presence_penalties_reqs
.
discard
(
req_id
)
self
.
repetition_penalties_reqs
.
discard
(
req_id
)
self
.
generators
.
pop
(
req_index
,
None
)
self
.
num_logprobs
.
pop
(
req_id
,
None
)
self
.
num_prompt_logprobs
.
pop
(
req_id
,
None
)
self
.
in_progress_prompt_logprobs_cpu
.
pop
(
req_id
,
None
)
# LoRA
lora_id
=
self
.
request_lora_mapping
[
req_index
]
if
lora_id
!=
0
:
self
.
lora_id_to_request_ids
[
lora_id
].
discard
(
req_id
)
if
len
(
self
.
lora_id_to_request_ids
[
lora_id
])
==
0
:
self
.
lora_id_to_request_ids
.
pop
(
lora_id
)
self
.
lora_id_to_lora_request
.
pop
(
lora_id
)
self
.
request_lora_mapping
[
req_index
]
=
0
self
.
logit_bias
[
req_index
]
=
None
self
.
has_allowed_token_ids
.
discard
(
req_id
)
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
# False means we don't fill with -inf.
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
].
fill_
(
False
)
self
.
bad_words_token_ids
.
pop
(
req_index
,
None
)
return
req_index
def
swap_states
(
self
,
i1
:
int
,
i2
:
int
)
->
None
:
old_id_i1
=
self
.
_req_ids
[
i1
]
old_id_i2
=
self
.
_req_ids
[
i2
]
self
.
_req_ids
[
i1
],
self
.
_req_ids
[
i2
]
=
\
self
.
_req_ids
[
i2
],
self
.
_req_ids
[
i1
]
# noqa
self
.
req_output_token_ids
[
i1
],
self
.
req_output_token_ids
[
i2
]
=
\
self
.
req_output_token_ids
[
i2
],
self
.
req_output_token_ids
[
i1
]
assert
old_id_i1
is
not
None
and
old_id_i2
is
not
None
self
.
req_id_to_index
[
old_id_i1
],
self
.
req_id_to_index
[
old_id_i2
]
=
\
self
.
req_id_to_index
[
old_id_i2
],
self
.
req_id_to_index
[
old_id_i1
]
self
.
num_tokens
[
i1
],
self
.
num_tokens
[
i2
]
=
\
self
.
num_tokens
[
i2
],
self
.
num_tokens
[
i1
]
self
.
num_tokens_no_spec
[
i1
],
self
.
num_tokens_no_spec
[
i2
]
=
\
self
.
num_tokens_no_spec
[
i2
],
self
.
num_tokens_no_spec
[
i1
]
self
.
num_prompt_tokens
[
i1
],
self
.
num_prompt_tokens
[
i2
]
=
\
self
.
num_prompt_tokens
[
i2
],
self
.
num_prompt_tokens
[
i1
]
self
.
num_computed_tokens_cpu
[
i1
],
self
.
num_computed_tokens_cpu
[
i2
]
=
\
self
.
num_computed_tokens_cpu
[
i2
],
self
.
num_computed_tokens_cpu
[
i1
]
self
.
temperature_cpu
[
i1
],
self
.
temperature_cpu
[
i2
]
=
\
self
.
temperature_cpu
[
i2
],
self
.
temperature_cpu
[
i1
]
self
.
top_p_cpu
[
i1
],
self
.
top_p_cpu
[
i2
]
=
\
self
.
top_p_cpu
[
i2
],
self
.
top_p_cpu
[
i1
]
self
.
top_k_cpu
[
i1
],
self
.
top_k_cpu
[
i2
]
=
\
self
.
top_k_cpu
[
i2
],
self
.
top_k_cpu
[
i1
]
self
.
frequency_penalties_cpu
[
i1
],
self
.
frequency_penalties_cpu
[
i2
]
=
\
self
.
frequency_penalties_cpu
[
i2
],
self
.
frequency_penalties_cpu
[
i1
]
self
.
presence_penalties_cpu
[
i1
],
self
.
presence_penalties_cpu
[
i2
]
=
\
self
.
presence_penalties_cpu
[
i2
],
self
.
presence_penalties_cpu
[
i1
]
self
.
repetition_penalties_cpu
[
i1
],
self
.
repetition_penalties_cpu
[
i2
]
=
\
self
.
repetition_penalties_cpu
[
i2
],
self
.
repetition_penalties_cpu
[
i1
]
self
.
min_p_cpu
[
i1
],
self
.
min_p_cpu
[
i2
]
=
\
self
.
min_p_cpu
[
i2
],
self
.
min_p_cpu
[
i1
]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp
=
self
.
token_ids_cpu
[
i1
,
...].
copy
()
self
.
token_ids_cpu
[
i1
,
...]
=
self
.
token_ids_cpu
[
i2
,
...]
self
.
token_ids_cpu
[
i2
,
...]
=
tmp
swap_dict_values
(
self
.
generators
,
i1
,
i2
)
swap_dict_values
(
self
.
min_tokens
,
i1
,
i2
)
swap_dict_values
(
self
.
bad_words_token_ids
,
i1
,
i2
)
self
.
request_lora_mapping
[
i1
],
self
.
request_lora_mapping
[
i2
]
=
\
self
.
request_lora_mapping
[
i2
],
self
.
request_lora_mapping
[
i1
]
self
.
logit_bias
[
i1
],
self
.
logit_bias
[
i2
]
=
\
self
.
logit_bias
[
i2
],
self
.
logit_bias
[
i1
]
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
self
.
allowed_token_ids_mask_cpu_tensor
[
i1
],
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i2
]
=
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i2
],
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i1
]
self
.
block_table
.
swap_row
(
i1
,
i2
)
def
condense
(
self
,
empty_req_indices
:
list
[
int
])
->
None
:
"""Move non-empty requests down into lower, empty indices.
Args:
empty_req_indices: empty batch indices, sorted descending.
"""
num_reqs
=
self
.
num_reqs
if
num_reqs
==
0
:
# The batched states are empty.
self
.
_req_ids
.
clear
()
self
.
req_output_token_ids
.
clear
()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index
=
num_reqs
+
len
(
empty_req_indices
)
-
1
while
empty_req_indices
:
# Find the largest non-empty index.
while
last_req_index
in
empty_req_indices
:
last_req_index
-=
1
# Find the smallest empty index.
empty_index
=
empty_req_indices
.
pop
()
if
empty_index
>=
last_req_index
:
break
# Swap the states.
req_id
=
self
.
_req_ids
[
last_req_index
]
output_token_ids
=
self
.
req_output_token_ids
[
last_req_index
]
assert
req_id
is
not
None
self
.
_req_ids
[
empty_index
]
=
req_id
self
.
_req_ids
[
last_req_index
]
=
None
self
.
req_output_token_ids
[
empty_index
]
=
output_token_ids
self
.
req_output_token_ids
[
last_req_index
]
=
None
self
.
req_id_to_index
[
req_id
]
=
empty_index
num_tokens
=
self
.
num_tokens
[
last_req_index
]
self
.
token_ids_cpu
[
empty_index
,
:
num_tokens
]
=
self
.
token_ids_cpu
[
last_req_index
,
:
num_tokens
]
self
.
num_tokens
[
empty_index
]
=
num_tokens
self
.
num_tokens_no_spec
[
empty_index
]
=
self
.
num_tokens_no_spec
[
last_req_index
]
self
.
num_prompt_tokens
[
empty_index
]
=
self
.
num_prompt_tokens
[
last_req_index
]
self
.
num_computed_tokens_cpu
[
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
self
.
block_table
.
move_row
(
last_req_index
,
empty_index
)
self
.
temperature_cpu
[
empty_index
]
=
self
.
temperature_cpu
[
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
self
.
top_k_cpu
[
empty_index
]
=
self
.
top_k_cpu
[
last_req_index
]
self
.
frequency_penalties_cpu
[
empty_index
]
=
self
.
frequency_penalties_cpu
[
last_req_index
]
self
.
presence_penalties_cpu
[
empty_index
]
=
self
.
presence_penalties_cpu
[
last_req_index
]
self
.
repetition_penalties_cpu
[
empty_index
]
=
self
.
repetition_penalties_cpu
[
last_req_index
]
self
.
min_p_cpu
[
empty_index
]
=
self
.
min_p_cpu
[
last_req_index
]
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
if
generator
is
not
None
:
self
.
generators
[
empty_index
]
=
generator
min_token
=
self
.
min_tokens
.
pop
(
last_req_index
,
None
)
if
min_token
is
not
None
:
self
.
min_tokens
[
empty_index
]
=
min_token
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
last_req_index
]
self
.
logit_bias
[
empty_index
]
=
self
.
logit_bias
[
last_req_index
]
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
self
.
allowed_token_ids_mask_cpu_tensor
[
empty_index
]
=
self
.
allowed_token_ids_mask_cpu_tensor
[
last_req_index
]
bad_words_token_ids
=
self
.
bad_words_token_ids
.
pop
(
last_req_index
,
None
)
if
bad_words_token_ids
is
not
None
:
self
.
bad_words_token_ids
[
empty_index
]
=
bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
# Trim lists to the batch size.
del
self
.
_req_ids
[
self
.
num_reqs
:]
del
self
.
req_output_token_ids
[
self
.
num_reqs
:]
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
(
self
.
num_reqs
,
max_prompt_len
),
device
=
"cpu"
,
dtype
=
torch
.
int64
,
pin_memory
=
self
.
pin_memory
,
)
prompt_token_ids
=
prompt_token_ids_cpu_tensor
.
numpy
()
prompt_token_ids
[:]
=
self
.
token_ids_cpu
[:
self
.
num_reqs
,
:
max_prompt_len
]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for
i
in
range
(
self
.
num_reqs
):
prompt_token_ids
[
i
,
self
.
num_prompt_tokens
[
i
]:]
=
self
.
vocab_size
return
prompt_token_ids_cpu_tensor
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
def
make_lora_inputs
(
self
,
num_scheduled_tokens
:
np
.
ndarray
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
set
[
LoRARequest
]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping
=
self
.
request_lora_mapping
[:
self
.
num_reqs
]
prompt_lora_mapping
=
tuple
(
req_lora_mapping
)
token_lora_mapping
=
tuple
(
req_lora_mapping
.
repeat
(
num_scheduled_tokens
))
active_lora_requests
:
set
[
LoRARequest
]
=
set
(
self
.
lora_id_to_lora_request
.
values
())
return
prompt_lora_mapping
,
token_lora_mapping
,
active_lora_requests
@
property
def
num_reqs
(
self
)
->
int
:
return
len
(
self
.
req_id_to_index
)
@
property
def
all_greedy
(
self
)
->
bool
:
return
len
(
self
.
random_reqs
)
==
0
@
property
def
all_random
(
self
)
->
bool
:
return
len
(
self
.
greedy_reqs
)
==
0
@
property
def
no_top_p
(
self
)
->
bool
:
return
len
(
self
.
top_p_reqs
)
==
0
@
property
def
no_top_k
(
self
)
->
bool
:
return
len
(
self
.
top_k_reqs
)
==
0
@
property
def
no_min_p
(
self
)
->
bool
:
return
len
(
self
.
min_p_reqs
)
==
0
@
property
def
no_penalties
(
self
)
->
bool
:
return
(
len
(
self
.
presence_penalties_reqs
)
==
0
and
len
(
self
.
frequency_penalties_reqs
)
==
0
and
len
(
self
.
repetition_penalties_reqs
)
==
0
)
@
property
def
max_num_logprobs
(
self
)
->
Optional
[
int
]:
return
max
(
self
.
num_logprobs
.
values
())
if
self
.
num_logprobs
else
None
@
property
def
no_prompt_logprob
(
self
)
->
bool
:
return
not
self
.
num_prompt_logprobs
@
property
def
no_allowed_token_ids
(
self
)
->
bool
:
return
len
(
self
.
has_allowed_token_ids
)
==
0
vllm/v1/worker/tpu_model_runner.py
View file @
19a53b27
...
@@ -42,8 +42,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
...
@@ -42,8 +42,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.tpu_input_batch
import
CachedRequestState
,
InputBatch
from
.utils
import
(
initialize_kv_cache_for_kv_sharing
,
from
.utils
import
(
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
)
sanity_check_mm_encoder_outputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment