Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b12e87f9
Unverified
Commit
b12e87f9
authored
Dec 30, 2024
by
youkaichao
Committed by
GitHub
Dec 30, 2024
Browse files
[platforms] enable platform plugins (#11602)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
5dbf8545
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
9 deletions
+11
-9
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+2
-3
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+1
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+8
-6
No files found.
vllm/worker/model_runner_base.py
View file @
b12e87f9
...
@@ -12,7 +12,6 @@ from torch import is_tensor
...
@@ -12,7 +12,6 @@ from torch import is_tensor
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -265,13 +264,13 @@ class ModelRunnerBase(ABC, Generic[T]):
...
@@ -265,13 +264,13 @@ class ModelRunnerBase(ABC, Generic[T]):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
current_platform
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
model_input
:
T
,
model_input
:
T
,
kv_caches
:
Optional
[
List
[
torch
.
Tensor
]],
kv_caches
:
Optional
[
List
[
torch
.
Tensor
]],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
num_steps
:
int
=
1
,
**
kwargs
,
)
->
Optional
[
List
[
SamplerOutput
]]:
)
->
Optional
[
List
[
SamplerOutput
]]:
"""
"""
Execute the model on the given input.
Execute the model on the given input.
...
...
vllm/worker/multi_step_model_runner.py
View file @
b12e87f9
...
@@ -544,6 +544,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -544,6 +544,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input
.
record_step_event
(
current_stream
)
model_input
.
record_step_event
(
current_stream
)
if
get_pp_group
().
is_last_rank
and
self
.
is_driver_worker
:
if
get_pp_group
().
is_last_rank
and
self
.
is_driver_worker
:
assert
isinstance
(
output
,
list
)
assert
len
(
assert
len
(
output
output
)
==
1
,
"MultiStepModelRunner requires single-step base_models"
)
==
1
,
"MultiStepModelRunner requires single-step base_models"
...
...
vllm/worker/worker_base.py
View file @
b12e87f9
...
@@ -11,7 +11,6 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
...
@@ -11,7 +11,6 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
resolve_obj_by_qualname
,
update_environment_variables
)
resolve_obj_by_qualname
,
update_environment_variables
)
...
@@ -44,6 +43,8 @@ class WorkerBase(ABC):
...
@@ -44,6 +43,8 @@ class WorkerBase(ABC):
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
from
vllm.platforms
import
current_platform
self
.
current_platform
=
current_platform
@
abstractmethod
@
abstractmethod
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
...
@@ -74,17 +75,17 @@ class WorkerBase(ABC):
...
@@ -74,17 +75,17 @@ class WorkerBase(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
current_platform
.
inference_mode
()
def
start_worker_execution_loop
(
self
)
->
None
:
def
start_worker_execution_loop
(
self
)
->
None
:
"""Execute model loop in parallel worker.
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
See `stop_remote_worker_execution_loop` for more details.
"""
"""
while
True
:
with
self
.
current_platform
.
inference_mode
():
output
=
self
.
execute_model
(
execute_model_req
=
None
)
while
True
:
if
output
is
None
:
output
=
self
.
execute_model
(
execute_model_req
=
None
)
return
None
if
output
is
None
:
return
None
@
abstractmethod
@
abstractmethod
def
execute_model
(
def
execute_model
(
...
@@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_execute_time
=
time
.
perf_counter
()
-
start_time
model_execute_time
=
time
.
perf_counter
()
-
start_time
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
# output is IntermediateTensors
# output is IntermediateTensors
assert
isinstance
(
output
,
IntermediateTensors
)
if
(
self
.
observability_config
is
not
None
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_execute_time
):
and
self
.
observability_config
.
collect_model_execute_time
):
output
.
tensors
[
"model_execute_time"
]
=
torch
.
tensor
(
output
.
tensors
[
"model_execute_time"
]
=
torch
.
tensor
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment