Unverified Commit 7678fcd5 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

Fix the torch version parsing logic (#15857)

parent 8661c024
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import contextlib import contextlib
import copy import copy
import hashlib import hashlib
import importlib.metadata
import os import os
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
...@@ -11,9 +10,9 @@ from unittest.mock import patch ...@@ -11,9 +10,9 @@ from unittest.mock import patch
import torch import torch
import torch._inductor.compile_fx import torch._inductor.compile_fx
import torch.fx as fx import torch.fx as fx
from packaging.version import Version
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer
class CompilerInterface: class CompilerInterface:
...@@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface): ...@@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface):
manually setting up internal contexts. But we also rely on non-public manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees. APIs which might not provide these guarantees.
""" """
if Version(importlib.metadata.version('torch')) >= Version("2.6"): if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context() return torch._dynamo.utils.get_metrics_context()
else: else:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import hashlib import hashlib
import importlib.metadata
import inspect import inspect
import json import json
import types import types
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Dict, Optional, Union
import torch import torch
from packaging.version import Version
from torch import fx from torch import fx
if Version(importlib.metadata.version('torch')) >= Version("2.6"): from vllm.utils import is_torch_equal_or_newer
if is_torch_equal_or_newer("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass from torch._inductor.custom_graph_pass import CustomGraphPass
else: else:
# CustomGraphPass is not present in 2.5 or lower, import our version # CustomGraphPass is not present in 2.5 or lower, import our version
......
...@@ -4,7 +4,6 @@ import ast ...@@ -4,7 +4,6 @@ import ast
import copy import copy
import enum import enum
import hashlib import hashlib
import importlib.metadata
import json import json
import sys import sys
import warnings import warnings
...@@ -18,7 +17,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, ...@@ -18,7 +17,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, Union) Optional, Protocol, Union)
import torch import torch
from packaging.version import Version
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -40,8 +38,8 @@ from vllm.transformers_utils.config import ( ...@@ -40,8 +38,8 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, random_uuid, get_cpu_memory, get_open_port, is_torch_equal_or_newer,
resolve_obj_by_qualname) random_uuid, resolve_obj_by_qualname)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -3285,7 +3283,7 @@ class CompilationConfig(BaseModel): ...@@ -3285,7 +3283,7 @@ class CompilationConfig(BaseModel):
# and it is not yet a priority. RFC here: # and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703 # https://github.com/vllm-project/vllm/issues/14703
if Version(importlib.metadata.version('torch')) >= Version("2.6"): if is_torch_equal_or_newer("2.6"):
KEY = 'enable_auto_functionalized_v2' KEY = 'enable_auto_functionalized_v2'
if KEY not in self.inductor_compile_config: if KEY not in self.inductor_compile_config:
self.inductor_compile_config[KEY] = False self.inductor_compile_config[KEY] = False
......
...@@ -53,6 +53,7 @@ import torch.types ...@@ -53,6 +53,7 @@ import torch.types
import yaml import yaml
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from packaging import version
from packaging.version import Version from packaging.version import Version
from torch.library import Library from torch.library import Library
from typing_extensions import Never, ParamSpec, TypeIs, assert_never from typing_extensions import Never, ParamSpec, TypeIs, assert_never
...@@ -2580,3 +2581,20 @@ def sha256(input) -> int: ...@@ -2580,3 +2581,20 @@ def sha256(input) -> int:
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(), return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big") byteorder="big")
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try:
torch_version = version.parse(str(torch.__version__))
return torch_version >= version.parse(target)
except Exception:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return Version(importlib.metadata.version('torch')) >= Version(target)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment