Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
408195ec
Unverified
Commit
408195ec
authored
Jan 21, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 21, 2026
Browse files
[Model Runner V2] Refactor Prompt Logprobs (#32811)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
63227acc
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
230 additions
and
142 deletions
+230
-142
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+17
-102
vllm/v1/worker/gpu/sample/logprob.py
vllm/v1/worker/gpu/sample/logprob.py
+0
-29
vllm/v1/worker/gpu/sample/prompt_logprob.py
vllm/v1/worker/gpu/sample/prompt_logprob.py
+212
-0
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+1
-11
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
408195ec
...
@@ -22,7 +22,6 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
...
@@ -22,7 +22,6 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
(
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
,
ModelRunnerOutput
,
)
)
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
...
@@ -51,8 +50,8 @@ from vllm.v1.worker.gpu.input_batch import (
...
@@ -51,8 +50,8 @@ from vllm.v1.worker.gpu.input_batch import (
)
)
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.prompt_logprob
import
PromptLogprobsWorker
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
...
@@ -156,6 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -156,6 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device
=
self
.
device
,
device
=
self
.
device
,
logprobs_mode
=
self
.
model_config
.
logprobs_mode
,
logprobs_mode
=
self
.
model_config
.
logprobs_mode
,
)
)
self
.
prompt_logprobs_worker
=
PromptLogprobsWorker
(
self
.
max_num_reqs
)
# CUDA graphs.
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
cudagraph_manager
=
CudaGraphManager
(
...
@@ -416,10 +416,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -416,10 +416,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
req_states
.
remove_request
(
req_id
)
self
.
req_states
.
remove_request
(
req_id
)
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
self
.
encoder_runner
.
remove_request
(
req_id
)
self
.
encoder_runner
.
remove_request
(
req_id
)
self
.
prompt_logprobs_worker
.
remove_request
(
req_id
)
for
req_id
in
scheduler_output
.
finished_req_ids
:
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
req_states
.
remove_request
(
req_id
)
self
.
req_states
.
remove_request
(
req_id
)
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
self
.
encoder_runner
.
remove_request
(
req_id
)
self
.
encoder_runner
.
remove_request
(
req_id
)
self
.
prompt_logprobs_worker
.
remove_request
(
req_id
)
def
free_states
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
def
free_states
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
...
@@ -438,7 +440,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -438,7 +440,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_len
=
prompt_len
,
prompt_len
=
prompt_len
,
prefill_token_ids
=
new_req_data
.
prefill_token_ids
,
prefill_token_ids
=
new_req_data
.
prefill_token_ids
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
sampling_params
=
new_req_data
.
sampling_params
,
lora_request
=
new_req_data
.
lora_request
,
lora_request
=
new_req_data
.
lora_request
,
)
)
req_index
=
self
.
req_states
.
req_id_to_index
[
req_id
]
req_index
=
self
.
req_states
.
req_id_to_index
[
req_id
]
...
@@ -461,6 +462,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -461,6 +462,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
sampler
.
add_request
(
self
.
sampler
.
add_request
(
req_index
,
prompt_len
,
new_req_data
.
sampling_params
req_index
,
prompt_len
,
new_req_data
.
sampling_params
)
)
self
.
prompt_logprobs_worker
.
add_request
(
req_id
,
req_index
,
new_req_data
.
sampling_params
)
if
scheduler_output
.
scheduled_new_reqs
:
if
scheduler_output
.
scheduled_new_reqs
:
self
.
req_states
.
apply_staged_writes
()
self
.
req_states
.
apply_staged_writes
()
...
@@ -729,104 +733,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -729,104 +733,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
return
sampler_output
,
num_sampled
,
num_rejected
return
sampler_output
,
num_sampled
,
num_rejected
def
compute_prompt_logprobs
(
self
,
hidden_states
:
torch
.
Tensor
,
input_batch
:
InputBatch
,
)
->
dict
[
str
,
LogprobsTensors
]:
idx_mapping_np
=
input_batch
.
idx_mapping_np
needs_prompt_logprobs
=
self
.
req_states
.
needs_prompt_logprobs
[
idx_mapping_np
]
if
not
np
.
any
(
needs_prompt_logprobs
):
# No request asks for prompt logprobs.
return
{}
prompt_lens
=
self
.
req_states
.
prompt_len
[
idx_mapping_np
]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
computed_prefill
=
self
.
req_states
.
num_computed_prefill_tokens
[
idx_mapping_np
]
includes_prompt
=
computed_prefill
<
prompt_lens
-
1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt
=
(
prompt_lens
<
self
.
req_states
.
prefill_len
.
np
[
idx_mapping_np
]
)
needs_prompt_logprobs
&=
includes_prompt
&
~
resumed_after_prompt
if
not
np
.
any
(
needs_prompt_logprobs
):
return
{}
# Just to be safe, clone the input ids.
n
=
input_batch
.
num_tokens
# Shift the input ids by one.
token_ids
=
torch
.
empty_like
(
input_batch
.
input_ids
[:
n
])
token_ids
[:
n
-
1
]
=
input_batch
.
input_ids
[
1
:
n
]
# To avoid out-of-bound access, set the last token id to 0.
token_ids
[
n
-
1
]
=
0
# Handle chunked prompts.
pos_after_step
=
computed_prefill
+
input_batch
.
num_scheduled_tokens
is_prompt_chunked
=
pos_after_step
<
prompt_lens
prefill_token_ids
=
self
.
req_states
.
prefill_token_ids
.
gpu
query_start_loc_np
=
input_batch
.
query_start_loc_np
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
if
not
needs_prompt_logprobs
[
i
]:
continue
if
not
is_prompt_chunked
[
i
]:
continue
# The prompt is chunked. Get the next prompt token.
req_idx
=
input_batch
.
idx_mapping_np
[
i
]
idx
=
int
(
query_start_loc_np
[
i
+
1
]
-
1
)
# NOTE(woosuk): This triggers two GPU operations.
next_prompt_token
=
prefill_token_ids
[
req_idx
,
pos_after_step
[
i
]]
token_ids
[
idx
]
=
next_prompt_token
# NOTE(woosuk): We mask out logprobs for negative tokens.
prompt_logprobs
,
prompt_ranks
=
compute_prompt_logprobs
(
token_ids
,
hidden_states
[:
n
],
self
.
model
.
compute_logits
,
)
prompt_token_ids
=
token_ids
.
unsqueeze
(
-
1
)
prompt_logprobs_dict
:
dict
[
str
,
LogprobsTensors
]
=
{}
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
if
not
needs_prompt_logprobs
[
i
]:
continue
start_idx
=
query_start_loc_np
[
i
]
end_idx
=
query_start_loc_np
[
i
+
1
]
assert
start_idx
<
end_idx
,
(
f
"start_idx (
{
start_idx
}
) >= end_idx (
{
end_idx
}
)"
)
logprobs
=
LogprobsTensors
(
logprob_token_ids
=
prompt_token_ids
[
start_idx
:
end_idx
],
logprobs
=
prompt_logprobs
[
start_idx
:
end_idx
],
selected_token_ranks
=
prompt_ranks
[
start_idx
:
end_idx
],
)
req_extra_data
=
self
.
req_states
.
extra_data
[
req_id
]
prompt_logprobs_list
=
req_extra_data
.
in_progress_prompt_logprobs
if
is_prompt_chunked
[
i
]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list
.
append
(
logprobs
)
continue
if
prompt_logprobs_list
:
# Merge the in-progress logprobs.
prompt_logprobs_list
.
append
(
logprobs
)
logprobs
=
LogprobsTensors
(
logprob_token_ids
=
torch
.
cat
(
[
x
.
logprob_token_ids
for
x
in
prompt_logprobs_list
]
),
logprobs
=
torch
.
cat
([
x
.
logprobs
for
x
in
prompt_logprobs_list
]),
selected_token_ranks
=
torch
.
cat
(
[
x
.
selected_token_ranks
for
x
in
prompt_logprobs_list
]
),
)
prompt_logprobs_list
.
clear
()
prompt_logprobs_dict
[
req_id
]
=
logprobs
return
prompt_logprobs_dict
def
postprocess
(
def
postprocess
(
self
,
self
,
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
...
@@ -1002,7 +908,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1002,7 +908,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampler_output
,
num_sampled
,
num_rejected
=
self
.
sample
(
sampler_output
,
num_sampled
,
num_rejected
=
self
.
sample
(
hidden_states
,
input_batch
,
grammar_output
hidden_states
,
input_batch
,
grammar_output
)
)
prompt_logprobs_dict
=
self
.
compute_prompt_logprobs
(
hidden_states
,
input_batch
)
prompt_logprobs_dict
=
self
.
prompt_logprobs_worker
.
compute_prompt_logprobs
(
self
.
model
.
compute_logits
,
hidden_states
,
input_batch
,
self
.
req_states
.
prefill_token_ids
.
gpu
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
self
.
req_states
.
prompt_len
,
self
.
req_states
.
prefill_len
.
np
,
self
.
req_states
.
num_computed_prefill_tokens
,
)
# Prepare the model runner output.
# Prepare the model runner output.
model_runner_output
=
ModelRunnerOutput
(
model_runner_output
=
ModelRunnerOutput
(
...
...
vllm/v1/worker/gpu/sample/logprob.py
View file @
408195ec
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
torch
import
torch
...
@@ -137,31 +136,3 @@ def compute_topk_logprobs(
...
@@ -137,31 +136,3 @@ def compute_topk_logprobs(
logprobs
=
logprobs
,
logprobs
=
logprobs
,
selected_token_ranks
=
token_ranks
,
selected_token_ranks
=
token_ranks
,
)
)
def
compute_prompt_logprobs
(
prompt_token_ids
:
torch
.
Tensor
,
prompt_hidden_states
:
torch
.
Tensor
,
logits_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE
=
1024
logprobs
=
[]
ranks
=
[]
prompt_token_ids
=
prompt_token_ids
.
to
(
torch
.
int64
)
for
start_idx
in
range
(
0
,
prompt_token_ids
.
shape
[
0
],
CHUNK_SIZE
):
end_idx
=
start_idx
+
CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits
=
logits_fn
(
prompt_hidden_states
[
start_idx
:
end_idx
])
prompt_logprobs
=
compute_topk_logprobs
(
prompt_logits
,
0
,
# num_logprobs
prompt_token_ids
[
start_idx
:
end_idx
],
)
logprobs
.
append
(
prompt_logprobs
.
logprobs
)
ranks
.
append
(
prompt_logprobs
.
selected_token_ranks
)
logprobs
=
torch
.
cat
(
logprobs
,
dim
=
0
)
if
len
(
logprobs
)
>
1
else
logprobs
[
0
]
ranks
=
torch
.
cat
(
ranks
,
dim
=
0
)
if
len
(
ranks
)
>
1
else
ranks
[
0
]
return
logprobs
,
ranks
vllm/v1/worker/gpu/sample/prompt_logprob.py
0 → 100644
View file @
408195ec
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
numpy
as
np
import
torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
class
PromptLogprobsWorker
:
def
__init__
(
self
,
max_num_reqs
:
int
):
self
.
max_num_reqs
=
max_num_reqs
self
.
uses_prompt_logprobs
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
bool
)
# req_idx -> list of in-progress LogprobsTensors
self
.
in_progress_prompt_logprobs
:
dict
[
str
,
list
[
LogprobsTensors
]]
=
{}
def
add_request
(
self
,
req_id
:
str
,
req_idx
:
int
,
sampling_params
:
SamplingParams
):
# For now, only support prompt logprobs for the prompt tokens (not top-k).
uses_prompt_logprobs
=
sampling_params
.
prompt_logprobs
is
not
None
if
uses_prompt_logprobs
:
self
.
uses_prompt_logprobs
[
req_idx
]
=
True
self
.
in_progress_prompt_logprobs
[
req_id
]
=
[]
else
:
self
.
uses_prompt_logprobs
[
req_idx
]
=
False
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
self
.
in_progress_prompt_logprobs
.
pop
(
req_id
,
None
)
def
compute_prompt_logprobs
(
self
,
logits_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
hidden_states
:
torch
.
Tensor
,
input_batch
:
InputBatch
,
# [max_num_reqs, max_model_len]
prefill_token_ids
:
torch
.
Tensor
,
# [max_num_reqs]
num_computed_tokens
:
torch
.
Tensor
,
# [max_num_reqs]
prompt_lens
:
np
.
ndarray
,
# [max_num_reqs]
prefill_lens
:
np
.
ndarray
,
# [max_num_reqs]
num_computed_prefill_tokens
:
np
.
ndarray
,
)
->
dict
[
str
,
LogprobsTensors
]:
idx_mapping_np
=
input_batch
.
idx_mapping_np
needs_prompt_logprobs
=
self
.
uses_prompt_logprobs
[
idx_mapping_np
]
if
not
np
.
any
(
needs_prompt_logprobs
):
# Common case: No request asks for prompt logprobs.
return
{}
prompt_lens
=
prompt_lens
[
idx_mapping_np
]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
computed_prefill
=
num_computed_prefill_tokens
[
idx_mapping_np
]
includes_prompt
=
computed_prefill
<
prompt_lens
-
1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt
=
prompt_lens
<
prefill_lens
[
idx_mapping_np
]
needs_prompt_logprobs
&=
includes_prompt
&
~
resumed_after_prompt
if
not
np
.
any
(
needs_prompt_logprobs
):
return
{}
# Get the prompt logprobs token_ids.
prompt_logprobs_token_ids
=
get_prompt_logprobs_token_ids
(
input_batch
.
num_tokens
,
input_batch
.
query_start_loc
,
input_batch
.
idx_mapping
,
num_computed_tokens
,
prefill_token_ids
,
)
# Compute the prompt logprobs.
prompt_logprobs
,
prompt_ranks
=
compute_prompt_logprobs_with_chunking
(
prompt_logprobs_token_ids
,
hidden_states
[:
input_batch
.
num_tokens
],
logits_fn
,
)
pos_after_step
=
computed_prefill
+
input_batch
.
num_scheduled_tokens
is_prompt_chunked
=
pos_after_step
<
prompt_lens
query_start_loc_np
=
input_batch
.
query_start_loc_np
prompt_token_ids
=
prompt_logprobs_token_ids
.
unsqueeze
(
-
1
)
prompt_logprobs_dict
:
dict
[
str
,
LogprobsTensors
]
=
{}
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
if
not
needs_prompt_logprobs
[
i
]:
continue
start_idx
=
query_start_loc_np
[
i
]
end_idx
=
query_start_loc_np
[
i
+
1
]
assert
start_idx
<
end_idx
,
(
f
"start_idx (
{
start_idx
}
) >= end_idx (
{
end_idx
}
)"
)
if
not
is_prompt_chunked
[
i
]:
end_idx
-=
1
logprobs
=
LogprobsTensors
(
logprob_token_ids
=
prompt_token_ids
[
start_idx
:
end_idx
],
logprobs
=
prompt_logprobs
[
start_idx
:
end_idx
],
selected_token_ranks
=
prompt_ranks
[
start_idx
:
end_idx
],
)
prompt_logprobs_list
=
self
.
in_progress_prompt_logprobs
[
req_id
]
if
is_prompt_chunked
[
i
]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list
.
append
(
logprobs
)
continue
if
prompt_logprobs_list
:
# Merge the in-progress logprobs.
prompt_logprobs_list
.
append
(
logprobs
)
logprobs
=
LogprobsTensors
(
logprob_token_ids
=
torch
.
cat
(
[
x
.
logprob_token_ids
for
x
in
prompt_logprobs_list
]
),
logprobs
=
torch
.
cat
([
x
.
logprobs
for
x
in
prompt_logprobs_list
]),
selected_token_ranks
=
torch
.
cat
(
[
x
.
selected_token_ranks
for
x
in
prompt_logprobs_list
]
),
)
prompt_logprobs_list
.
clear
()
prompt_logprobs_dict
[
req_id
]
=
logprobs
return
prompt_logprobs_dict
@
triton
.
jit
def
_prompt_logprobs_token_ids_kernel
(
prompt_logprobs_token_ids_ptr
,
query_start_loc_ptr
,
idx_mapping_ptr
,
num_computed_tokens_ptr
,
prefill_token_ids_ptr
,
prefill_token_ids_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
+
1
)
query_len
=
query_end
-
query_start
num_computed_tokens
=
tl
.
load
(
num_computed_tokens_ptr
+
req_state_idx
)
for
i
in
range
(
0
,
query_len
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
query_len
# NOTE(woosuk): We should shift the pos by one
# because the logprob is computed for the next token.
target_pos
=
num_computed_tokens
+
1
+
block
token_ids
=
tl
.
load
(
prefill_token_ids_ptr
+
req_state_idx
*
prefill_token_ids_stride
+
target_pos
,
mask
=
mask
,
)
tl
.
store
(
prompt_logprobs_token_ids_ptr
+
query_start
+
block
,
token_ids
,
mask
=
mask
)
def
get_prompt_logprobs_token_ids
(
num_tokens
:
int
,
query_start_loc
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
num_computed_tokens
:
torch
.
Tensor
,
prefill_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
token_ids
=
torch
.
empty
(
num_tokens
,
dtype
=
torch
.
int64
,
device
=
idx_mapping
.
device
)
num_reqs
=
idx_mapping
.
shape
[
0
]
_prompt_logprobs_token_ids_kernel
[(
num_reqs
,)](
token_ids
,
query_start_loc
,
idx_mapping
,
num_computed_tokens
,
prefill_token_ids
,
prefill_token_ids
.
stride
(
0
),
BLOCK_SIZE
=
1024
,
)
return
token_ids
def
compute_prompt_logprobs_with_chunking
(
prompt_token_ids
:
torch
.
Tensor
,
prompt_hidden_states
:
torch
.
Tensor
,
logits_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE
=
1024
logprobs
=
[]
ranks
=
[]
prompt_token_ids
=
prompt_token_ids
.
to
(
torch
.
int64
)
for
start_idx
in
range
(
0
,
prompt_token_ids
.
shape
[
0
],
CHUNK_SIZE
):
end_idx
=
start_idx
+
CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits
=
logits_fn
(
prompt_hidden_states
[
start_idx
:
end_idx
])
prompt_logprobs
=
compute_topk_logprobs
(
prompt_logits
,
0
,
# num_logprobs
prompt_token_ids
[
start_idx
:
end_idx
],
)
logprobs
.
append
(
prompt_logprobs
.
logprobs
)
ranks
.
append
(
prompt_logprobs
.
selected_token_ranks
)
logprobs
=
torch
.
cat
(
logprobs
,
dim
=
0
)
if
len
(
logprobs
)
>
1
else
logprobs
[
0
]
ranks
=
torch
.
cat
(
ranks
,
dim
=
0
)
if
len
(
ranks
)
>
1
else
ranks
[
0
]
return
logprobs
,
ranks
vllm/v1/worker/gpu/states.py
View file @
408195ec
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
NO_LORA_ID
=
0
NO_LORA_ID
=
0
...
@@ -76,8 +74,6 @@ class RequestState:
...
@@ -76,8 +74,6 @@ class RequestState:
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
self
.
needs_prompt_logprobs
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
bool
)
@
property
@
property
def
num_reqs
(
self
)
->
int
:
def
num_reqs
(
self
)
->
int
:
return
len
(
self
.
req_id_to_index
)
return
len
(
self
.
req_id_to_index
)
...
@@ -88,7 +84,6 @@ class RequestState:
...
@@ -88,7 +84,6 @@ class RequestState:
prompt_len
:
int
,
prompt_len
:
int
,
prefill_token_ids
:
list
[
int
],
prefill_token_ids
:
list
[
int
],
num_computed_tokens
:
int
,
num_computed_tokens
:
int
,
sampling_params
:
SamplingParams
,
lora_request
:
LoRARequest
|
None
,
lora_request
:
LoRARequest
|
None
,
)
->
None
:
)
->
None
:
assert
len
(
self
.
free_indices
)
>
0
,
"No free indices"
assert
len
(
self
.
free_indices
)
>
0
,
"No free indices"
...
@@ -112,10 +107,6 @@ class RequestState:
...
@@ -112,10 +107,6 @@ class RequestState:
else
:
else
:
self
.
lora_ids
[
req_idx
]
=
NO_LORA_ID
self
.
lora_ids
[
req_idx
]
=
NO_LORA_ID
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs
=
sampling_params
.
prompt_logprobs
is
not
None
self
.
needs_prompt_logprobs
[
req_idx
]
=
needs_prompt_logprobs
def
apply_staged_writes
(
self
)
->
None
:
def
apply_staged_writes
(
self
)
->
None
:
self
.
prefill_len
.
copy_to_uva
()
self
.
prefill_len
.
copy_to_uva
()
self
.
prefill_token_ids
.
apply_write
()
self
.
prefill_token_ids
.
apply_write
()
...
@@ -151,4 +142,3 @@ class RequestState:
...
@@ -151,4 +142,3 @@ class RequestState:
@
dataclass
@
dataclass
class
ExtraData
:
class
ExtraData
:
lora_request
:
LoRARequest
|
None
lora_request
:
LoRARequest
|
None
in_progress_prompt_logprobs
:
list
[
LogprobsTensors
]
=
field
(
default_factory
=
list
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment