Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0dd3f4f5
Unverified
Commit
0dd3f4f5
authored
Aug 18, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 18, 2025
Browse files
[Misc] Minor refactoring for prepare_inputs (#23116)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
498259cc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
22 deletions
+21
-22
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+21
-22
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
0dd3f4f5
...
@@ -757,10 +757,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -757,10 +757,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare the attention metadata.
# Prepare the attention metadata.
self
.
query_start_loc_np
[
0
]
=
0
self
.
query_start_loc_np
[
0
]
=
0
self
.
query_start_loc_np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
self
.
query_start_loc_np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
self
.
query_start_loc_np
[
num_reqs
+
1
:].
fill
(
cu_num_tokens
[
-
1
])
self
.
query_start_loc
.
copy_
(
self
.
query_start_loc_cpu
,
non_blocking
=
True
)
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
self
.
seq_lens_np
[:
num_reqs
]
=
(
self
.
seq_lens_np
[:
num_reqs
]
=
(
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
+
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
+
num_scheduled_tokens
)
num_scheduled_tokens
)
# Fill unused with 0 for full cuda graph mode.
self
.
seq_lens_np
[
num_reqs
:].
fill
(
0
)
self
.
seq_lens
.
copy_
(
self
.
seq_lens_cpu
,
non_blocking
=
True
)
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
# Copy the tensors to the GPU.
# Copy the tensors to the GPU.
self
.
input_ids
[:
total_num_scheduled_tokens
].
copy_
(
self
.
input_ids
[:
total_num_scheduled_tokens
].
copy_
(
...
@@ -776,22 +785,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -776,22 +785,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
non_blocking
=
True
)
non_blocking
=
True
)
self
.
query_start_loc
[:
num_reqs
+
1
].
copy_
(
self
.
query_start_loc_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
self
.
seq_lens
[:
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
non_blocking
=
True
)
# Fill unused with 0 for full cuda graph mode.
self
.
seq_lens
[
num_reqs
:].
fill_
(
0
)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
self
.
query_start_loc
[
num_reqs
+
1
:].
fill_
(
self
.
query_start_loc_cpu
[
num_reqs
].
item
())
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
spec_decode_common_attn_metadata
=
None
use_spec_decode
=
len
(
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
if
not
use_spec_decode
:
if
not
use_spec_decode
:
...
@@ -860,6 +853,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -860,6 +853,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
per_layer_metadata
[
layer_name
]
per_layer_metadata
[
layer_name
]
attn_metadata
[
layer_name
]
=
encoder_attn_metadata
attn_metadata
[
layer_name
]
=
encoder_attn_metadata
# Used in the below loop.
query_start_loc_cpu
=
self
.
query_start_loc_cpu
[:
num_reqs
+
1
]
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
num_reqs
]
num_computed_tokens_cpu
=
(
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[:
num_reqs
])
spec_decode_common_attn_metadata
=
None
# Prepare the attention metadata for each KV cache group and make layers
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
# in the same group share the same metadata.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
...
@@ -874,12 +874,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -874,12 +874,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
blk_table
.
slot_mapping
[
total_num_scheduled_tokens
:].
fill_
(
-
1
)
blk_table
.
slot_mapping
[
total_num_scheduled_tokens
:].
fill_
(
-
1
)
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
],
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
self
.
query_start_loc_cpu
[:
num_reqs
+
1
],
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
self
.
seq_lens
[:
num_reqs
],
seq_lens
=
seq_lens
,
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
num_reqs
],
seq_lens_cpu
=
seq_lens_cpu
,
num_computed_tokens_cpu
=
self
.
input_batch
.
num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
num_computed_tokens_cpu_tensor
[:
num_reqs
],
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment