env_override.py 23.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
# ruff: noqa: E402
import importlib.util
5
6
import os

7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

def _get_torch_cuda_version():
    """Peripheral function to _maybe_set_cuda_compatibility_path().
    PyTorch version must not be determined by importing directly
    because it will trigger the CUDA initialization, losing the
    chance to set the LD_LIBRARY_PATH beforehand.
    """
    try:
        spec = importlib.util.find_spec("torch")
        if not spec:
            return None
        if spec.origin:
            torch_root = os.path.dirname(spec.origin)
        elif spec.submodule_search_locations:
            torch_root = spec.submodule_search_locations[0]
        else:
            return None
        version_path = os.path.join(torch_root, "version.py")
        if not os.path.exists(version_path):
            return None
        # Load the version module without importing torch
        ver_spec = importlib.util.spec_from_file_location("torch.version", version_path)
        if not ver_spec or not ver_spec.loader:
            return None
        module = importlib.util.module_from_spec(ver_spec)
        # Avoid registering in sys.modules to not confuse future imports
        ver_spec.loader.exec_module(module)
        return getattr(module, "cuda", None)
    except Exception:
        return None


def _maybe_set_cuda_compatibility_path():
    """Set LD_LIBRARY_PATH for CUDA forward compatibility if enabled.

    Must run before 'import torch' since torch loads CUDA shared libraries
    at import time and the dynamic linker only consults LD_LIBRARY_PATH when
    a library is first loaded.

    CUDA forward compatibility is only supported on select professional and
    datacenter NVIDIA GPUs. Consumer GPUs (GeForce, RTX) do not support it
    and will get Error 803 if compat libs are loaded.
    """
    enable = os.environ.get("VLLM_ENABLE_CUDA_COMPATIBILITY", "0").strip().lower() in (
        "1",
        "true",
    )
    if not enable:
        return

    cuda_compat_path = os.environ.get("VLLM_CUDA_COMPATIBILITY_PATH", "")
    if not cuda_compat_path or not os.path.isdir(cuda_compat_path):
        conda_prefix = os.environ.get("CONDA_PREFIX", "")
        conda_compat = os.path.join(conda_prefix, "cuda-compat")
        if conda_prefix and os.path.isdir(conda_compat):
            cuda_compat_path = conda_compat
    if not cuda_compat_path or not os.path.isdir(cuda_compat_path):
        torch_cuda_version = _get_torch_cuda_version()
        if torch_cuda_version:
            default_path = f"/usr/local/cuda-{torch_cuda_version}/compat"
            if os.path.isdir(default_path):
                cuda_compat_path = default_path
    if not cuda_compat_path or not os.path.isdir(cuda_compat_path):
        return

    norm_path = os.path.normpath(cuda_compat_path)
    existing = os.environ.get("LD_LIBRARY_PATH", "")
    ld_paths = existing.split(os.pathsep) if existing else []

    if ld_paths and ld_paths[0] and os.path.normpath(ld_paths[0]) == norm_path:
        return  # Already at the front

    new_paths = [norm_path] + [
        p for p in ld_paths if not p or os.path.normpath(p) != norm_path
    ]
    os.environ["LD_LIBRARY_PATH"] = os.pathsep.join(new_paths)


_maybe_set_cuda_compatibility_path()

87
88
import torch

89
from vllm.logger import init_logger
90
from vllm.utils.torch_utils import is_torch_equal, is_torch_equal_or_newer
91
92
93

logger = init_logger(__name__)

94
95
96
97
98
99
100
# set some common config/environment variables that should be set
# for all processes created by vllm and all processes
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.

# see https://github.com/vllm-project/vllm/pull/15951
# it avoids unintentional cuda initialization from torch.cuda.is_available()
101
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"
102

103
104
# see https://github.com/vllm-project/vllm/issues/10480 and
# https://github.com/vllm-project/vllm/issues/10619.
105
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
106

107
108
109
110
111
112
113
114
# Enable Triton autotuning result caching to disk by default.
# Without this, Triton re-runs autotuning on every process restart,
# adding significant latency to the first inference request.
# This writes autotuning results to TRITON_CACHE_DIR.
# It can still be overridden by setting TRITON_CACHE_AUTOTUNING=0
# in the environment.
os.environ.setdefault("TRITON_CACHE_AUTOTUNING", "1")

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# ===================================================
# torch 2.9 Inductor PythonWrapperCodegen monkeypatch
# ===================================================
# This change monkeypatches memory_plan_reuse in pytorch 2.9.0 to work around
# a test failure for test_multi_graph_piecewise_compile_outputs_equal.
# For more context, see https://github.com/pytorch/pytorch/pull/165514.


def memory_plan_reuse_patched(self):
    import torch._inductor.ir as ir
    from torch._inductor.codegen.wrapper import (
        EnterSubgraphLine,
        ExitSubgraphLine,
        MemoryPlanningLine,
        MemoryPlanningState,
        SubgraphPythonWrapperCodegen,
    )
    from torch._inductor.virtualized import V

    def get_output_names(graph_outputs) -> list[str]:
        import itertools

        names = []
        shape_counter = itertools.count(0)
        none_counter = itertools.count(0)
        for node in graph_outputs:
            if isinstance(node, ir.NoneAsConstantBuffer):
                names.append(f"{V.graph.name}_none{next(none_counter)}")
            elif isinstance(node, ir.ShapeAsConstantBuffer):
                names.append(f"{V.graph.name}_shape{next(shape_counter)}")
            else:
                names.append(node.get_name())
        return names

    if (
        isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen)
        and V.graph.wrapper_code.partition_signatures is not None
    ):
        out_names = get_output_names(
            V.graph.wrapper_code.partition_signatures.output_nodes
        )
    else:
        out_names = V.graph.get_output_names()

    while (
        self.lines
        and isinstance(self.lines[-1], MemoryPlanningLine)
        and self.lines[-1].node.name not in out_names  # type: ignore[attr-defined]
    ):
        # these lines will be pointless
        self.lines.pop()

    # codegen allocations in two passes
    planning_states = [MemoryPlanningState()]
    past_planning_states = []
    for i in range(len(self.lines)):
        line = self.lines[i]
        if isinstance(line, MemoryPlanningLine):
            self.lines[i] = line.plan(planning_states[-1])
        elif isinstance(line, EnterSubgraphLine):
            planning_states.append(MemoryPlanningState())
        elif isinstance(line, ExitSubgraphLine):
            past_planning_states.append(planning_states.pop())
    past_planning_states.append(planning_states.pop())
    assert len(planning_states) == 0

181

182
183
184
185
186
# ===================================================
# torch 2.9 Inductor get_graph_partition_signature monkeypatch
# ===================================================
# This change monkeypatches get_graph_partition_signature in pytorch 2.9.0 to
# fix inductor partition + attention-nvfp4 quant fusion, tested in
187
# `tests/compile/test_fusion_attn.py::test_attn_quant`.
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# For more context, see https://github.com/pytorch/pytorch/pull/165815.


