backends.py 18 KB
Newer Older
1
import copy
2
import dataclasses
3
from contextlib import ExitStack
4
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
5
from unittest.mock import patch
6
7
8
9

import torch
import torch.fx as fx

10
import vllm.envs as envs
11
from vllm.config import CompilationConfig
12
from vllm.logger import init_logger
13
from vllm.utils import weak_ref_tensors
14

15
from .counter import compilation_counter
16
17
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
18
19
20

logger = init_logger(__name__)

21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def wrap_inductor(graph,
                  example_inputs,
                  additional_inductor_config,
                  do_logging=False,
                  runtime_shape: Optional[int] = None,
                  use_inductor: bool = True):
    if not use_inductor:
        return graph

    compilation_counter.num_inductor_compilations += 1

    if do_logging:
        if runtime_shape is None:
            logger.info("Compiling a graph for general shape")
        else:
            logger.info("Compiling a graph for shape %s", runtime_shape)

39
40
41
    from torch._inductor import config
    current_config = config.shallow_copy_dict()
    from torch._inductor.compile_fx import compile_fx
42
43
44

    if additional_inductor_config is not None:
        current_config.update(additional_inductor_config)
45
46
47
48

    # inductor can inplace modify the graph, so we need to copy it
    # see https://github.com/pytorch/pytorch/issues/138980
    graph = copy.deepcopy(graph)
49
    return compile_fx(graph, example_inputs, config_patches=current_config)
50
51


52
53
54
@dataclasses.dataclass
class SplitItem:
    submod_name: str
55
    graph_id: int
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    is_splitting_graph: bool
    graph: fx.GraphModule


def split_graph(graph: fx.GraphModule,
                ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
        if node.op == 'call_function' and str(node.target) in ops:
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
82
        graph,
83
84
85
        None,
        lambda node: node_to_subgraph_id[node],
        keep_original_order=True)
86

87
    outputs = []
88

89
    names = [name for (name, module) in split_gm.named_modules()]
90

91
92
93
94
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
95

96
        module = getattr(split_gm, name)
97

98
        graph_id = int(name.replace("submod_", ""))
99
100
101
102
103
        outputs.append(
            SplitItem(name, graph_id, (graph_id in split_op_graphs), module))

    # sort by intetger graph_id, rather than string name
    outputs.sort(key=lambda x: x.graph_id)
104

105
    return split_gm, outputs
106
107


108
109
110
111
112
113
114
115
116
# we share the global graph pool among all the backends
global_graph_pool = None


class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
117
118
119
120
121

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    """

    def __init__(self, module: torch.fx.GraphModule,
                 compile_submod_names: List[str],
                 compilation_configs: CompilationConfig, graph_pool):
        super().__init__(module)
        from torch._guards import detect_fake_mode
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
        self.compilation_configs = compilation_configs
        self.graph_pool = graph_pool

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
139
140
        with self.fake_mode:
            return super().run(*fake_args)
141
142
143
144
145
146
147
148

    def call_module(self, target: torch.fx.node.Target,
                    args: Tuple[torch.fx.node.Argument,
                                ...], kwargs: Dict[str, Any]) -> Any:
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
149
            index = self.compile_submod_names.index(target)
150
151
152
153
154
155
156
157
158
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
            compiled_graph_for_general_shape = wrap_inductor(
                submod,
                args,
                self.compilation_configs.inductor_compile_config,
                runtime_shape=None,
159
                do_logging=index == 0,
160
161
162
                use_inductor=self.compilation_configs.use_inductor)

            self.module.__dict__[target] = PiecewiseBackend(
163
164
                submod, self.compilation_configs, self.graph_pool, index,
                len(self.compile_submod_names), sym_shape_indices,
165
166
167
168
169
170
171
                compiled_graph_for_general_shape)

            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


172
173
174
175
class VllmBackend:
    """The compilation backend for `torch.compile` with VLLM.
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
176

177
178
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
179

180
181
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
182
    """
183

184
185
186
187
188
189
190
191
192
    compilation_configs: CompilationConfig
    graph_pool: Any
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
    piecewise_graphs: List[SplitItem]
    returned_callable: Callable
193
194
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
195
196
    sym_tensor_indices: List[int]
    input_buffers: List[torch.Tensor]
197

198
199
200
201
    def __init__(
        self,
        compilation_configs: CompilationConfig,
    ):
202
203
204
205
206
207
208
209
        global global_graph_pool
        if global_graph_pool is None:
            global_graph_pool = torch.cuda.graph_pool_handle()

        # TODO: in the future, if we want to use multiple
        # streams, it might not be safe to share a global pool.
        # only investigate this when we use multiple streams
        self.graph_pool = global_graph_pool
210
211
212

        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
213

214
215
216
        self.sym_tensor_indices = []
        self.input_buffers = []

217
218
        self.compilation_configs = compilation_configs

219
220
221
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

222
    def configure_post_pass(self):
223
        config = self.compilation_configs
224
        self.post_grad_pass_manager.configure(config.pass_config)
225

226
227
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
228
        inductor_config = config.inductor_compile_config
229
230
231
232
233
234
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
            # Config should automatically wrap all inductor passes
            assert isinstance(inductor_config[PASS_KEY], InductorPass)
            self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
235

236
237
238
239
240
241
242
243
244
    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

        compilation_counter.num_graphs_seen += 1

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
245
        self.configure_post_pass()
246
247

        self.split_gm, self.piecewise_graphs = split_graph(
248
            graph, self.compilation_configs.splitting_ops)
249

250
        from torch._dynamo.utils import lazy_format_graph_code
251
252
253
        logger.debug("%s", lazy_format_graph_code("before split", self.graph))
        logger.debug("%s", lazy_format_graph_code("after split",
                                                  self.split_gm))
254

255
256
257
258
259
260
261
262
263
264
265
266
        compilation_counter.num_piecewise_graphs_seen += len(
            self.piecewise_graphs)
        submod_names_to_compile = [
            item.submod_name for item in self.piecewise_graphs
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
        PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
                                    self.compilation_configs,
                                    self.graph_pool).run(*example_inputs)
267
268
269

        self._called = True

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        if not self.compilation_configs.use_cudagraph or \
            not self.compilation_configs.cudagraph_copy_inputs:
            return self.split_gm

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
        self.sym_tensor_indices = [
            i for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

        def copy_and_call(*args):
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

        return copy_and_call
310
311
312
313
314
315
316
317
318
319
320
321
322
323


@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int
    need_to_compile: bool  # the size is in compile_sizes
    use_cudagraph: bool  # the size is in capture_sizes

    compiled: bool = False
    runnable: Callable = None  # type: ignore
    num_finished_warmup: int = 0
    cudagraph: Optional[torch.cuda.CUDAGraph] = None
    output: Optional[Any] = None

324
325
326
327
    # for cudagraph debugging, track the input addresses
    # during capture, and check if they are the same during replay
    input_addresses: Optional[List[int]] = None

328
329
330

class PiecewiseBackend:

331
332
    def __init__(self, graph: fx.GraphModule,
                 compilation_configs: CompilationConfig, graph_pool: Any,
333
334
                 piecewise_compile_index: int, total_piecewise_compiles: int,
                 sym_shape_indices: List[int],
335
                 compiled_graph_for_general_shape: Callable):
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        """
        The backend for piecewise compilation.
        It mainly handles the compilation and cudagraph capturing.

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_configs.compile_sizes`.

        Independently, we will capture cudagraph for different shapes.

        If a shape needs both compilation and cudagraph, we will
        compile it first, and then capture cudagraph.
        """
        self.graph = graph
        self.compilation_configs = compilation_configs
        self.graph_pool = graph_pool
352
353
354
355
356
357
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles

        self.is_first_graph = piecewise_compile_index == 0
        self.is_last_graph = (
            piecewise_compile_index == total_piecewise_compiles - 1)
358
359
360
361
362
363
364
365
366

        self.compile_sizes: Set[int] = set(
            self.compilation_configs.compile_sizes)
        self.capture_sizes: Set[int] = set(
            self.compilation_configs.capture_sizes
        ) if self.compilation_configs.use_cudagraph else set()

        self.first_run_finished = False

367
        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa
368

369
        self.sym_shape_indices = sym_shape_indices
370

371
372
        self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        # the entries for different shapes that we need to either
        # compile or capture cudagraph
        self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
        for shape in self.compile_sizes.union(self.capture_sizes):
            self.concrete_size_entries[shape] = ConcreteSizeEntry(
                runtime_shape=shape,
                need_to_compile=shape in self.compile_sizes,
                use_cudagraph=shape in self.capture_sizes,
            )

    def __call__(self, *args) -> Any:
        if not self.first_run_finished:
            self.first_run_finished = True
            return self.compiled_graph_for_general_shape(*args)

        runtime_shape = args[self.sym_shape_indices[0]]
        if runtime_shape not in self.concrete_size_entries:
            # we don't need to do anything for this shape
            return self.compiled_graph_for_general_shape(*args)

        entry = self.concrete_size_entries[runtime_shape]
394

395
396
        if entry.runnable is None:
            entry.runnable = self.compiled_graph_for_general_shape
397

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        if entry.need_to_compile and not entry.compiled:
            entry.compiled = True
            # args are real arguments
            entry.runnable = wrap_inductor(
                self.graph,
                args,
                self.compilation_configs.inductor_compile_config,
                runtime_shape=runtime_shape,
                do_logging=self.is_first_graph,
                use_inductor=self.compilation_configs.use_inductor)

        if not entry.use_cudagraph:
            return entry.runnable(*args)

        if entry.cudagraph is None:
            if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups:  # noqa
                entry.num_finished_warmup += 1
                if self.is_first_graph:
                    logger.debug(
                        "Warming up %s/%s for shape %s",
                        entry.num_finished_warmup,
                        self.compilation_configs.cudagraph_num_of_warmups,
                        runtime_shape)
                return entry.runnable(*args)

            if self.is_first_graph:
424
425
426
427
428
                # Since we capture cudagraph for many different shapes and
                # capturing is fast, we don't need to log it for every shape.
                # We only log it in the debug mode.
                logger.debug("Capturing a cudagraph for shape %s",
                             runtime_shape)
429

430
431
432
433
            input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            entry.input_addresses = input_addresses
434
            cudagraph = torch.cuda.CUDAGraph()
435

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
            with ExitStack() as stack:
                if not self.is_first_graph:
                    # during every model forward, we will capture
                    # many pieces of cudagraphs (roughly one per layer).
                    # running gc again and again across layers will
                    # make the cudagraph capture very slow.
                    # therefore, we only run gc for the first graph,
                    # and disable gc for the rest of the graphs.
                    stack.enter_context(patch("gc.collect", lambda: None))
                    stack.enter_context(
                        patch("torch.cuda.empty_cache", lambda: None))

                # mind-exploding: carefully manage the reference and memory.
                with torch.cuda.graph(cudagraph, pool=self.graph_pool):
                    # `output` is managed by pytorch's cudagraph pool
                    output = entry.runnable(*args)
                    if self.is_last_graph:
                        # by converting it to weak ref,
                        # the original `output` will immediately be released
                        # to save memory. It is only safe to do this for
                        # the last graph, because the output of the last graph
                        # will not be used by any other cuda graph.
                        output = weak_ref_tensors(output)
459
460
461
462
463

            # here we always use weak ref for the output
            # to save memory
            entry.output = weak_ref_tensors(output)
            entry.cudagraph = cudagraph
464
465
466

            compilation_counter.num_cudagraph_caputured += 1

467
468
469
470
471
472
473
474
475
476
477
478
479
480
            # important: we need to return the output, rather than
            # the weak ref of the output, so that pytorch can correctly
            # manage the memory during cuda graph capture
            return output

        if self.is_debugging_mode:
            # check if the input addresses are the same
            new_input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            assert new_input_addresses == entry.input_addresses, (
                "Input addresses for cudagraphs are different during replay."
                f" Expected {entry.input_addresses}, got {new_input_addresses}"
            )
481
482
483

        entry.cudagraph.replay()
        return entry.output