Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2ff48576
Unverified
Commit
2ff48576
authored
Feb 10, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 11, 2025
Browse files
[V1][Minor] Move scheduler outputs to a separate file (#13062)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
91e87675
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
113 additions
and
89 deletions
+113
-89
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+3
-86
vllm/v1/core/scheduler_output.py
vllm/v1/core/scheduler_output.py
+108
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+1
-2
No files found.
vllm/v1/core/scheduler.py
View file @
2ff48576
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
collections
import
deque
from
collections
import
deque
from
dataclasses
import
dataclass
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
(
TYPE_CHECKING
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
compute_encoder_budget
)
compute_encoder_budget
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
from
vllm.v1.core.scheduler_output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
from
vllm.v1.engine
import
EngineCoreOutput
,
EngineCoreOutputs
from
vllm.v1.engine
import
EngineCoreOutput
,
EngineCoreOutputs
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
if
TYPE_CHECKING
:
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.base
import
PlaceholderRange
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -600,80 +594,3 @@ class Scheduler:
...
@@ -600,80 +594,3 @@ class Scheduler:
num_waiting_reqs
=
len
(
self
.
waiting
),
num_waiting_reqs
=
len
(
self
.
waiting
),
gpu_cache_usage
=
self
.
kv_cache_manager
.
usage
,
gpu_cache_usage
=
self
.
kv_cache_manager
.
usage
,
)
)
@
dataclass
class
NewRequestData
:
req_id
:
str
prompt_token_ids
:
List
[
int
]
prompt
:
Optional
[
str
]
mm_inputs
:
List
[
"MultiModalKwargs"
]
mm_hashes
:
List
[
str
]
mm_positions
:
List
[
"PlaceholderRange"
]
sampling_params
:
SamplingParams
block_ids
:
List
[
int
]
num_computed_tokens
:
int
lora_request
:
Optional
[
LoRARequest
]
@
classmethod
def
from_request
(
cls
,
request
:
Request
,
block_ids
:
List
[
int
],
num_computed_tokens
:
int
,
)
->
"NewRequestData"
:
return
cls
(
req_id
=
request
.
request_id
,
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt
=
request
.
prompt
,
mm_inputs
=
request
.
mm_inputs
,
mm_hashes
=
request
.
mm_hashes
,
mm_positions
=
request
.
mm_positions
,
sampling_params
=
request
.
sampling_params
,
block_ids
=
block_ids
,
num_computed_tokens
=
num_computed_tokens
,
lora_request
=
request
.
lora_request
,
)
@
dataclass
class
CachedRequestData
:
req_id
:
str
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption
:
bool
new_block_ids
:
List
[
int
]
num_computed_tokens
:
int
@
classmethod
def
from_request
(
cls
,
request
:
Request
,
resumed_from_preemption
:
bool
,
new_block_ids
:
List
[
int
],
num_computed_tokens
:
int
,
)
->
"CachedRequestData"
:
return
cls
(
req_id
=
request
.
request_id
,
resumed_from_preemption
=
resumed_from_preemption
,
new_block_ids
=
new_block_ids
,
num_computed_tokens
=
num_computed_tokens
,
)
@
dataclass
class
SchedulerOutput
:
scheduled_new_reqs
:
List
[
NewRequestData
]
scheduled_cached_reqs
:
List
[
CachedRequestData
]
num_scheduled_tokens
:
Dict
[
str
,
int
]
total_num_scheduled_tokens
:
int
scheduled_encoder_inputs
:
Dict
[
str
,
List
[
int
]]
num_common_prefix_blocks
:
int
finished_req_ids
:
Set
[
str
]
free_encoder_input_ids
:
List
[
Tuple
[
str
,
int
]]
vllm/v1/core/scheduler_output.py
0 → 100644
View file @
2ff48576
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
if
TYPE_CHECKING
:
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.base
import
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.request
import
Request
@
dataclass
class
NewRequestData
:
req_id
:
str
prompt_token_ids
:
List
[
int
]
prompt
:
Optional
[
str
]
mm_inputs
:
List
[
"MultiModalKwargs"
]
mm_hashes
:
List
[
str
]
mm_positions
:
List
[
"PlaceholderRange"
]
sampling_params
:
"SamplingParams"
block_ids
:
List
[
int
]
num_computed_tokens
:
int
lora_request
:
Optional
[
"LoRARequest"
]
@
classmethod
def
from_request
(
cls
,
request
:
"Request"
,
block_ids
:
List
[
int
],
num_computed_tokens
:
int
,
)
->
"NewRequestData"
:
return
cls
(
req_id
=
request
.
request_id
,
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt
=
request
.
prompt
,
mm_inputs
=
request
.
mm_inputs
,
mm_hashes
=
request
.
mm_hashes
,
mm_positions
=
request
.
mm_positions
,
sampling_params
=
request
.
sampling_params
,
block_ids
=
block_ids
,
num_computed_tokens
=
num_computed_tokens
,
lora_request
=
request
.
lora_request
,
)
@
dataclass
class
CachedRequestData
:
req_id
:
str
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption
:
bool
new_block_ids
:
List
[
int
]
num_computed_tokens
:
int
@
classmethod
def
from_request
(
cls
,
request
:
"Request"
,
resumed_from_preemption
:
bool
,
new_block_ids
:
List
[
int
],
num_computed_tokens
:
int
,
)
->
"CachedRequestData"
:
return
cls
(
req_id
=
request
.
request_id
,
resumed_from_preemption
=
resumed_from_preemption
,
new_block_ids
=
new_block_ids
,
num_computed_tokens
=
num_computed_tokens
,
)
@
dataclass
class
SchedulerOutput
:
# List of the requests that are scheduled for the first time.
# We cache the request's data in each worker process, so that we don't
# need to re-send it every scheduling step.
scheduled_new_reqs
:
List
[
NewRequestData
]
# List of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs
:
List
[
CachedRequestData
]
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.
num_scheduled_tokens
:
Dict
[
str
,
int
]
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens
:
int
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs
:
Dict
[
str
,
List
[
int
]]
# Number of common prefix blocks for all requests.
# This can be used for cascade attention.
num_common_prefix_blocks
:
int
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids
:
Set
[
str
]
# List of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache.
free_encoder_input_ids
:
List
[
Tuple
[
str
,
int
]]
vllm/v1/worker/gpu_model_runner.py
View file @
2ff48576
...
@@ -36,7 +36,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...
@@ -36,7 +36,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.scheduler
import
SchedulerOutput
from
vllm.v1.core.scheduler
_output
import
SchedulerOutput
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/worker/gpu_worker.py
View file @
2ff48576
...
@@ -18,7 +18,6 @@ from vllm.logger import init_logger
...
@@ -18,7 +18,6 @@ from vllm.logger import init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
GiB_bytes
from
vllm.utils
import
GiB_bytes
from
vllm.v1.core.scheduler
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
@@ -26,7 +25,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
...
@@ -26,7 +25,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.scheduler
import
SchedulerOutput
from
vllm.v1.core.scheduler
_output
import
SchedulerOutput
class
Worker
:
class
Worker
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment