backends.py 20.2 KB
Newer Older
1
import copy
2
import dataclasses
3
import time
4
from contextlib import ExitStack
5
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
6
from unittest.mock import patch
7
8
9
10

import torch
import torch.fx as fx

11
import vllm.envs as envs
12
from vllm.config import CompilationConfig, VllmConfig
13
from vllm.logger import init_logger
14
from vllm.utils import weak_ref_tensors
15

16
from .counter import compilation_counter
17
from .inductor_pass import InductorPass
18
from .monitor import end_monitoring_torch_compile
19
from .pass_manager import PostGradPassManager
20
21
22

logger = init_logger(__name__)

23

24
25
26
def wrap_inductor(graph,
                  example_inputs,
                  additional_inductor_config,
27
28
29
                  compilation_config: CompilationConfig,
                  graph_index: int = 0,
                  num_graphs: int = 1,
30
31
                  runtime_shape: Optional[int] = None,
                  use_inductor: bool = True):
32
33
34
35
36
    if graph_index == 0:
        # before compiling the first graph, record the start time
        global compilation_start_time
        compilation_start_time = time.time()

37
38
39
40
41
    if not use_inductor:
        return graph

    compilation_counter.num_inductor_compilations += 1

42
    from torch._inductor import config
43
    current_config = config.get_config_copy()
44
    from torch._inductor.compile_fx import compile_fx
45
46
47

    if additional_inductor_config is not None:
        current_config.update(additional_inductor_config)
48

49
50
51
52
53
54
    if isinstance(runtime_shape, int):
        # for a specific batchsize, tuning triton kernel parameters
        # can be beneficial
        current_config["max_autotune"] = True
        current_config["coordinate_descent_tuning"] = True

55
56
57
    # inductor can inplace modify the graph, so we need to copy it
    # see https://github.com/pytorch/pytorch/issues/138980
    graph = copy.deepcopy(graph)
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    compiled_graph = compile_fx(graph,
                                example_inputs,
                                config_patches=current_config)

    # after compiling the last graph, record the end time
    if graph_index == num_graphs - 1:
        now = time.time()
        elapsed = now - compilation_start_time
        compilation_config.compilation_time += elapsed
        if runtime_shape is None:
            logger.info("Compiling a graph for general shape takes %.2f s",
                        elapsed)
        else:
            logger.info("Compiling a graph for shape %s takes %.2f s",
                        runtime_shape, elapsed)

    return compiled_graph
75
76


77
78
79
@dataclasses.dataclass
class SplitItem:
    submod_name: str
80
    graph_id: int
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    is_splitting_graph: bool
    graph: fx.GraphModule


def split_graph(graph: fx.GraphModule,
                ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
        if node.op == 'call_function' and str(node.target) in ops:
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
107
        graph,
108
109
110
        None,
        lambda node: node_to_subgraph_id[node],
        keep_original_order=True)
111

112
    outputs = []
113

114
    names = [name for (name, module) in split_gm.named_modules()]
115

116
117
118
119
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
120

121
        module = getattr(split_gm, name)
122

123
        graph_id = int(name.replace("submod_", ""))
124
125
126
127
128
        outputs.append(
            SplitItem(name, graph_id, (graph_id in split_op_graphs), module))

    # sort by intetger graph_id, rather than string name
    outputs.sort(key=lambda x: x.graph_id)
129

130
    return split_gm, outputs
131
132


133
134
135
# we share the global graph pool among all the backends
global_graph_pool = None

136
137
compilation_start_time = 0.0

138
139
140
141
142
143

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
144
145
146
147
148

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
149
150
151
    """

    def __init__(self, module: torch.fx.GraphModule,
152
153
                 compile_submod_names: List[str], vllm_config: VllmConfig,
                 graph_pool):
154
155
156
157
        super().__init__(module)
        from torch._guards import detect_fake_mode
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
158
        self.compilation_config = vllm_config.compilation_config
159
        self.graph_pool = graph_pool
160
        self.vllm_config = vllm_config
161
162
163
164
165
166

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
167
168
        with self.fake_mode:
            return super().run(*fake_args)
169
170
171
172
173
174
175
176

    def call_module(self, target: torch.fx.node.Target,
                    args: Tuple[torch.fx.node.Argument,
                                ...], kwargs: Dict[str, Any]) -> Any:
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
177
            index = self.compile_submod_names.index(target)
178
179
180
181
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
182
            global compilation_start_time
183
184
185
            compiled_graph_for_general_shape = wrap_inductor(
                submod,
                args,
186
187
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
188
189
                graph_index=index,
                num_graphs=len(self.compile_submod_names),
190
                runtime_shape=None,
191
                use_inductor=self.compilation_config.use_inductor)
192
193

            self.module.__dict__[target] = PiecewiseBackend(
194
                submod, self.vllm_config, self.graph_pool, index,
195
                len(self.compile_submod_names), sym_shape_indices,
196
197
198
199
200
201
202
                compiled_graph_for_general_shape)

            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


203
204
205
206
class VllmBackend:
    """The compilation backend for `torch.compile` with VLLM.
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
207

208
209
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
210

211
212
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
213
    """
214

215
216
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
217
218
219
220
221
222
223
224
    graph_pool: Any
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
    piecewise_graphs: List[SplitItem]
    returned_callable: Callable
225
226
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
227
228
    sym_tensor_indices: List[int]
    input_buffers: List[torch.Tensor]
229

230
231
    def __init__(
        self,
232
        vllm_config: VllmConfig,
233
    ):
234
235
236
237
238
239
240
241
        global global_graph_pool
        if global_graph_pool is None:
            global_graph_pool = torch.cuda.graph_pool_handle()

        # TODO: in the future, if we want to use multiple
        # streams, it might not be safe to share a global pool.
        # only investigate this when we use multiple streams
        self.graph_pool = global_graph_pool
242
243
244

        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
245

246
247
248
        self.sym_tensor_indices = []
        self.input_buffers = []

249
250
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
251

252
253
254
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

255
    def configure_post_pass(self):
256
        config = self.compilation_config
257
        self.post_grad_pass_manager.configure(config.pass_config)
258

259
260
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
261
        inductor_config = config.inductor_compile_config
262
263
264
265
266
267
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
            # Config should automatically wrap all inductor passes
            assert isinstance(inductor_config[PASS_KEY], InductorPass)
            self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
268

269
270
    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

271
272
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
273
        compilation_counter.num_graphs_seen += 1
274
275
276
        from .monitor import torch_compile_start_time
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
277
        self.compilation_config.compilation_time += dynamo_time
278
279
280
281
282
283

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
284
        self.configure_post_pass()
285
286

        self.split_gm, self.piecewise_graphs = split_graph(
287
            graph, self.compilation_config.splitting_ops)
288

289
        from torch._dynamo.utils import lazy_format_graph_code
290
291
292
        logger.debug("%s", lazy_format_graph_code("before split", self.graph))
        logger.debug("%s", lazy_format_graph_code("after split",
                                                  self.split_gm))
293

294
295
296
297
298
299
300
301
302
303
        compilation_counter.num_piecewise_graphs_seen += len(
            self.piecewise_graphs)
        submod_names_to_compile = [
            item.submod_name for item in self.piecewise_graphs
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
        PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
304
                                    self.vllm_config,
305
                                    self.graph_pool).run(*example_inputs)
306
307
308

        self._called = True

309
310
        if not self.compilation_config.use_cudagraph or \
            not self.compilation_config.cudagraph_copy_inputs:
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
            return self.split_gm

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
        self.sym_tensor_indices = [
            i for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

        def copy_and_call(*args):
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

        return copy_and_call
349
350
351
352
353
354
355
356
357
358
359
360
361
362


@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int
    need_to_compile: bool  # the size is in compile_sizes
    use_cudagraph: bool  # the size is in capture_sizes

    compiled: bool = False
    runnable: Callable = None  # type: ignore
    num_finished_warmup: int = 0
    cudagraph: Optional[torch.cuda.CUDAGraph] = None
    output: Optional[Any] = None

363
364
365
366
    # for cudagraph debugging, track the input addresses
    # during capture, and check if they are the same during replay
    input_addresses: Optional[List[int]] = None

367
368
369

class PiecewiseBackend:

370
371
372
    def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
                 graph_pool: Any, piecewise_compile_index: int,
                 total_piecewise_compiles: int, sym_shape_indices: List[int],
373
                 compiled_graph_for_general_shape: Callable):
374
375
376
377
378
379
        """
        The backend for piecewise compilation.
        It mainly handles the compilation and cudagraph capturing.

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
380
        `compilation_config.compile_sizes`.
381
382
383
384
385
386
387

        Independently, we will capture cudagraph for different shapes.

        If a shape needs both compilation and cudagraph, we will
        compile it first, and then capture cudagraph.
        """
        self.graph = graph
388
389
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
390
        self.graph_pool = graph_pool
391
392
393
394
395
396
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles

        self.is_first_graph = piecewise_compile_index == 0
        self.is_last_graph = (
            piecewise_compile_index == total_piecewise_compiles - 1)
397
398

        self.compile_sizes: Set[int] = set(
399
            self.compilation_config.compile_sizes)
400
        self.capture_sizes: Set[int] = set(
401
402
            self.compilation_config.capture_sizes
        ) if self.compilation_config.use_cudagraph else set()
403
404
405

        self.first_run_finished = False

406
        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa
407

408
        self.sym_shape_indices = sym_shape_indices
409

410
411
        self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

412
413
414
        # the entries for different shapes that we need to either
        # compile or capture cudagraph
        self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
415
416
        self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union(
            self.capture_sizes)
417
418
419
420
421
422
423
424
425
426
        for shape in self.compile_sizes.union(self.capture_sizes):
            self.concrete_size_entries[shape] = ConcreteSizeEntry(
                runtime_shape=shape,
                need_to_compile=shape in self.compile_sizes,
                use_cudagraph=shape in self.capture_sizes,
            )

    def __call__(self, *args) -> Any:
        if not self.first_run_finished:
            self.first_run_finished = True
427
428
            # no specific sizes to compile
            if self.is_last_graph and not self.to_be_compiled_sizes:
429
                end_monitoring_torch_compile(self.vllm_config)
430
431
432
433
434
435
436
437
            return self.compiled_graph_for_general_shape(*args)

        runtime_shape = args[self.sym_shape_indices[0]]
        if runtime_shape not in self.concrete_size_entries:
            # we don't need to do anything for this shape
            return self.compiled_graph_for_general_shape(*args)

        entry = self.concrete_size_entries[runtime_shape]
438

439
440
        if entry.runnable is None:
            entry.runnable = self.compiled_graph_for_general_shape
441

442
443
        if entry.need_to_compile and not entry.compiled:
            entry.compiled = True
444
            self.to_be_compiled_sizes.remove(runtime_shape)
445
446
447
448
            # args are real arguments
            entry.runnable = wrap_inductor(
                self.graph,
                args,
449
450
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
451
452
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
453
                runtime_shape=runtime_shape,
454
                use_inductor=self.compilation_config.use_inductor)
455

456
457
            # finished compilations for all required shapes
            if self.is_last_graph and not self.to_be_compiled_sizes:
458
                end_monitoring_torch_compile(self.vllm_config)
459

460
461
462
463
        if not entry.use_cudagraph:
            return entry.runnable(*args)

        if entry.cudagraph is None:
464
            if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups:  # noqa
465
466
467
468
469
                entry.num_finished_warmup += 1
                if self.is_first_graph:
                    logger.debug(
                        "Warming up %s/%s for shape %s",
                        entry.num_finished_warmup,
470
                        self.compilation_config.cudagraph_num_of_warmups,
471
472
473
474
                        runtime_shape)
                return entry.runnable(*args)

            if self.is_first_graph:
475
476
477
478
479
                # Since we capture cudagraph for many different shapes and
                # capturing is fast, we don't need to log it for every shape.
                # We only log it in the debug mode.
                logger.debug("Capturing a cudagraph for shape %s",
                             runtime_shape)
480

481
482
483
484
            input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            entry.input_addresses = input_addresses
485
            cudagraph = torch.cuda.CUDAGraph()
486

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
            with ExitStack() as stack:
                if not self.is_first_graph:
                    # during every model forward, we will capture
                    # many pieces of cudagraphs (roughly one per layer).
                    # running gc again and again across layers will
                    # make the cudagraph capture very slow.
                    # therefore, we only run gc for the first graph,
                    # and disable gc for the rest of the graphs.
                    stack.enter_context(patch("gc.collect", lambda: None))
                    stack.enter_context(
                        patch("torch.cuda.empty_cache", lambda: None))

                # mind-exploding: carefully manage the reference and memory.
                with torch.cuda.graph(cudagraph, pool=self.graph_pool):
                    # `output` is managed by pytorch's cudagraph pool
                    output = entry.runnable(*args)
                    if self.is_last_graph:
                        # by converting it to weak ref,
                        # the original `output` will immediately be released
                        # to save memory. It is only safe to do this for
                        # the last graph, because the output of the last graph
                        # will not be used by any other cuda graph.
                        output = weak_ref_tensors(output)
510
511
512
513
514

            # here we always use weak ref for the output
            # to save memory
            entry.output = weak_ref_tensors(output)
            entry.cudagraph = cudagraph
515
516
517

            compilation_counter.num_cudagraph_caputured += 1

518
519
520
521
522
523
524
525
526
527
528
529
530
531
            # important: we need to return the output, rather than
            # the weak ref of the output, so that pytorch can correctly
            # manage the memory during cuda graph capture
            return output

        if self.is_debugging_mode:
            # check if the input addresses are the same
            new_input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            assert new_input_addresses == entry.input_addresses, (
                "Input addresses for cudagraphs are different during replay."
                f" Expected {entry.input_addresses}, got {new_input_addresses}"
            )
532
533
534

        entry.cudagraph.replay()
        return entry.output