Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d531ad7
Unverified
Commit
6d531ad7
authored
Mar 28, 2025
by
Nick Hill
Committed by
GitHub
Mar 28, 2025
Browse files
[Misc][V1] Misc code streamlining (#15723)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
762b424a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
39 deletions
+33
-39
vllm/distributed/utils.py
vllm/distributed/utils.py
+1
-4
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+25
-30
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+1
-1
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+1
-1
vllm/v1/request.py
vllm/v1/request.py
+5
-3
No files found.
vllm/distributed/utils.py
View file @
6d531ad7
...
...
@@ -207,10 +207,7 @@ class StatelessProcessGroup:
def
barrier
(
self
):
"""A barrier to synchronize all ranks."""
for
i
in
range
(
self
.
world_size
):
if
i
==
self
.
rank
:
self
.
broadcast_obj
(
None
,
src
=
self
.
rank
)
else
:
self
.
broadcast_obj
(
None
,
src
=
i
)
self
.
broadcast_obj
(
None
,
src
=
i
)
@
staticmethod
def
create
(
...
...
vllm/v1/core/sched/scheduler.py
View file @
6d531ad7
...
...
@@ -269,29 +269,26 @@ class Scheduler(SchedulerInterface):
request
=
self
.
waiting
[
0
]
# Waiting request skipping logic
is_skipped
=
False
# Skip request if the structured output request is still waiting
# for FSM.
if
(
not
is_skipped
and
request
.
status
==
RequestStatus
.
WAITING_FOR_FSM
):
# for FSM compilation.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_FSM
:
structured_output_req
=
request
.
structured_output_request
is_skipped
=
(
not
structured_output_req
or
not
structured_output_req
.
grammar
)
if
not
is_skipped
:
if
structured_output_req
and
structured_output_req
.
grammar
:
request
.
status
=
RequestStatus
.
WAITING
# Skip request if max_loras can't be honored.
if
(
not
is_skipped
and
self
.
lora_config
and
request
.
lora_request
):
req_lora_id
=
request
.
lora_request
.
lora_int_id
is_skipped
=
(
len
(
scheduled_loras
)
==
self
.
lora_config
.
max_loras
and
(
req_lora_id
not
in
scheduled_loras
))
if
is_skipped
:
skipped_waiting_requests
.
appendleft
(
request
)
else
:
self
.
waiting
.
popleft
()
skipped_waiting_requests
.
appendleft
(
request
)
continue
# Check that adding the request still respects the max_loras
# constraint.
if
self
.
lora_config
and
request
.
lora_request
and
(
len
(
scheduled_loras
)
==
self
.
lora_config
.
max_loras
and
request
.
lora_request
.
lora_int_id
not
in
scheduled_loras
):
# Scheduling would exceed max_loras, skip.
self
.
waiting
.
popleft
()
skipped_waiting_requests
.
appendleft
(
request
)
continue
# Get already-cached tokens.
...
...
@@ -602,8 +599,9 @@ class Scheduler(SchedulerInterface):
# OPTIMIZATION: Avoid list(set) if the set is empty.
if
cached_encoder_input_ids
:
for
input_id
in
list
(
cached_encoder_input_ids
):
start_pos
=
request
.
mm_positions
[
input_id
][
"offset"
]
num_tokens
=
request
.
mm_positions
[
input_id
][
"length"
]
mm_positions
=
request
.
mm_positions
[
input_id
]
start_pos
=
mm_positions
[
"offset"
]
num_tokens
=
mm_positions
[
"length"
]
if
start_pos
+
num_tokens
<=
request
.
num_computed_tokens
:
# The encoder output is already processed and stored
# in the decoder's KV cache.
...
...
@@ -616,25 +614,24 @@ class Scheduler(SchedulerInterface):
stopped
=
False
new_logprobs
=
None
new_token_ids
:
list
[
int
]
=
[]
new_token_ids
=
generated_token_ids
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for
output_token_id
in
g
enerate
d
_token_ids
:
for
num_new
,
output_token_id
in
en
um
erate
(
new
_token_ids
,
1
)
:
request
.
append_output_token_ids
(
output_token_id
)
new_token_ids
.
append
(
output_token_id
)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped
=
check_stop
(
request
,
self
.
max_model_len
)
if
stopped
:
self
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
# Extract sample logprobs if needed.
if
(
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
is
not
None
):
if
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
...
...
@@ -644,9 +641,7 @@ class Scheduler(SchedulerInterface):
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
request
.
request_id
,
new_token_ids
,
)
req_id
,
new_token_ids
)
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
...
...
@@ -665,7 +660,7 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs.
assert
not
prompt_logprobs_tensors
self
.
scheduled_req_ids
.
remove
(
req
uest
.
request
_id
)
self
.
scheduled_req_ids
.
remove
(
req_id
)
if
not
stopped
:
new_running
.
append
(
request
)
...
...
vllm/v1/engine/core_client.py
View file @
6d531ad7
...
...
@@ -416,9 +416,9 @@ class SyncMPClient(MPClient):
def
process_outputs_socket
():
shutdown_socket
=
ctx
.
socket
(
zmq
.
PAIR
)
shutdown_socket
.
bind
(
shutdown_path
)
out_socket
=
make_zmq_socket
(
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
try
:
shutdown_socket
.
bind
(
shutdown_path
)
poller
=
zmq
.
Poller
()
poller
.
register
(
shutdown_socket
)
poller
.
register
(
out_socket
)
...
...
vllm/v1/engine/output_processor.py
View file @
6d531ad7
...
...
@@ -328,7 +328,7 @@ class OutputProcessor:
# 2) Detokenize the token ids into text and perform stop checks.
stop_string
=
req_state
.
detokenizer
.
update
(
new_token_ids
,
finish_reason
==
FinishReason
.
STOP
)
if
stop_string
and
finish_reason
!=
FinishReason
.
STOP
:
if
stop_string
:
finish_reason
=
FinishReason
.
STOP
stop_reason
=
stop_string
...
...
vllm/v1/request.py
View file @
6d531ad7
...
...
@@ -93,9 +93,11 @@ class Request:
token_ids
:
Union
[
int
,
list
[
int
]],
)
->
None
:
if
isinstance
(
token_ids
,
int
):
token_ids
=
[
token_ids
]
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
self
.
_output_token_ids
.
append
(
token_ids
)
self
.
_all_token_ids
.
append
(
token_ids
)
else
:
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
@
property
def
num_tokens
(
self
)
->
int
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment