Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7678fcd5
Unverified
Commit
7678fcd5
authored
Apr 10, 2025
by
Lu Fang
Committed by
GitHub
Apr 10, 2025
Browse files
Fix the torch version parsing logic (#15857)
parent
8661c024
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
11 deletions
+26
-11
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+2
-3
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+3
-3
vllm/config.py
vllm/config.py
+3
-5
vllm/utils.py
vllm/utils.py
+18
-0
No files found.
vllm/compilation/compiler_interface.py
View file @
7678fcd5
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
import
contextlib
import
contextlib
import
copy
import
copy
import
hashlib
import
hashlib
import
importlib.metadata
import
os
import
os
from
contextlib
import
ExitStack
from
contextlib
import
ExitStack
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
...
@@ -11,9 +10,9 @@ from unittest.mock import patch
...
@@ -11,9 +10,9 @@ from unittest.mock import patch
import
torch
import
torch
import
torch._inductor.compile_fx
import
torch._inductor.compile_fx
import
torch.fx
as
fx
import
torch.fx
as
fx
from
packaging.version
import
Version
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.utils
import
is_torch_equal_or_newer
class
CompilerInterface
:
class
CompilerInterface
:
...
@@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface):
...
@@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface):
manually setting up internal contexts. But we also rely on non-public
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
APIs which might not provide these guarantees.
"""
"""
if
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
"2.6"
):
if
is_torch_equal_or_newer
(
"2.6"
):
import
torch._dynamo.utils
import
torch._dynamo.utils
return
torch
.
_dynamo
.
utils
.
get_metrics_context
()
return
torch
.
_dynamo
.
utils
.
get_metrics_context
()
else
:
else
:
...
...
vllm/compilation/inductor_pass.py
View file @
7678fcd5
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
hashlib
import
hashlib
import
importlib.metadata
import
inspect
import
inspect
import
json
import
json
import
types
import
types
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
import
torch
import
torch
from
packaging.version
import
Version
from
torch
import
fx
from
torch
import
fx
if
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
"2.6"
):
from
vllm.utils
import
is_torch_equal_or_newer
if
is_torch_equal_or_newer
(
"2.6"
):
from
torch._inductor.custom_graph_pass
import
CustomGraphPass
from
torch._inductor.custom_graph_pass
import
CustomGraphPass
else
:
else
:
# CustomGraphPass is not present in 2.5 or lower, import our version
# CustomGraphPass is not present in 2.5 or lower, import our version
...
...
vllm/config.py
View file @
7678fcd5
...
@@ -4,7 +4,6 @@ import ast
...
@@ -4,7 +4,6 @@ import ast
import
copy
import
copy
import
enum
import
enum
import
hashlib
import
hashlib
import
importlib.metadata
import
json
import
json
import
sys
import
sys
import
warnings
import
warnings
...
@@ -18,7 +17,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
...
@@ -18,7 +17,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional
,
Protocol
,
Union
)
Optional
,
Protocol
,
Union
)
import
torch
import
torch
from
packaging.version
import
Version
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -40,8 +38,8 @@ from vllm.transformers_utils.config import (
...
@@ -40,8 +38,8 @@ from vllm.transformers_utils.config import (
from
vllm.transformers_utils.s3_utils
import
S3Model
from
vllm.transformers_utils.s3_utils
import
S3Model
from
vllm.transformers_utils.utils
import
is_s3
,
maybe_model_redirect
from
vllm.transformers_utils.utils
import
is_s3
,
maybe_model_redirect
from
vllm.utils
import
(
GiB_bytes
,
LayerBlockType
,
cuda_device_count_stateless
,
from
vllm.utils
import
(
GiB_bytes
,
LayerBlockType
,
cuda_device_count_stateless
,
get_cpu_memory
,
get_open_port
,
random_uuid
,
get_cpu_memory
,
get_open_port
,
is_torch_equal_or_newer
,
resolve_obj_by_qualname
)
random_uuid
,
resolve_obj_by_qualname
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
...
@@ -3285,7 +3283,7 @@ class CompilationConfig(BaseModel):
...
@@ -3285,7 +3283,7 @@ class CompilationConfig(BaseModel):
# and it is not yet a priority. RFC here:
# and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703
# https://github.com/vllm-project/vllm/issues/14703
if
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
"2.6"
):
if
is_torch_equal_or_newer
(
"2.6"
):
KEY
=
'enable_auto_functionalized_v2'
KEY
=
'enable_auto_functionalized_v2'
if
KEY
not
in
self
.
inductor_compile_config
:
if
KEY
not
in
self
.
inductor_compile_config
:
self
.
inductor_compile_config
[
KEY
]
=
False
self
.
inductor_compile_config
[
KEY
]
=
False
...
...
vllm/utils.py
View file @
7678fcd5
...
@@ -53,6 +53,7 @@ import torch.types
...
@@ -53,6 +53,7 @@ import torch.types
import
yaml
import
yaml
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
from
packaging
import
version
from
packaging.version
import
Version
from
packaging.version
import
Version
from
torch.library
import
Library
from
torch.library
import
Library
from
typing_extensions
import
Never
,
ParamSpec
,
TypeIs
,
assert_never
from
typing_extensions
import
Never
,
ParamSpec
,
TypeIs
,
assert_never
...
@@ -2580,3 +2581,20 @@ def sha256(input) -> int:
...
@@ -2580,3 +2581,20 @@ def sha256(input) -> int:
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
return
int
.
from_bytes
(
hashlib
.
sha256
(
input_bytes
).
digest
(),
return
int
.
from_bytes
(
hashlib
.
sha256
(
input_bytes
).
digest
(),
byteorder
=
"big"
)
byteorder
=
"big"
)
def
is_torch_equal_or_newer
(
target
:
str
)
->
bool
:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try
:
torch_version
=
version
.
parse
(
str
(
torch
.
__version__
))
return
torch_version
>=
version
.
parse
(
target
)
except
Exception
:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
target
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment