Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
42de2cef
Unverified
Commit
42de2cef
authored
Jul 21, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 21, 2024
Browse files
[Misc] Add a wrapper for torch.inference_mode (#6618)
parent
c9eef37f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
4 deletions
+49
-4
vllm/platforms/__init__.py
vllm/platforms/__init__.py
+7
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+21
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+17
-0
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+2
-1
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+2
-1
No files found.
vllm/platforms/__init__.py
View file @
42de2cef
...
@@ -2,7 +2,9 @@ from typing import Optional
...
@@ -2,7 +2,9 @@ from typing import Optional
import
torch
import
torch
from
.interface
import
Platform
,
PlatformEnum
from
vllm.utils
import
is_tpu
from
.interface
import
Platform
,
PlatformEnum
,
UnspecifiedPlatform
current_platform
:
Optional
[
Platform
]
current_platform
:
Optional
[
Platform
]
...
@@ -12,7 +14,10 @@ if torch.version.cuda is not None:
...
@@ -12,7 +14,10 @@ if torch.version.cuda is not None:
elif
torch
.
version
.
hip
is
not
None
:
elif
torch
.
version
.
hip
is
not
None
:
from
.rocm
import
RocmPlatform
from
.rocm
import
RocmPlatform
current_platform
=
RocmPlatform
()
current_platform
=
RocmPlatform
()
elif
is_tpu
():
from
.tpu
import
TpuPlatform
current_platform
=
TpuPlatform
()
else
:
else
:
current_platform
=
None
current_platform
=
UnspecifiedPlatform
()
__all__
=
[
'Platform'
,
'PlatformEnum'
,
'current_platform'
]
__all__
=
[
'Platform'
,
'PlatformEnum'
,
'current_platform'
]
vllm/platforms/interface.py
View file @
42de2cef
import
enum
import
enum
from
typing
import
Tuple
from
typing
import
Tuple
import
torch
class
PlatformEnum
(
enum
.
Enum
):
class
PlatformEnum
(
enum
.
Enum
):
CUDA
=
enum
.
auto
()
CUDA
=
enum
.
auto
()
ROCM
=
enum
.
auto
()
ROCM
=
enum
.
auto
()
TPU
=
enum
.
auto
()
UNSPECIFIED
=
enum
.
auto
()
class
Platform
:
class
Platform
:
...
@@ -16,6 +20,23 @@ class Platform:
...
@@ -16,6 +20,23 @@ class Platform:
def
is_rocm
(
self
)
->
bool
:
def
is_rocm
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
ROCM
return
self
.
_enum
==
PlatformEnum
.
ROCM
def
is_tpu
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
TPU
@
staticmethod
@
staticmethod
def
get_device_capability
(
device_id
:
int
=
0
)
->
Tuple
[
int
,
int
]:
def
get_device_capability
(
device_id
:
int
=
0
)
->
Tuple
[
int
,
int
]:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
def
inference_mode
():
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU
do not support `torch.inference_mode`. In such a case, they will fall
back to `torch.no_grad` by overriding this method.
"""
return
torch
.
inference_mode
(
mode
=
True
)
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
vllm/platforms/tpu.py
0 → 100644
View file @
42de2cef
from
typing
import
Tuple
import
torch
from
.interface
import
Platform
,
PlatformEnum
class
TpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
TPU
@
staticmethod
def
get_device_capability
(
device_id
:
int
=
0
)
->
Tuple
[
int
,
int
]:
raise
RuntimeError
(
"TPU does not have device capability."
)
@
staticmethod
def
inference_mode
():
return
torch
.
no_grad
()
vllm/worker/model_runner_base.py
View file @
42de2cef
...
@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
...
@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
...
@@ -163,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]):
...
@@ -163,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
torch
.
inference_mode
()
@
current_platform
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
model_input
:
T
,
model_input
:
T
,
...
...
vllm/worker/worker_base.py
View file @
42de2cef
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
from
vllm.distributed
import
broadcast_tensor_dict
,
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
SamplerOutput
)
SamplerOutput
)
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
...
@@ -53,7 +54,7 @@ class WorkerBase(ABC):
...
@@ -53,7 +54,7 @@ class WorkerBase(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
torch
.
inference_mode
()
@
current_platform
.
inference_mode
()
def
start_worker_execution_loop
(
self
)
->
None
:
def
start_worker_execution_loop
(
self
)
->
None
:
"""Execute model loop in parallel worker.
"""Execute model loop in parallel worker.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment