Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7678fcd5
Unverified
Commit
7678fcd5
authored
Apr 10, 2025
by
Lu Fang
Committed by
GitHub
Apr 10, 2025
Browse files
Fix the torch version parsing logic (#15857)
parent
8661c024
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
11 deletions
+26
-11
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+2
-3
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+3
-3
vllm/config.py
vllm/config.py
+3
-5
vllm/utils.py
vllm/utils.py
+18
-0
No files found.
vllm/compilation/compiler_interface.py
View file @
7678fcd5
...
...
@@ -2,7 +2,6 @@
import
contextlib
import
copy
import
hashlib
import
importlib.metadata
import
os
from
contextlib
import
ExitStack
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
...
...
@@ -11,9 +10,9 @@ from unittest.mock import patch
import
torch
import
torch._inductor.compile_fx
import
torch.fx
as
fx
from
packaging.version
import
Version
from
vllm.config
import
VllmConfig
from
vllm.utils
import
is_torch_equal_or_newer
class
CompilerInterface
:
...
...
@@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface):
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
if
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
"2.6"
):
if
is_torch_equal_or_newer
(
"2.6"
):
import
torch._dynamo.utils
return
torch
.
_dynamo
.
utils
.
get_metrics_context
()
else
:
...
...
vllm/compilation/inductor_pass.py
View file @
7678fcd5
# SPDX-License-Identifier: Apache-2.0
import
hashlib
import
importlib.metadata
import
inspect
import
json
import
types
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
import
torch
from
packaging.version
import
Version
from
torch
import
fx
if
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
"2.6"
):
from
vllm.utils
import
is_torch_equal_or_newer
if
is_torch_equal_or_newer
(
"2.6"
):
from
torch._inductor.custom_graph_pass
import
CustomGraphPass
else
:
# CustomGraphPass is not present in 2.5 or lower, import our version
...
...
vllm/config.py
View file @
7678fcd5
...
...
@@ -4,7 +4,6 @@ import ast
import
copy
import
enum
import
hashlib
import
importlib.metadata
import
json
import
sys
import
warnings
...
...
@@ -18,7 +17,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional
,
Protocol
,
Union
)
import
torch
from
packaging.version
import
Version
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
transformers
import
PretrainedConfig
...
...
@@ -40,8 +38,8 @@ from vllm.transformers_utils.config import (
from
vllm.transformers_utils.s3_utils
import
S3Model
from
vllm.transformers_utils.utils
import
is_s3
,
maybe_model_redirect
from
vllm.utils
import
(
GiB_bytes
,
LayerBlockType
,
cuda_device_count_stateless
,
get_cpu_memory
,
get_open_port
,
random_uuid
,
resolve_obj_by_qualname
)
get_cpu_memory
,
get_open_port
,
is_torch_equal_or_newer
,
random_uuid
,
resolve_obj_by_qualname
)
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
...
...
@@ -3285,7 +3283,7 @@ class CompilationConfig(BaseModel):
# and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703
if
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
"2.6"
):
if
is_torch_equal_or_newer
(
"2.6"
):
KEY
=
'enable_auto_functionalized_v2'
if
KEY
not
in
self
.
inductor_compile_config
:
self
.
inductor_compile_config
[
KEY
]
=
False
...
...
vllm/utils.py
View file @
7678fcd5
...
...
@@ -53,6 +53,7 @@ import torch.types
import
yaml
import
zmq
import
zmq.asyncio
from
packaging
import
version
from
packaging.version
import
Version
from
torch.library
import
Library
from
typing_extensions
import
Never
,
ParamSpec
,
TypeIs
,
assert_never
...
...
@@ -2580,3 +2581,20 @@ def sha256(input) -> int:
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
return
int
.
from_bytes
(
hashlib
.
sha256
(
input_bytes
).
digest
(),
byteorder
=
"big"
)
def
is_torch_equal_or_newer
(
target
:
str
)
->
bool
:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try
:
torch_version
=
version
.
parse
(
str
(
torch
.
__version__
))
return
torch_version
>=
version
.
parse
(
target
)
except
Exception
:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
target
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment