Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca1b1e72
Unverified
Commit
ca1b1e72
authored
Nov 28, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 28, 2025
Browse files
[Model Runner V2] Refactor prefill token preparation (#29712)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
762a4a6c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
83 additions
and
78 deletions
+83
-78
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+1
-1
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+52
-34
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+17
-29
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+8
-11
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+5
-3
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
ca1b1e72
...
@@ -78,7 +78,7 @@ class CudaGraphManager:
...
@@ -78,7 +78,7 @@ class CudaGraphManager:
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
)
->
None
:
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
input_ids
=
input_buffers
.
input_ids
.
gpu
[:
num_tokens
]
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
attn_metadata
=
prepare_inputs_to_capture
(
attn_metadata
=
prepare_inputs_to_capture
(
num_reqs
,
num_reqs
,
...
...
vllm/v1/worker/gpu/input_batch.py
View file @
ca1b1e72
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
from
typing
import
Any
import
numba
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -30,15 +29,12 @@ class InputBuffers:
...
@@ -30,15 +29,12 @@ class InputBuffers:
self
.
pin_memory
=
pin_memory
self
.
pin_memory
=
pin_memory
self
.
idx_mapping
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
idx_mapping
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
input_ids
=
self
.
_make_buffer
(
max_num_tokens
,
dtype
=
torch
.
int32
)
self
.
input_ids
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
query_start_loc
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
self
.
query_start_loc
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
cu_num_logits
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
self
.
cu_num_logits
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
# Spec decoding.
self
.
next_prefill_tokens
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
# Structured outputs.
# Structured outputs.
self
.
bitmask_indices
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
bitmask_indices
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
grammar_bitmask
=
self
.
_make_buffer
(
self
.
grammar_bitmask
=
self
.
_make_buffer
(
...
@@ -120,7 +116,7 @@ class InputBatch:
...
@@ -120,7 +116,7 @@ class InputBatch:
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
seq_lens
=
input_buffers
.
seq_lens
[:
num_reqs
]
seq_lens
=
input_buffers
.
seq_lens
[:
num_reqs
]
input_ids
=
input_buffers
.
input_ids
.
copy_to_gpu
(
num_tokens
)
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
# attn_metadata = defaultdict(lambda: None)
# attn_metadata = defaultdict(lambda: None)
logits_indices
=
query_start_loc
[
1
:]
-
1
logits_indices
=
query_start_loc
[
1
:]
-
1
...
@@ -146,41 +142,63 @@ class InputBatch:
...
@@ -146,41 +142,63 @@ class InputBatch:
)
)
@
numba
.
njit
(
cache
=
True
)
@
triton
.
jit
def
_prepare_prefill_inputs
(
def
_prepare_prefill_inputs_kernel
(
idx_mapping
:
np
.
ndarray
,
# [B]
input_ids_ptr
,
query_lens
:
np
.
ndarray
,
# [B]
next_prefill_tokens_ptr
,
query_start_loc
:
np
.
ndarray
,
# [B + 1]
idx_mapping_ptr
,
prefill_token_ids
:
np
.
ndarray
,
# [N, max_model_len]
query_start_loc_ptr
,
num_computed_prefill_tokens
:
np
.
ndarray
,
# [N]
prefill_token_ids_ptr
,
input_ids
:
np
.
ndarray
,
# [num_input_tokens]
prefill_token_ids_stride
,
)
->
None
:
prefill_lens_ptr
,
num_reqs
=
idx_mapping
.
shape
[
0
]
num_computed_tokens_ptr
,
query_starts
=
query_start_loc
[:
num_reqs
]
BLOCK_SIZE
:
tl
.
constexpr
,
query_ends
=
query_start_loc
[
1
:
num_reqs
+
1
]
):
starts
=
num_computed_prefill_tokens
[
idx_mapping
]
batch_idx
=
tl
.
program_id
(
0
)
ends
=
starts
+
query_lens
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
for
i
in
range
(
num_reqs
):
prefill_len
=
tl
.
load
(
prefill_lens_ptr
+
req_state_idx
)
input_ids
[
query_starts
[
i
]
:
query_ends
[
i
]]
=
prefill_token_ids
[
num_computed
=
tl
.
load
(
num_computed_tokens_ptr
+
req_state_idx
)
idx_mapping
[
i
],
starts
[
i
]
:
ends
[
i
]
if
num_computed
>=
prefill_len
:
]
# Not prefill.
return
query_start
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
+
1
)
query_len
=
query_end
-
query_start
prefill_ptr
=
prefill_token_ids_ptr
+
req_state_idx
*
prefill_token_ids_stride
for
i
in
range
(
0
,
query_len
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
query_len
tokens
=
tl
.
load
(
prefill_ptr
+
num_computed
+
block
,
mask
=
mask
)
tl
.
store
(
input_ids_ptr
+
query_start
+
block
,
tokens
,
mask
=
mask
)
next_pos
=
num_computed
+
query_len
if
next_pos
<
prefill_len
:
next_token
=
tl
.
load
(
prefill_ptr
+
next_pos
)
tl
.
store
(
next_prefill_tokens_ptr
+
req_state_idx
,
next_token
)
def
prepare_prefill_inputs
(
def
prepare_prefill_inputs
(
idx_mapping
:
np
.
ndarray
,
input_ids
:
torch
.
Tensor
,
num_scheduled_tokens
:
np
.
ndarray
,
next_prefill_tokens
:
torch
.
Tensor
,
query_start_loc
:
np
.
ndarray
,
idx_mapping
:
torch
.
Tensor
,
prefill_token_ids
:
np
.
ndarray
,
query_start_loc
:
torch
.
Tensor
,
num_computed_prefill_tokens
:
np
.
ndarray
,
prefill_token_ids
:
torch
.
Tensor
,
input_ids
:
np
.
ndarray
,
prefill_len
:
torch
.
Tensor
,
num_computed_tokens
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
_prepare_prefill_inputs
(
num_reqs
=
idx_mapping
.
shape
[
0
]
_prepare_prefill_inputs_kernel
[(
num_reqs
,)](
input_ids
,
next_prefill_tokens
,
idx_mapping
,
idx_mapping
,
num_scheduled_tokens
,
query_start_loc
,
query_start_loc
,
prefill_token_ids
,
prefill_token_ids
,
num_computed_prefill_tokens
,
prefill_token_ids
.
stride
(
0
),
input_ids
,
prefill_len
,
num_computed_tokens
,
BLOCK_SIZE
=
1024
,
)
)
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
ca1b1e72
...
@@ -104,11 +104,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -104,11 +104,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
self
.
use_async_scheduling
:
if
self
.
use_async_scheduling
:
self
.
input_prep_event
=
torch
.
cuda
.
Event
()
self
.
input_prep_event
=
torch
.
cuda
.
Event
()
self
.
structured_outputs_event
=
torch
.
cuda
.
Event
()
self
.
structured_outputs_event
=
torch
.
cuda
.
Event
()
self
.
spec_decode_event
=
torch
.
cuda
.
Event
()
else
:
else
:
self
.
input_prep_event
=
None
self
.
input_prep_event
=
None
self
.
structured_outputs_event
=
None
self
.
structured_outputs_event
=
None
self
.
spec_decode_event
=
None
if
self
.
speculative_config
is
not
None
:
if
self
.
speculative_config
is
not
None
:
self
.
do_spec_decode
=
True
self
.
do_spec_decode
=
True
...
@@ -412,9 +410,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -412,9 +410,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_new_blocks
[
i
].
append
(
x
+
len
(
block_ids
))
cu_num_new_blocks
[
i
].
append
(
x
+
len
(
block_ids
))
new_block_ids
[
i
].
extend
(
block_ids
)
new_block_ids
[
i
].
extend
(
block_ids
)
overwrite
.
append
(
True
)
overwrite
.
append
(
True
)
# Update the GPU tensors for request states.
if
scheduler_output
.
scheduled_new_reqs
:
self
.
req_states
.
prefill_len
.
copy_to_gpu
()
# Add new blocks for the existing requests.
# Add new blocks for the existing requests.
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
...
@@ -507,16 +502,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -507,16 +502,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc_cpu
=
self
.
input_buffers
.
query_start_loc
.
cpu
[:
num_reqs
+
1
]
query_start_loc_cpu
=
self
.
input_buffers
.
query_start_loc
.
cpu
[:
num_reqs
+
1
]
query_start_loc_np
=
self
.
input_buffers
.
query_start_loc
.
np
[:
num_reqs
+
1
]
query_start_loc_np
=
self
.
input_buffers
.
query_start_loc
.
np
[:
num_reqs
+
1
]
#
Copy
prefill tokens
from CPU to GPU
.
#
Get
prefill tokens.
prepare_prefill_inputs
(
prepare_prefill_inputs
(
idx_mapping_np
,
self
.
input_buffers
.
input_ids
,
num_scheduled_tokens
,
self
.
req_states
.
next_prefill_tokens
,
query_start_loc_np
,
idx_mapping
,
self
.
req_states
.
prefill_token_ids
.
np
,
query_start_loc_gpu
,
self
.
req_states
.
num_computed_prefill_tokens
,
self
.
req_states
.
prefill_token_ids
.
gpu
,
self
.
input_buffers
.
input_ids
.
np
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
num_computed_tokens
,
)
)
self
.
input_buffers
.
input_ids
.
copy_to_gpu
(
num_tokens
)
# Prepare positions and seq_lens.
# Prepare positions and seq_lens.
prepare_pos_seq_lens
(
prepare_pos_seq_lens
(
...
@@ -531,7 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -531,7 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Some input token ids are directly read from the last sampled tokens
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
# and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices
=
combine_sampled_and_draft_tokens
(
logits_indices
=
combine_sampled_and_draft_tokens
(
self
.
input_buffers
.
input_ids
.
gpu
,
self
.
input_buffers
.
input_ids
,
idx_mapping
,
idx_mapping
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req_states
.
last_sampled_tokens
,
query_start_loc_gpu
,
query_start_loc_gpu
,
...
@@ -572,7 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -572,7 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config
=
self
.
kv_cache_config
,
kv_cache_config
=
self
.
kv_cache_config
,
)
)
input_ids
=
self
.
input_buffers
.
input_ids
.
gpu
[:
num_tokens_after_padding
]
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
]
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
]
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
]
return
InputBatch
(
return
InputBatch
(
req_ids
=
req_ids
,
req_ids
=
req_ids
,
...
@@ -782,20 +777,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -782,20 +777,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_reqs
=
input_batch
.
num_reqs
idx_mapping_np
=
input_batch
.
idx_mapping_np
with
async_barrier
(
self
.
spec_decode_event
):
self
.
input_buffers
.
next_prefill_tokens
.
np
[:
num_reqs
]
=
(
self
.
req_states
.
prefill_token_ids
.
np
[
idx_mapping_np
,
self
.
req_states
.
num_computed_prefill_tokens
[
idx_mapping_np
],
]
)
next_prefill_tokens
=
self
.
input_buffers
.
next_prefill_tokens
.
copy_to_gpu
(
num_reqs
)
assert
self
.
speculator
is
not
None
assert
self
.
speculator
is
not
None
last_sampled_tokens
=
self
.
req_states
.
last_sampled_tokens
[
input_batch
.
idx_mapping
]
next_prefill_tokens
=
self
.
req_states
.
next_prefill_tokens
[
input_batch
.
idx_mapping
]
draft_tokens
=
self
.
speculator
.
propose
(
draft_tokens
=
self
.
speculator
.
propose
(
input_batch
,
input_batch
,
sampling_metadata
,
sampling_metadata
,
...
@@ -803,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -803,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
aux_hidden_states
,
aux_hidden_states
,
num_sampled
,
num_sampled
,
num_rejected
,
num_rejected
,
self
.
req_states
.
last_sampled_tokens
,
last_sampled_tokens
,
next_prefill_tokens
,
next_prefill_tokens
,
)
)
return
draft_tokens
return
draft_tokens
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
ca1b1e72
...
@@ -121,7 +121,7 @@ class EagleSpeculator:
...
@@ -121,7 +121,7 @@ class EagleSpeculator:
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
=
num_tokens_across_dp
,
):
):
ret_hidden_states
=
self
.
model
(
ret_hidden_states
=
self
.
model
(
input_ids
=
self
.
input_buffers
.
input_ids
.
gpu
[:
num_tokens
],
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens
],
positions
=
self
.
input_buffers
.
positions
[:
num_tokens
],
positions
=
self
.
input_buffers
.
positions
[:
num_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_tokens
],
)
)
...
@@ -194,7 +194,7 @@ class EagleSpeculator:
...
@@ -194,7 +194,7 @@ class EagleSpeculator:
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
# [num_reqs]
# [num_reqs]
num_rejected
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
# [
max_
num_reqs
, 1
]
# [num_reqs]
last_sampled
:
torch
.
Tensor
,
last_sampled
:
torch
.
Tensor
,
# [num_reqs]
# [num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
next_prefill_tokens
:
torch
.
Tensor
,
...
@@ -316,7 +316,6 @@ def _prepare_eagle_inputs_kernel(
...
@@ -316,7 +316,6 @@ def _prepare_eagle_inputs_kernel(
eagle_positions_ptr
,
eagle_positions_ptr
,
target_input_ids_ptr
,
target_input_ids_ptr
,
target_positions_ptr
,
target_positions_ptr
,
idx_mapping_ptr
,
last_sampled_ptr
,
last_sampled_ptr
,
next_prefill_tokens_ptr
,
next_prefill_tokens_ptr
,
num_sampled_ptr
,
num_sampled_ptr
,
...
@@ -335,8 +334,7 @@ def _prepare_eagle_inputs_kernel(
...
@@ -335,8 +334,7 @@ def _prepare_eagle_inputs_kernel(
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
batch_idx
)
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
batch_idx
)
if
num_sampled
>
0
:
if
num_sampled
>
0
:
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
next_token
=
tl
.
load
(
last_sampled_ptr
+
batch_idx
).
to
(
tl
.
int32
)
next_token
=
tl
.
load
(
last_sampled_ptr
+
req_state_idx
).
to
(
tl
.
int32
)
else
:
else
:
# Chunked prefilling.
# Chunked prefilling.
# Get the next prefill token.
# Get the next prefill token.
...
@@ -368,9 +366,9 @@ def prepare_eagle_inputs(
...
@@ -368,9 +366,9 @@ def prepare_eagle_inputs(
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
# [num_reqs]
# [num_reqs]
num_rejected
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
# [
max_
num_reqs
, 1
]
# [num_reqs]
last_sampled
:
torch
.
Tensor
,
last_sampled
:
torch
.
Tensor
,
# [
max_
num_reqs]
# [num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
next_prefill_tokens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_reqs
=
input_batch
.
num_reqs
num_reqs
=
input_batch
.
num_reqs
...
@@ -381,11 +379,10 @@ def prepare_eagle_inputs(
...
@@ -381,11 +379,10 @@ def prepare_eagle_inputs(
)
)
_prepare_eagle_inputs_kernel
[(
num_reqs
,)](
_prepare_eagle_inputs_kernel
[(
num_reqs
,)](
last_token_indices
,
last_token_indices
,
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
input_ids
,
input_buffers
.
positions
,
input_buffers
.
positions
,
input_batch
.
input_ids
,
input_batch
.
input_ids
,
input_batch
.
positions
,
input_batch
.
positions
,
input_batch
.
idx_mapping
,
last_sampled
,
last_sampled
,
next_prefill_tokens
,
next_prefill_tokens
,
num_sampled
,
num_sampled
,
...
@@ -485,7 +482,7 @@ def prepare_eagle_decode(
...
@@ -485,7 +482,7 @@ def prepare_eagle_decode(
last_token_indices
,
last_token_indices
,
target_seq_lens
,
target_seq_lens
,
num_rejected
,
num_rejected
,
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
input_ids
,
input_buffers
.
positions
,
input_buffers
.
positions
,
input_hidden_states
,
input_hidden_states
,
input_hidden_states
.
stride
(
0
),
input_hidden_states
.
stride
(
0
),
...
@@ -553,7 +550,7 @@ def update_eagle_inputs(
...
@@ -553,7 +550,7 @@ def update_eagle_inputs(
):
):
num_reqs
,
hidden_size
=
output_hidden_states
.
shape
num_reqs
,
hidden_size
=
output_hidden_states
.
shape
_update_eagle_inputs_kernel
[(
num_reqs
,)](
_update_eagle_inputs_kernel
[(
num_reqs
,)](
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
input_ids
,
input_buffers
.
positions
,
input_buffers
.
positions
,
hidden_states
,
hidden_states
,
hidden_states
.
stride
(
0
),
hidden_states
.
stride
(
0
),
...
...
vllm/v1/worker/gpu/states.py
View file @
ca1b1e72
...
@@ -117,8 +117,7 @@ class RequestState:
...
@@ -117,8 +117,7 @@ class RequestState:
self
.
prefill_token_ids
=
UvaBuffer
(
self
.
prefill_token_ids
=
UvaBuffer
(
self
.
max_num_reqs
,
self
.
max_model_len
,
dtype
=
torch
.
int32
self
.
max_num_reqs
,
self
.
max_model_len
,
dtype
=
torch
.
int32
)
)
self
.
prefill_len
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
prefill_len
=
UvaBuffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
# Number of computed tokens.
# Number of computed tokens.
self
.
num_computed_prefill_tokens
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_prefill_tokens
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens
=
torch
.
zeros
(
self
.
num_computed_tokens
=
torch
.
zeros
(
...
@@ -140,6 +139,9 @@ class RequestState:
...
@@ -140,6 +139,9 @@ class RequestState:
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
device
,
device
=
device
,
)
)
self
.
next_prefill_tokens
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
# LoRA.
# LoRA.
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
...
@@ -380,13 +382,13 @@ def _expand_sampling_metadata_kernel(
...
@@ -380,13 +382,13 @@ def _expand_sampling_metadata_kernel(
expanded_top_p_ptr
,
expanded_top_p_ptr
,
top_k_ptr
,
top_k_ptr
,
expanded_top_k_ptr
,
expanded_top_k_ptr
,
seeds_ptr
,
rep_penalty_ptr
,
rep_penalty_ptr
,
expanded_rep_penalty_ptr
,
expanded_rep_penalty_ptr
,
freq_penalty_ptr
,
freq_penalty_ptr
,
expanded_freq_penalty_ptr
,
expanded_freq_penalty_ptr
,
pres_penalty_ptr
,
pres_penalty_ptr
,
expanded_pres_penalty_ptr
,
expanded_pres_penalty_ptr
,
seeds_ptr
,
expanded_seeds_ptr
,
expanded_seeds_ptr
,
cu_num_logits_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment