backends.py 30.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
import json
8
import operator
9
10
import os
import pprint
11
import time
12
from collections.abc import Callable, Sequence
13
from contextlib import contextmanager
14
from copy import deepcopy
15
from functools import partial
16
from typing import Any
17
18
19

import torch
import torch.fx as fx
20
from torch._dispatch.python import enable_python_dispatcher
21

22
import vllm.envs as envs
23
24
25
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
26
    should_split,
27
)
28
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
29
from vllm.config.compilation import DynamicShapesType
30
from vllm.config.utils import Range, hash_factors
31
from vllm.logger import init_logger
32
from vllm.logging_utils import lazy
33
from vllm.platforms import current_platform
34
from vllm.utils.import_utils import resolve_obj_by_qualname
35
from vllm.utils.torch_utils import is_torch_equal_or_newer
36

37
from .caching import VllmSerializableFunction
38
39
40
41
42
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
43
    is_compile_cache_enabled,
44
)
45
from .counter import compilation_counter
46
47
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
48
49
50

logger = init_logger(__name__)

51

52
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
53
    if compilation_config.backend == "inductor":
54
55
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
56
57
58
59
60
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
61
            logger.debug("Using InductorStandaloneAdaptor")
62
63
64
            return InductorStandaloneAdaptor(
                compilation_config.compile_cache_save_format
            )
65
        else:
66
            logger.debug("Using InductorAdaptor")
67
            return InductorAdaptor()
68
    elif compilation_config.backend == "eager":
69
        logger.debug("Using EagerAdaptor")
70
        return EagerAdaptor()
71
72
73
74
75
    else:
        logger.debug("Using custom backend: %s", compilation_config.backend)
        compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
        assert isinstance(compiler, CompilerInterface)
        return compiler
76
77


78
79
80
81
82
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
83

84
85
86
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
87

88
89
90
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
91
92
    """

93
    def __init__(self, compilation_config: CompilationConfig):
94
        self.cache: dict[tuple[Range, int, str], Any] = dict()
95
        self.is_cache_updated = False
96
97
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
98

99
100
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
101

102
    @contextmanager
103
    def compile_context(self, compile_range: Range):
104
105
106
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
107
        with pass_context(compile_range):
108
            if self.compilation_config.use_inductor_graph_partition:
109
                with inductor_partition_rule_context(
110
                    self.compilation_config.splitting_ops
111
                ):
112
113
114
115
                    yield
            else:
                yield

116
117
118
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

134
        self.disable_cache = disable_cache
135
        self.cache_dir = cache_dir
136
137
138
139
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
140
            with open(self.cache_file_path) as f:
141
142
143
144
145
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

146
147
148
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
149
150

    def save_to_file(self):
151
        if self.disable_cache or not self.is_cache_updated:
152
            return
153
154
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
155
        with open(self.cache_file_path, "w") as f:
156
157
            f.write(data)

158
159
160
161
162
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
163
        compile_range: Range,
164
    ) -> Callable | None:
165
        if (compile_range, graph_index, self.compiler.name) not in self.cache:
166
            return None
167
        handle = self.cache[(compile_range, graph_index, self.compiler.name)]
168
        compiled_graph = self.compiler.load(
169
170
171
172
173
174
175
176
            handle, graph, example_inputs, graph_index, compile_range
        )
        logger.debug(
            "Directly load the %s-th graph for compile range %sfrom %s via handle %s",
            graph_index,
            str(compile_range),
            self.compiler.name,
            handle,
177
        )
178
179
        return compiled_graph

180
181
182
183
184
185
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
186
        compile_range: Range,
187
188
189
        graph_index: int = 0,
        num_graphs: int = 1,
    ) -> Any:
190
        if graph_index == 0:
191
192
193
194
195
196
197
198
199
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
200
        compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
201
        if compiled_graph is not None:
202
203
204
205
206
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
207
                compilation_config.compilation_time += elapsed
208
209
210
211
212
213
                logger.info(
                    "Directly load the compiled graph(s) for compile range %s "
                    "from the cache, took %.3f s",
                    str(compile_range),
                    elapsed,
                )
214
215
216
217
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
218
219
220
221
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
222
223
224
225
            maybe_key = "artifact_compile_range_"
            maybe_key += f"{compile_range.start}_{compile_range.end}"
            maybe_key += f"_subgraph_{graph_index}"
        with self.compile_context(compile_range):
226
227
228
229
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
230
                compile_range,
231
232
                maybe_key,
            )
233
234
235
236

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
237
        if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
238
            self.cache[(compile_range, graph_index, self.compiler.name)] = handle
239
            compilation_counter.num_cache_entries_updated += 1
240
            self.is_cache_updated = True
241
242
            if graph_index == 0:
                # adds some info logging for the first graph
243
244
245
                logger.info_once(
                    "Cache the graph of compile range %s for later use",
                    str(compile_range),
246
                )
247
248
249
250
251
252
253
            logger.debug(
                "Store the %s-th graph for compile range%s from %s via handle %s",
                graph_index,
                str(compile_range),
                self.compiler.name,
                handle,
            )
254
255
256
257
258
259

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
260
261
262
263
264
265
            logger.info_once(
                "Compiling a graph for compile range %s takes %.2f s",
                str(compile_range),
                elapsed,
                scope="local",
            )
266

267
        return compiled_graph
268
269


270
271
272
@dataclasses.dataclass
class SplitItem:
    submod_name: str
273
    graph_id: int
274
275
276
277
    is_splitting_graph: bool
    graph: fx.GraphModule


278
def split_graph(
279
    graph: fx.GraphModule, splitting_ops: list[str]
280
) -> tuple[fx.GraphModule, list[SplitItem]]:
281
282
    # split graph by ops
    subgraph_id = 0
283
284
    node_to_subgraph_id: dict[fx.Node, int] = {}
    split_op_graphs: list[int] = []
285
286
287
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
288

289
290
291
292
293
294
295
296
297
298
299
300
        # Check if this is a getitem operation on a node from an earlier subgraph.
        # If so, assign it to the same subgraph as its input to avoid passing entire
        # tuple as input to submodules, which is against standalone_compile and
        # AoTAutograd input requirement.
        if node.op == "call_function" and node.target == operator.getitem:
            # Assign this getitem to the same subgraph as its input
            input_node = node.args[0]
            if input_node.op != "placeholder":
                assert input_node in node_to_subgraph_id
                node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
                continue

301
        if should_split(node, splitting_ops):
302
303
304
305
306
307
308
309
310
311
312
313
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
314
315
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
316

317
    outputs = []
318

319
    names = [name for (name, module) in split_gm.named_modules()]
320

321
322
323
324
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
325

326
        module = getattr(split_gm, name)
327

328
        graph_id = int(name.replace("submod_", ""))
329
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
330

331
    # sort by integer graph_id, rather than string name
332
    outputs.sort(key=lambda x: x.graph_id)
333

334
    return split_gm, outputs
335
336


337
338
compilation_start_time = 0.0

339
340
341
342
343
344

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
345
346
347
348
349

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
350
351
    """

352
353
354
355
356
357
358
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
359
360
        super().__init__(module)
        from torch._guards import detect_fake_mode
361

362
363
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
364
365
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
366
        self.vllm_backend = vllm_backend
367
368
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
369
370

    def run(self, *args):
371
        # maybe instead just assert inputs are fake?
372
373
374
375
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
376
        with self.fake_mode, enable_python_dispatcher():
377
            return super().run(*fake_args)
378

379
380
381
382
383
384
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
385
        assert isinstance(target, str)
386

387
388
389
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
390
            index = self.compile_submod_names.index(target)
391
            submod = self.fetch_attr(target)
392

393
394
395
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
396

397
            # Lazy import here to avoid circular import
398
            from .piecewise_backend import PiecewiseBackend
399

400
            piecewise_backend = PiecewiseBackend(
401
402
403
404
405
406
407
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                self.vllm_backend,
            )
408

409
410
411
412
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
413
414
415
416
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

417
418
419
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
420
421
                    current_platform.get_static_graph_wrapper_cls()
                )
422
423
424
425
426
427
428
429
430
431
432
433

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
434
435
436
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
437
438
439
            else:
                self.module.__dict__[target] = piecewise_backend

440
441
442
443
444
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


445
446
447
448
449
450
451
452
453
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
454
    assert tag != model_tag, (
455
        f"Model tag {tag} is the same as the current tag {model_tag}."
456
    )
457
458
459
460
461
462
463
464
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


465
class VllmBackend:
466
    """The compilation backend for `torch.compile` with vLLM.
467
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
468
    where we customize the compilation.
469

470
471
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
472

473
474
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
475
    """
476

477
478
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
479
480
481
482
483
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
484
    piecewise_graphs: list[SplitItem]
485
    returned_callable: Callable
486
487
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
488
489
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
490
    compiler_manager: CompilerManager
491
492
493
    # Copy of CompilationConfig.inductor_compile_config +
    # an entry for PostGradPassManager
    inductor_config: dict[str, Any]
494

495
496
    def __init__(
        self,
497
        vllm_config: VllmConfig,
498
        prefix: str = "",
499
    ):
500
501
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
502
        # e.g. language_model, vision_model, etc.
503
504
505
506
507
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

508
        # Passes to run on the graph post-grad.
509
510
511
512
        self.pass_manager = resolve_obj_by_qualname(
            current_platform.get_pass_manager_cls()
        )()
        self.pass_key = current_platform.pass_key
513

514
515
516
        self.sym_tensor_indices = []
        self.input_buffers = []

517
518
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
519

520
        self.compiler_manager: CompilerManager = CompilerManager(
521
522
            self.compilation_config
        )
523

524
525
526
527
528
529
        # Deepcopy the inductor config to detach the post-grad custom pass
        # from CompilationConfig.
        # We want to avoid PostGradPassManager in CompilationConfig because
        # in future we need PostGradPassManager.uuid() to be executed
        # only at compile time.
        self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
530
531
532
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

533
    def configure_post_pass(self):
534
        self.pass_manager.configure(self.vllm_config)
535

536
537
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
538
539
540
541
542
        if self.pass_key in self.inductor_config:
            if isinstance(self.inductor_config[self.pass_key], PostGradPassManager):
                raise ValueError(
                    "PostGradPassManager can not be kept in CompilationConfig."
                )
543
            else:
544
                # Config should automatically wrap all inductor passes
545
546
547
548
549
550
551
                assert isinstance(
                    self.compilation_config.inductor_compile_config[self.pass_key],
                    InductorPass,
                )
                self.pass_manager.add(
                    self.compilation_config.inductor_compile_config[self.pass_key]
                )
552
        self.inductor_config[self.pass_key] = self.pass_manager
553

554
555
556
    def __call__(
        self, graph: fx.GraphModule, example_inputs
    ) -> VllmSerializableFunction:
557
        vllm_config = self.vllm_config
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        # Minimal hashing here with existing utilities, reused below.

        env_factors = envs.compile_factors()
        env_hash = hash_factors(env_factors)
        # Compute config/compiler/code hashes once and reuse
        config_hash = vllm_config.compute_hash()
        compiler_hash = self.compiler_manager.compute_hash(vllm_config)
        forward_code_files = list(sorted(self.compilation_config.traced_files))

        logger.debug(
            "Traced files (to be considered for compilation cache):\n%s",
            lazy(lambda: "\n".join(forward_code_files)),
        )
        hash_content = []
        for filepath in forward_code_files:
            hash_content.append(filepath)
            if filepath == "<string>":
                # This means the function was dynamically generated, with
                # e.g. exec(). We can't actually check these.
                continue
            try:
                with open(filepath) as f:
                    hash_content.append(f.read())
            except Exception:
                logger.warning("Failed to read file %s", filepath)
                continue
        code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
        # Clear after consumption
        self.compilation_config.traced_files.clear()
587
588
589
590
591
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.
592
593
594
595
            factors = [env_hash, config_hash, code_hash, compiler_hash]
            # Use SHA-256 for cache key hashing to be consistent across
            # compute_hash functions. Truncate for a short cache dir name.
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
596
            cache_dir = os.path.join(
597
                envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
598
599
600
            )
            self.compilation_config.cache_dir = cache_dir

601
        cache_dir = self.compilation_config.cache_dir
602
        os.makedirs(cache_dir, exist_ok=True)
603
        self.compilation_config.cache_dir = cache_dir
604
605
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
606
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
607
        os.makedirs(local_cache_dir, exist_ok=True)
608
        self.compilation_config.local_cache_dir = local_cache_dir
609

610
        # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
611
        disable_cache = not is_compile_cache_enabled(self.inductor_config)
612
613

        if disable_cache:
614
            logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
615
        else:
616
617
618
619
            logger.info_once(
                "Using cache directory: %s for vLLM's torch.compile",
                local_cache_dir,
                scope="local",
620
            )
621

622
623
624
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
        # Reuses existing cache key

        logger.debug(
            "torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
            env_hash,
            config_hash,
            compiler_hash,
            code_hash,
            local_cache_dir,
        )

        # Persist and log only hash-relevant factors together.
        try:
            logger.debug(
                "Compile env factors (raw):\n%s\nVllm config hash: %s",
                lazy(partial(pprint.pformat, env_factors, width=120)),
                config_hash,
            )
            meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
            if not os.path.exists(meta_path):
                with open(meta_path, "w") as f:
                    json.dump(
                        {
                            "env": env_factors,  # raw factors used for env_hash
                            "config_hash": config_hash,
                            "code_hash": code_hash,
                            "compiler_hash": compiler_hash,
                        },
                        f,
                        indent=2,
                        sort_keys=True,
                    )
        except Exception:
            # Best-effort only; metadata write failures are non-fatal.
            logger.warning(
                (
                    "Could not write compile cache metadata at %s; continuing without "
                    "metadata. Compiled cache remains valid; diagnostics may be "
                    "limited."
                ),
                local_cache_dir,
                exc_info=True,
            )

670
671
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
672
        compilation_counter.num_graphs_seen += 1
673
        from .monitor import torch_compile_start_time
674

675
        dynamo_time = time.time() - torch_compile_start_time
676
677
678
        logger.info_once(
            "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
        )
679
        self.compilation_config.compilation_time += dynamo_time
680
681
682
683
684
685

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
686
        self.configure_post_pass()
687

688
689
690
691
692
693
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

694
        self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
695

696
        from torch._dynamo.utils import lazy_format_graph_code
697
698
699
700
701

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
702

703
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
704
        submod_names_to_compile = [
705
706
            item.submod_name
            for item in self.piecewise_graphs
707
708
709
            if not item.is_splitting_graph
        ]

710
711
712
713
714
715
716
717
718
719
        # Extract fake values from the graph to use them when needed.
        all_fake_values = []
        for i in graph.graph.find_nodes(op="placeholder"):
            all_fake_values.append(i.meta["example_value"])

        fake_args = [
            all_fake_values[i] if isinstance(t, torch.Tensor) else t
            for i, t in enumerate(example_inputs)
        ]

720
721
        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
722
723
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
724
        ).run(*fake_args)
725

726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
        from torch._guards import detect_fake_mode

        fake_mode = detect_fake_mode()

        if (
            self.compilation_config.dynamic_shapes_config.evaluate_guards
            and self.compilation_config.dynamic_shapes_config.type
            == DynamicShapesType.BACKED
        ):
            from torch.utils._sympy.value_ranges import ValueRanges

            # Drop counter-0/1 specializations guards; for backed dynamic shapes,
            # torch.compile will specialize for 0/1 inputs or otherwise guards that
            # shape is >= 2. This is because it's really hard not to hit a check
            # against 0/1. When we evaluate shape guards, we exclude checking those
            # guards (We would fail always otherwise).

            # We avoid that by updating the ranges of backed sizes when the min is
            # 2 for any, we assume it's 0.
            for s, r in fake_mode.shape_env.var_to_range.items():
                if r.lower == 2:
                    fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)

749
750
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
751
752
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
753
            # use `print_readable` because it can include submodules
754
755
756
757
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
758
759
760
761
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

762
763
764
            logger.debug_once(
                "Computation graph saved to %s", graph_path, scope="local"
            )
765

766
767
        self._called = True

768
769
770
771
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
772
773
774
            return VllmSerializableFunction(
                graph, example_inputs, self.prefix, self.split_gm
            )
775
776

        # index of tensors that have symbolic shapes (batch size)
777
778
779
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
780

781
        self.sym_tensor_indices = [
782
783
784
785
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
786
787
788
789
790
791
792
793
794
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
795
796
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
797
798
799
800
801
802
803
804
805
806
807
808
809
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

810
811
812
        return VllmSerializableFunction(
            graph, example_inputs, self.prefix, copy_and_call
        )