def get_graph_partition_signature_patched(
    self, partitions, skip_cudagraphs: list[bool]
):
    """
    Gets signature for each graph partition, including input nodes, output nodes, and
    whether deallocating an input within graph partition.
    """
    from torch._inductor import dependencies
    from torch._inductor.ir import GraphPartitionSignature, MutationOutput, NoneLayout
    from torch._inductor.virtualized import V
    from torch.utils._ordered_set import OrderedSet

    signatures = []

    unmet_output_names = OrderedSet(V.graph.get_output_names())
    name_to_node = self.get_name_to_nodes()

    def is_none_layout(buf_name: str) -> bool:
        """
        Checks if buf_name is NoneLayout. Buffers with NoneLayout is not allocated
        so graph partition should not take it as inputs or outputs.
        """
        buf = self.name_to_buf.get(buf_name, None)

        if buf is None:
            return False

        if isinstance(buf.node.layout, NoneLayout):
            if isinstance(buf.node, MutationOutput) and (
                real_name := self.mutation_real_name.get(buf_name, None)
            ):
                return is_none_layout(real_name)

            return True

        return False

    for partition, skip_cudagraph in zip(
        reversed(partitions), reversed(skip_cudagraphs)
    ):
        output_names: OrderedSet[str] = OrderedSet()

        for node in partition:
            output_names.update(node.outputs_by_name.keys())

        returned_output_names = output_names.intersection(unmet_output_names)

        # all reads/writes are partition inputs except those generated
        # within the partition and tensor constants
        read_writes = dependencies.ReadWrites.merge_list(
            [node.read_writes for node in partition]
        )

        # WeakDep is fake dependency on unused buffer. It should not appear
        # in partition_input_names for inputs that are actually read or written.
        partition_input_names = (
            OrderedSet(
                [
                    x.name
                    for x in read_writes.reads | read_writes.writes
                    if not is_none_layout(x.name)
                ]
            )
            - output_names
        )

        partition_input_names = OrderedSet(
            self.mutation_real_name.get(name, name) for name in partition_input_names
        )

        buffer_names_to_free: OrderedSet[str] = OrderedSet()
        for node in partition:
            buffer_names_to_free.update(node.last_usage)

        # buffer_names_to_free may contain buffers allocated in previous
        # graph partitions. These buffers should also be a partition
        # input.
        extra_input_names = [
            name
            for name in (buffer_names_to_free - output_names)
            if name in name_to_node
        ]
        partition_input_names.update(extra_input_names)

        input_nodes = {
            name: name_to_node[name]
            for name in partition_input_names
            if name in name_to_node
        }
        input_deallocation = {
            name: name in buffer_names_to_free
            for name in partition_input_names
            if name in name_to_node
        }

        # if an input tensor is not freed in the partition function, it should
        # also be returned as an output. This brings benefits to cudagraph
        # since the returned output tensor is a cudagraph managed tensor with
        # a static tensor address.
        extra_output_names = [
            name
            for name in partition_input_names
            if name in name_to_node and name not in buffer_names_to_free
        ]

        returned_output_names.update(extra_output_names)

        returned_output_names = OrderedSet(
            self.mutation_real_name.get(name, name) for name in returned_output_names
        )

        output_nodes = [
            name_to_node[name]
            for name in returned_output_names
            if not is_none_layout(name)
        ]

        constant_names = [
            name for name in partition_input_names if name in V.graph.constants
        ]

        symbol_inputs = self.get_graph_partition_symbol_inputs(partition, input_nodes)

        partition_signature = GraphPartitionSignature(
            symbol_inputs,
            input_nodes,
            output_nodes,
            input_deallocation,
            skip_cudagraph,
            constant_names,
        )

        signatures.append(partition_signature)

        unmet_output_names = partition_input_names.union(
            unmet_output_names - returned_output_names
        )

    return signatures[::-1]


332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
# ========================================
# torch 2.9 Inductor Scheduler monkeypatch
# ========================================
# This change monkeypatches a function in Inductor to work around the following
# bug: https://github.com/vllm-project/vllm/issues/26678
#
# The bug occurs when `use_inductor_graph_partition` is turned on and there
# exists operators inside of `splitting_ops` that have an in-place mutation. In
# vllm, this specifically occurs on the operator
# vllm.unified_attention_with_output. In this case, inductor does not populate
# the inductor IR's `origin_node` field, causing an assertion error when trying
# to access the node's `origin_node` field.
#
# So, we will monkeypatch torch._inductor.scheduler.Scheduler.should_partition
# so that it does not access the inductor IR node's `origin_node` field and just
# returns True if a node is registered as having a custom partition function.
# This is ok for now since vllm's implementation of the custom partition
# functions just return True.
# ========================================


def should_partition_patched(self, node, should_log: bool = False) -> bool:
    # This is a patched version of
    # torch._inductor.scheduler.Scheduler.should_partition that modifies
    # the following piece of code so that we always return True:
    # https://github.com/pytorch/pytorch/blob/ecb53078faf86ca1b33277df33b82985675bb011/torch/_inductor/scheduler.py#L4712-L4724
    """Return True if we should partition the inductor graph on this node"""

    import torch._inductor.ir as ir
    from torch._inductor.scheduler import (
        BaseSchedulerNode,
        FusedSchedulerNode,
    )
    from torch._inductor.utils import (
        _unstable_customized_partition_wrapper,
        is_cudagraph_unsafe_op,
        maybe_log_cudagraph_partition,
    )

    # Allow users to manually specify if a node should be partitioned
    # Can only do this for FallbackKernels
    ir_node = node.node
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
        op := ir_node.op_overload
    ):
        op_overload_packet_name = op.name()
        op_overload_name = (
            f"{op_overload_packet_name}.{op._overloadname}"
            if isinstance(op, torch._ops.OpOverload)
            else op_overload_packet_name
        )
        if (
            op_overload_packet_name
            in torch._inductor.config.custom_should_partition_ops
            or op_overload_name in torch._inductor.config.custom_should_partition_ops
        ):
            assert isinstance(op, torch._ops.OpOverload)
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
            return True

    # When not using cudagraphs, keep all kernels in the `call` function
    # instead of graph partition functions, since graph partition only brings
    # benefit to cudagraph
    if (
        not torch._inductor.config.triton.cudagraphs
        and _unstable_customized_partition_wrapper.wrapper is None
    ):
        return True

    # avoid duplicating logs when should_partition is called multiple times
    # on the same node
    def noop_log(msg: str, node: BaseSchedulerNode | None) -> None:
        return

    log_partition_reason = maybe_log_cudagraph_partition if should_log else noop_log

    if isinstance(node, FusedSchedulerNode):
        return any(self.should_partition(snode) for snode in node.snodes)

    assert node.node is not None

    if not node.is_gpu():
        log_partition_reason("non gpu ops", node=node)

        return True

    if isinstance(node.node, ir.DeviceCopy):
        log_partition_reason("DeviceCopy ops", node=node)
        return True

    if isinstance(node.node, ir.Conditional):
        log_partition_reason("Conditional ops", node=node)
        return True

    if getattr(node.node, "unbacked_bindings", None):
        log_partition_reason("unbacked binding ops", node=node)
        return True

    if is_cudagraph_unsafe_op(node.node):
        log_partition_reason("CUDAGraph-unsafe custom ops", node=node)
        return True

    return False


def _update_scheduler_patched(self) -> None:
    # Copied from torch._inductor.graph.GrahLowering._update_scheduler. Patches
    # this method so that we can patch Scheduler.should_partition with the
    # function above
    """
    (Re)initializes the scheduler member.  When initializing the scheduler, no CUBIN
    files should be generated (to avoid biasing any benchmarks and pessimizing
    fusion decisions).
    """
    import torch._inductor.config as config
    from torch._inductor.scheduler import Scheduler

    Scheduler.should_partition = should_partition_patched
449
    Scheduler.get_graph_partition_signature = get_graph_partition_signature_patched
450
451
452
453
454

    with config.patch("triton.store_cubin", False):
        self.scheduler = Scheduler(self.operations)


455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
# ===================================================
# torch 2.9 Inductor get_raw_stream workaround
# ===================================================
# Workaround for TorchInductor autotune using get_raw_stream() without defining it.
# This occurs when compile_sizes > 1 in compilation_config.
# For more context, see https://github.com/vllm-project/vllm/issues/30905.
def _patch_get_raw_stream_if_needed():
    """Workaround for TorchInductor autotune get_raw_stream() bug."""
    from vllm.utils.torch_utils import is_torch_equal

    # Only apply the patch for torch 2.9.0 or 2.9.1
    if is_torch_equal("2.9.0") or is_torch_equal("2.9.1"):
        import builtins

        # Check if CUDA functionality is available without initializing CUDA
        # _cuda_getCurrentRawStream only exists in CUDA builds of PyTorch
        if hasattr(torch._C, "_cuda_getCurrentRawStream"):
            from torch._C import _cuda_getCurrentRawStream as _get_raw_stream

474
            builtins.get_raw_stream = _get_raw_stream  # type: ignore[attr-defined]
475
476
477
478


_patch_get_raw_stream_if_needed()

479
480
if is_torch_equal("2.9.0"):
    from torch._inductor.codegen.wrapper import PythonWrapperCodegen
481
    from torch._inductor.graph import GraphLowering
482
483
484
485
486
487
488
    from torch.utils._config_module import _Config, _ConfigEntry

    # `custom_should_partition_ops` is a new config after 2.9.0. So this would
    # not overwrite any user configs.
    torch._inductor.config._config["custom_should_partition_ops"] = _ConfigEntry(
        _Config(default=[])
    )
489

490
    PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
491
    GraphLowering._update_scheduler = _update_scheduler_patched
492
493
494
495
496
497
498
499
500
501

# ===================================================
# torch <2.12 GraphCaptureOutput.get_runtime_env monkeypatch
# ===================================================
# PyTorch's AOT compile path omits builtins from used_globals, causing
# 'Missing required external references' errors for refs like 'type'.
# (which happens in transformers code)
# This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558
# and can be removed once torch >=2.12 is the minimum supported version.

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
# ===================================================
# torch >= 2.11 Inductor constrain_to_fx_strides monkeypatch
# ===================================================
# Inductor's constrain_to_fx_strides calls .stride() on every FX arg's meta
# value, which crashes on FakeScriptObject (the compile-time proxy for hoisted
# opaque types). The patched version skips args whose meta value is not a
# torch.Tensor.
# Upstream issue: https://github.com/pytorch/pytorch/issues/175973


_constrain_to_fx_strides_patched = False


def _apply_constrain_to_fx_strides_patch():
    """Patch lowering.constrain_to_fx_strides globally. Safe to call
    multiple times; only the first call does anything.
    Only applies for torch >= 2.11 and < 2.12."""
    global _constrain_to_fx_strides_patched
    if _constrain_to_fx_strides_patched:
        return
    _constrain_to_fx_strides_patched = True

    if not is_torch_equal_or_newer("2.11.0.dev") or is_torch_equal_or_newer(
        "2.12.0.dev"
    ):
        return

    import torch._inductor.ir as _ir
    import torch._inductor.lowering as _lowering
    from torch._inductor.virtualized import V as _V

    def _patched(fx_node, *args, **kwargs):
        def apply_constraint(arg, fx_arg):
            if isinstance(arg, _ir.IRNode):
                meta_val = fx_arg.meta.get("val")
                if isinstance(meta_val, torch.Tensor):
                    stride_order = _ir.get_stride_order(
                        meta_val.stride(), _V.graph.sizevars.shape_env
                    )
                    return _ir.ExternKernel.require_stride_order(arg, stride_order)
                return arg
            if isinstance(arg, dict):
                return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
            return arg

        args = tuple(
            apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
        )
        kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
        return args, kwargs

    _lowering.constrain_to_fx_strides = _patched


556
if is_torch_equal_or_newer("2.10.0") and not is_torch_equal_or_newer("2.12.0.dev"):
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    import builtins as _builtins
    import pickle

    from torch._dynamo.convert_frame import GraphCaptureOutput

    _original_get_runtime_env = GraphCaptureOutput.get_runtime_env

    def _safe_builtins_dict(builtins_dict: dict) -> dict:
        """Filter a builtins dict to only picklable entries for serialization."""
        result = {}
        for k, v in builtins_dict.items():
            try:
                pickle.dumps(v)
                result[k] = v
            except Exception:
                pass
        return result

    def _patched_get_runtime_env(self):  # type: ignore[no-untyped-def]
        runtime_env = _original_get_runtime_env(self)
        for ref in runtime_env.external_refs:
            if ref not in runtime_env.used_globals:
                if ref.startswith("__builtins_dict__") and ref in self.f_globals:
                    runtime_env.used_globals[ref] = _safe_builtins_dict(
                        self.f_globals[ref]
                    )
                elif hasattr(_builtins, ref):
                    runtime_env.used_globals[ref] = getattr(_builtins, ref)
        return runtime_env

    GraphCaptureOutput.get_runtime_env = _patched_get_runtime_env
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636

# ===================================================
# torch 2.10 FxGraphCachePickler.dumps ValueError fix
# ===================================================
# PyTorch 2.10's FxGraphCachePickler.dumps() doesn't catch ValueError,
# causing torch.compile cache failures when tensors with non-standard
# layouts (e.g. blocked-layout prepacked weights) are serialized.
# PyTorch mainline fixed this in pytorch/pytorch#176557 (merged 2026-03-04).
# This is a thin backport for 2.10 users; remove once 2.10 is dropped.


def _apply_fxgraphcache_pickle_patch(pickler_cls, bypass_cls):
    """Wrap pickler_cls.dumps to convert ValueError into bypass_cls.

    Idempotent: sets `_vllm_fxgraph_dumps_patched` on the class after the
    first apply to prevent re-application. The wrapper function is also
    marked with `_vllm_patched` as an additional safeguard.
    """
    if getattr(pickler_cls, "_vllm_fxgraph_dumps_patched", False):
        return

    original_dumps = pickler_cls.dumps
    if hasattr(original_dumps, "_vllm_patched"):
        return

    def patched_dumps(self, obj):
        try:
            return original_dumps(self, obj)
        except ValueError as e:
            raise bypass_cls("Failed to pickle cache key") from e

    patched_dumps._vllm_patched = True  # type: ignore[attr-defined]
    pickler_cls.dumps = patched_dumps
    pickler_cls._vllm_fxgraph_dumps_patched = True  # type: ignore[attr-defined]


def _patch_fxgraphcache_pickle_if_needed():
    """Apply FxGraphCachePickler.dumps ValueError backport when on torch 2.10.x."""
    from vllm.utils.torch_utils import is_torch_equal_or_newer

    if not is_torch_equal_or_newer("2.10.0") or is_torch_equal_or_newer("2.11.0"):
        return

    from torch._inductor.codecache import BypassFxGraphCache, FxGraphCachePickler

    _apply_fxgraphcache_pickle_patch(FxGraphCachePickler, BypassFxGraphCache)


_patch_fxgraphcache_pickle_if_needed()