Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ba775279
Unverified
Commit
ba775279
authored
Sep 12, 2024
by
William Lin
Committed by
GitHub
Sep 12, 2024
Browse files
[bugfix] torch profiler bug for single gpu with GPUExecutor (#8354)
parent
68210201
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
5 deletions
+27
-5
examples/offline_inference_with_profiler.py
examples/offline_inference_with_profiler.py
+1
-1
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+13
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+13
-2
No files found.
examples/offline_inference_with_profiler.py
View file @
ba775279
...
...
@@ -16,7 +16,7 @@ prompts = [
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# Create an LLM.
llm
=
LLM
(
model
=
"facebook/opt-125m"
)
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
tensor_parallel_size
=
1
)
llm
.
start_profile
()
...
...
vllm/engine/async_llm_engine.py
View file @
ba775279
...
...
@@ -13,6 +13,7 @@ from vllm.engine.async_timeout import asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
,
SchedulerOutputState
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
...
...
@@ -1019,7 +1020,17 @@ class AsyncLLMEngine:
self
.
engine
.
remove_logger
(
logger_name
=
logger_name
)
async
def
start_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if
type
(
self
.
engine
.
model_executor
)
==
GPUExecutorAsync
:
self
.
engine
.
model_executor
.
start_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
async
def
stop_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if
type
(
self
.
engine
.
model_executor
)
==
GPUExecutorAsync
:
self
.
engine
.
model_executor
.
stop_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"stop_profile"
)
vllm/engine/llm_engine.py
View file @
ba775279
...
...
@@ -26,6 +26,7 @@ from vllm.engine.output_processor.interfaces import (
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.util
import
create_output_by_sequence_group
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
(
INPUT_REGISTRY
,
EncoderDecoderLLMInputs
,
InputRegistry
,
LLMInputs
,
PromptInputs
)
...
...
@@ -1597,10 +1598,20 @@ class LLMEngine:
self
.
model_executor
.
check_health
()
def
start_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if
type
(
self
.
model_executor
)
==
GPUExecutor
:
self
.
model_executor
.
start_profile
()
else
:
self
.
model_executor
.
_run_workers
(
"start_profile"
)
def
stop_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if
type
(
self
.
model_executor
)
==
GPUExecutor
:
self
.
model_executor
.
stop_profile
()
else
:
self
.
model_executor
.
_run_workers
(
"stop_profile"
)
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracer
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment