Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c21ce9e
Unverified
Commit
4c21ce9e
authored
Feb 17, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 17, 2025
Browse files
[V1] Get input tokens from scheduler (#13339)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
ce77eb94
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
139 additions
and
139 deletions
+139
-139
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+1
-0
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+28
-15
vllm/v1/core/scheduler_output.py
vllm/v1/core/scheduler_output.py
+8
-7
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+102
-117
No files found.
tests/v1/worker/test_gpu_model_runner.py
View file @
4c21ce9e
...
...
@@ -154,6 +154,7 @@ def test_update_states_request_resumed(model_runner):
cached_req_data
=
CachedRequestData
(
req_id
=
req_id
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_block_ids
=
[],
num_computed_tokens
=
0
,
)
...
...
vllm/v1/core/scheduler.py
View file @
4c21ce9e
...
...
@@ -121,6 +121,8 @@ class Scheduler:
encoder_budget
=
self
.
max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens
:
Dict
[
str
,
List
[
int
]]
=
{}
# For logging.
scheduled_timestamp
=
time
.
monotonic
()
# First, schedule the RUNNING requests.
...
...
@@ -187,6 +189,15 @@ class Scheduler:
token_budget
-=
num_new_tokens
req_index
+=
1
# Speculative decode related.
if
request
.
spec_token_ids
:
num_scheduled_spec_tokens
=
(
num_new_tokens
+
request
.
num_computed_tokens
-
request
.
num_tokens
)
if
num_scheduled_spec_tokens
>
0
:
scheduled_spec_decode_tokens
[
request
.
request_id
]
=
(
request
.
spec_token_ids
[:
num_scheduled_spec_tokens
])
# Encoder-related.
if
encoder_inputs_to_schedule
:
scheduled_encoder_inputs
[
request
.
request_id
]
=
(
...
...
@@ -196,11 +207,6 @@ class Scheduler:
self
.
encoder_cache_manager
.
allocate
(
request
,
i
)
encoder_budget
=
new_encoder_budget
# Speculative decode related.
if
request
.
spec_token_ids
:
scheduled_spec_decode_tokens
[
request
.
request_id
]
=
request
.
spec_token_ids
# Record the LoRAs in scheduled_running_reqs
requested_loras
:
Set
[
int
]
=
set
()
if
self
.
lora_config
:
...
...
@@ -324,23 +330,24 @@ class Scheduler:
# Construct the scheduler output.
new_reqs_data
=
[
NewRequestData
.
from_request
(
req
,
req_to_new_block_ids
[
req
.
request_id
],
req
.
num_computed_tokens
)
req_to_new_block_ids
[
req
.
request_id
])
for
req
in
scheduled_new_reqs
]
resumed_reqs_data
=
[
self
.
_make_cached_request_data
(
req
,
num_scheduled_tokens
[
req
.
request_id
],
len
(
scheduled_spec_decode_tokens
.
get
(
req
.
request_id
,
())),
req_to_new_block_ids
[
req
.
request_id
],
req
.
num_computed_tokens
,
resumed_from_preemption
=
True
,
)
for
req
in
scheduled_resumed_reqs
]
running_reqs_data
=
[
self
.
_make_cached_request_data
(
req
,
num_scheduled_tokens
[
req
.
request_id
],
len
(
scheduled_spec_decode_tokens
.
get
(
req
.
request_id
,
())),
req_to_new_block_ids
[
req
.
request_id
],
req
.
num_computed_tokens
,
resumed_from_preemption
=
False
,
)
for
req
in
scheduled_running_reqs
]
...
...
@@ -349,8 +356,8 @@ class Scheduler:
scheduled_cached_reqs
=
resumed_reqs_data
+
running_reqs_data
,
num_scheduled_tokens
=
num_scheduled_tokens
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
scheduled_encoder_inputs
=
scheduled_encoder_inputs
,
scheduled_spec_decode_tokens
=
scheduled_spec_decode_tokens
,
scheduled_encoder_inputs
=
scheduled_encoder_inputs
,
num_common_prefix_blocks
=
num_common_prefix_blocks
,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
...
...
@@ -366,22 +373,28 @@ class Scheduler:
def
_make_cached_request_data
(
self
,
request
:
Request
,
num_scheduled_tokens
:
int
,
num_scheduled_spec_tokens
:
int
,
new_block_ids
:
List
[
int
],
num_computed_tokens
:
int
,
resumed_from_preemption
:
bool
,
)
->
"CachedRequestData"
:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
if
request
.
request_id
in
self
.
_cached_reqs_data
:
req_data
=
self
.
_cached_reqs_data
[
request
.
request_id
]
num_computed_tokens
=
request
.
num_computed_tokens
num_regular_tokens
=
num_scheduled_tokens
-
num_scheduled_spec_tokens
new_token_ids
=
request
.
all_token_ids
[
num_computed_tokens
:
num_computed_tokens
+
num_regular_tokens
]
req_data
=
self
.
_cached_reqs_data
.
get
(
request
.
request_id
)
if
req_data
is
not
None
:
req_data
.
resumed_from_preemption
=
resumed_from_preemption
req_data
.
new_token_ids
=
new_token_ids
req_data
.
new_block_ids
=
new_block_ids
req_data
.
num_computed_tokens
=
num_computed_tokens
else
:
req_data
=
CachedRequestData
.
from_request
(
request
,
resumed_from_preemption
,
new_
block
_ids
,
n
um_computed_token
s
)
new_
token
_ids
,
n
ew_block_id
s
)
self
.
_cached_reqs_data
[
request
.
request_id
]
=
req_data
return
req_data
...
...
vllm/v1/core/scheduler_output.py
View file @
4c21ce9e
...
...
@@ -30,7 +30,6 @@ class NewRequestData:
cls
,
request
:
"Request"
,
block_ids
:
List
[
int
],
num_computed_tokens
:
int
,
)
->
"NewRequestData"
:
return
cls
(
req_id
=
request
.
request_id
,
...
...
@@ -41,7 +40,7 @@ class NewRequestData:
mm_positions
=
request
.
mm_positions
,
sampling_params
=
request
.
sampling_params
,
block_ids
=
block_ids
,
num_computed_tokens
=
num_computed_tokens
,
num_computed_tokens
=
request
.
num_computed_tokens
,
lora_request
=
request
.
lora_request
,
)
...
...
@@ -54,6 +53,7 @@ class CachedRequestData:
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption
:
bool
new_token_ids
:
List
[
int
]
new_block_ids
:
List
[
int
]
num_computed_tokens
:
int
...
...
@@ -62,14 +62,15 @@ class CachedRequestData:
cls
,
request
:
"Request"
,
resumed_from_preemption
:
bool
,
new_token_ids
:
List
[
int
],
new_block_ids
:
List
[
int
],
num_computed_tokens
:
int
,
)
->
"CachedRequestData"
:
return
cls
(
req_id
=
request
.
request_id
,
resumed_from_preemption
=
resumed_from_preemption
,
new_token_ids
=
new_token_ids
,
new_block_ids
=
new_block_ids
,
num_computed_tokens
=
num_computed_tokens
,
num_computed_tokens
=
request
.
num_computed_tokens
,
)
...
...
@@ -91,9 +92,9 @@ class SchedulerOutput:
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens
:
int
# req_id -> spec_
decode_
tokens
# If a request does not have any spec decode tokens, it will
#
not be
included in the dictionary.
# req_id -> spec_token
_id
s
# If a request does not have any spec decode tokens, it will
not be
# included in the dictionary.
scheduled_spec_decode_tokens
:
Dict
[
str
,
List
[
int
]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4c21ce9e
...
...
@@ -2,7 +2,7 @@
import
gc
import
time
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
cast
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -184,7 +184,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
max_model_len
,
self
.
max_num_tokens
),
dtype
=
np
.
int32
)
self
.
arange_cpu
=
torch
.
from_numpy
(
self
.
arange_np
)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
...
...
@@ -327,7 +326,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_state
=
self
.
requests
[
req_id
]
# Update the cached states.
req_state
.
num_computed_tokens
=
req_data
.
num_computed_tokens
num_computed_tokens
=
req_data
.
num_computed_tokens
req_state
.
num_computed_tokens
=
num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens
=
(
num_computed_tokens
+
len
(
req_data
.
new_token_ids
)
-
req_state
.
num_tokens
)
new_token_ids
=
(
req_data
.
new_token_ids
[
-
num_new_tokens
:]
if
num_new_tokens
>
0
else
[])
req_state
.
output_token_ids
.
extend
(
new_token_ids
)
# Update the block IDs.
if
not
req_data
.
resumed_from_preemption
:
# Append the new blocks to the existing block IDs.
req_state
.
block_ids
.
extend
(
req_data
.
new_block_ids
)
...
...
@@ -346,12 +355,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch.
self
.
input_batch
.
num_computed_tokens_cpu
[
req_index
]
=
(
req_data
.
num_computed_tokens
)
start_index
=
len
(
req_state
.
block_ids
)
-
len
(
req_data
.
new_block_ids
)
num_computed_tokens
)
start_index
=
(
len
(
req_state
.
block_ids
)
-
len
(
req_data
.
new_block_ids
)
)
self
.
input_batch
.
block_table
.
append_row
(
req_index
,
start_index
,
req_data
.
new_block_ids
)
# Add new_token_ids to token_ids_cpu.
start_token_index
=
num_computed_tokens
end_token_index
=
num_computed_tokens
+
len
(
req_data
.
new_token_ids
)
self
.
input_batch
.
token_ids_cpu
[
req_index
,
start_token_index
:
end_token_index
]
=
req_data
.
new_token_ids
# Add spec_token_ids to token_ids_cpu.
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[])
if
spec_token_ids
:
start_index
=
end_token_index
end_token_index
+=
len
(
spec_token_ids
)
self
.
input_batch
.
token_ids_cpu
[
req_index
,
start_index
:
end_token_index
]
=
spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self
.
input_batch
.
num_tokens
[
req_index
]
=
end_token_index
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed
=
len
(
removed_req_indices
)
>
0
or
len
(
req_ids_to_add
)
>
0
# Add the new or resumed requests to the persistent batch.
...
...
@@ -374,7 +401,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
batch_changed
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
Tuple
[
FlashAttentionMetadata
,
torch
.
Tensor
]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
...
...
@@ -387,24 +415,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens
_list
:
List
[
int
]
=
[]
num_scheduled_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
0
all_spec_token_ids
:
List
[
int
]
=
[]
num_spec_tokens_list
:
List
[
int
]
=
[]
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens
_list
.
append
(
num_tokens
)
num_scheduled_tokens
[
i
]
=
num_tokens
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
num_tokens
)
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[])
all_spec_token_ids
.
extend
(
spec_token_ids
)
num_spec_tokens_list
.
append
(
len
(
spec_token_ids
))
num_scheduled_tokens
:
np
.
ndarray
=
np
.
array
(
num_scheduled_tokens_list
,
dtype
=
np
.
int32
)
assert
max_num_scheduled_tokens
>
0
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
...
...
@@ -441,78 +459,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
token_indices
=
(
positions_np
+
req_indices
*
self
.
input_batch
.
token_ids_cpu
.
shape
[
1
])
use_spec_decode
=
len
(
all_spec_token_ids
)
>
0
if
use_spec_decode
:
# 1. Write spec_token_ids to input batch.
# Step 1. Get req indices that perform spec decode and repeat
# the req indices by the number of spec tokens. Note
# for requests that don't perform spec decode, the
# number of spec tokens is 0 and the req index is
# repeated 0 times.
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1]
# spec_req_indices: [0, 0, 0, 2, 2, 4]
spec_req_indices
=
np
.
repeat
(
self
.
arange_np
[:
num_reqs
],
num_spec_tokens_list
)
# spec_offsets: offsets within each spec token list.
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
spec_offsets
=
np
.
concatenate
(
[
self
.
arange_np
[
1
:
val
+
1
]
for
val
in
num_spec_tokens_list
])
# spec_seq_offsets: offsets within each sequence.
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2]
# after repeating: [1, 1, 1, 3, 3, 2]
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
# = [2, 3, 4, 4, 5, 3]
spec_seq_offsets
=
np
.
repeat
(
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
],
num_spec_tokens_list
)
+
spec_offsets
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
cumsums_spec_offsets
=
(
spec_seq_offsets
+
spec_req_indices
*
self
.
input_batch
.
token_ids_cpu
.
shape
[
1
])
cumsums_spec_offsets
=
torch
.
from_numpy
(
cumsums_spec_offsets
).
to
(
torch
.
int64
)
all_spec_token_ids
=
torch
.
tensor
(
all_spec_token_ids
,
device
=
"cpu"
,
dtype
=
self
.
input_ids_cpu
.
dtype
)
# Step 2. Write spec token ids to input_ids_cpu.
self
.
input_batch
.
token_ids_cpu_tensor
.
flatten
().
scatter_
(
0
,
cumsums_spec_offsets
,
all_spec_token_ids
)
# 2. Get spec decode logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# num_sampled_tokens: [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_spec_tokens_np
=
np
.
array
(
num_spec_tokens_list
,
dtype
=
np
.
int32
)
num_sampled_tokens
=
num_spec_tokens_np
+
1
# logits_start_loc: [0, 103, 104, 206, 207]
logits_start_loc
=
cu_num_tokens
-
num_sampled_tokens
# [0, 103, 104, 206, 207] ->
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc
=
np
.
repeat
(
logits_start_loc
,
num_sampled_tokens
)
# The following three lines:
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
cu_num_sampled_tokens
=
np
.
cumsum
(
num_sampled_tokens
)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets
=
np
.
repeat
(
cu_num_sampled_tokens
-
num_sampled_tokens
,
num_sampled_tokens
)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens
=
num_sampled_tokens
.
sum
()
sampled_arange
=
(
self
.
arange_np
[:
total_num_sampled_tokens
]
-
cumsums_sampled_offsets
)
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
spec_decode_logits_indices
=
logits_start_loc
+
sampled_arange
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
...
...
@@ -606,9 +552,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
suffix_kv_lens
=
suffix_kv_lens
,
)
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
if
use_spec_decode
:
logits_indices
=
torch
.
from_numpy
(
spec_decode_logits_indices
).
to
(
s
elf
.
device
,
non_blocking
=
True
)
logits_indices
=
self
.
_calc_spec_decode_metadata
(
s
cheduler_output
,
cu_num_tokens
)
else
:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
...
...
@@ -762,6 +710,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
mrope_pos_ptr
+=
completion_part_len
def
_calc_spec_decode_metadata
(
self
,
scheduler_output
:
"SchedulerOutput"
,
cu_num_tokens
:
np
.
ndarray
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Get the number of spec decode tokens for each request.
num_reqs
=
self
.
input_batch
.
num_reqs
num_spec_decode_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
num_spec_decode_tokens
[
i
]
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
()))
# Get spec decode logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# num_sampled_tokens: [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_sampled_tokens
=
num_spec_decode_tokens
+
1
# logits_start_loc: [0, 103, 104, 206, 207]
logits_start_loc
=
cu_num_tokens
-
num_sampled_tokens
# [0, 103, 104, 206, 207] ->
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc
=
np
.
repeat
(
logits_start_loc
,
num_sampled_tokens
)
# The following three lines:
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
cu_num_sampled_tokens
=
np
.
cumsum
(
num_sampled_tokens
)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets
=
np
.
repeat
(
cu_num_sampled_tokens
-
num_sampled_tokens
,
num_sampled_tokens
)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens
=
num_sampled_tokens
.
sum
()
sampled_arange
=
(
self
.
arange_np
[:
total_num_sampled_tokens
]
-
cumsums_sampled_offsets
)
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
spec_decode_logits_indices
=
logits_start_loc
+
sampled_arange
return
torch
.
from_numpy
(
spec_decode_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
def
_prepare_sampling
(
self
,
batch_changed
:
bool
,
...
...
@@ -773,7 +768,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
req_id
,
req
in
self
.
requests
.
items
()}
sampling_metadata
=
self
.
input_batch
.
make_sampling_metadata
(
req_id_output_token_ids
,
req_to_spec_token_ids
,
not
batch_changed
)
req_id_output_token_ids
,
req_to_spec_token_ids
,
skip_copy
=
not
batch_changed
)
return
sampling_metadata
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
...
...
@@ -960,28 +957,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs
=
self
.
input_batch
.
num_reqs
request_seq_lens
:
List
[
Tuple
[
int
,
CachedRequestState
,
int
]]
=
[]
req_ids
:
List
[
str
]
=
[]
# Because `input_batch.req_ids` is a list of length `max_num_reqs`,
# we need to stop at `num_reqs`.
# FIXME(woosuk): This is hacky. Refactor.
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
req_ids
.
append
(
req_id
)
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
if
seq_len
>=
req_state
.
num_tokens
:
request_seq_lens
.
append
((
i
,
req_state
,
seq_len
))
else
:
# Ignore the sampled token from the partial request.
if
seq_len
<
req_state
.
num_tokens
:
# Ignore the sampled token.
# Rewind the generator state as if the token was not sampled.
generator
=
self
.
input_batch
.
generators
.
get
(
i
)
if
generator
is
not
None
:
# This relies on cuda-specific torch-internal impl details
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
# num_reqs entries should be non-None
assert
all
(
req_id
is
not
None
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]),
"req_ids contains None"
req_ids
=
cast
(
List
[
str
],
self
.
input_batch
.
req_ids
[:
num_reqs
])
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors
=
sampler_output
.
logprobs_tensors
...
...
@@ -994,29 +987,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
,
)
#
Update batch with
the valid generated tokens.
#
Get
the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
max_gen_len
==
1
:
# No spec decode tokens.
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
token_id
=
valid_sampled_token_ids
[
i
][
0
]
self
.
input_batch
.
token_ids_cpu
[
i
,
seq_len
]
=
token_id
req_state
.
output_token_ids
.
append
(
token_id
)
self
.
input_batch
.
num_tokens
[
i
]
+=
1
else
:
# Includes spec decode tokens.
valid_mask
=
sampled_token_ids
!=
INVALID_TOKEN_ID
gen_lens
=
valid_mask
.
sum
(
dim
=
1
).
tolist
()
# TODO(woosuk): Optimize this.
valid_sampled_token_ids
=
[
seq
.
tolist
()
for
seq
in
sampled_token_ids
[
valid_mask
].
split
(
gen_lens
)
]
self
.
input_batch
.
num_tokens
[:
num_reqs
]
+=
gen_lens
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
target_slice
=
slice
(
seq_len
-
gen_lens
[
i
]
+
1
,
seq_len
+
1
)
self
.
input_batch
.
token_ids_cpu
[
i
,
target_slice
]
=
valid_sampled_token_ids
[
i
]
req_state
.
output_token_ids
.
extend
(
valid_sampled_token_ids
[
i
])
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment