backends.py 18.8 KB
Newer Older
1
import copy
2
import dataclasses
3
import operator
4
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
5
6
7
8

import torch
import torch.fx as fx

9
from vllm.logger import init_logger
10
from vllm.utils import weak_ref_tensors
11

12
13
from .config import CompilationConfig
from .counter import compilation_counter
14
15
16
17
from .levels import CompilationLevel

logger = init_logger(__name__)

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

def fix_functionalization(graph: fx.Graph):
    """
    Rewrite the graph module to replace the pattern involving
    torch._higher_order_ops.auto_functionalize.auto_functionalized
    with a direct call to the inplace custom op.

    # TODO: check if PyTorch nightly has fixed this issue
    """

    # debug code, if we want to see the graph before the transformation
    # with open("before.py", "w") as f:
    #     print(graph.python_code(root_module="self", verbose=True).src, file=f)

    nodes_to_remove = []

    for node in graph.nodes:
        # Identify the auto_functionalized node
        if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized:  # noqa
            if node.args[0] == torch.ops._C.rotary_embedding.default:
                # manual replace for rotary_embedding

                # Now, collect the arguments
                kwargs = node.kwargs

                query = kwargs['query']
                mm_node = query.args[0].args[0]

                # Create a new call to torch.ops._C.rotary_embedding.default
                with graph.inserting_before(node):
                    # just insert the call to the custom op
                    # NOTE: don't run dead code elimination,
                    # otherwise this op will be removed
                    graph.call_function(torch.ops._C.rotary_embedding.default,
                                        kwargs=kwargs)

                # Remove the auto_functionalized node
                # Since the node may have outputs, we need to handle its users
                # Replace uses of the outputs (getitem nodes) with mm_node
                for user in list(node.users):
                    if user.op == 'call_function' and user.target == operator.getitem:  # noqa
                        # Remove the getitem node
                        for getitem_user in list(user.users):
                            if (getitem_user.op == 'call_function'
                                    and getitem_user.target
                                    == torch.ops.aten.slice_scatter.default):
                                # Replace the uses of slice_scatter node
                                # with mm_node
                                getitem_user.replace_all_uses_with(mm_node)
                                nodes_to_remove.append(getitem_user)
                        nodes_to_remove.append(user)
                nodes_to_remove.append(node)

            elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
                # manual replace for fused_add_rms_norm
                # this is the most effective optimization for llama
                # failing to do this will result in many unnecessary copies

                kwargs = node.kwargs

                input = kwargs['input']
                residual = kwargs['residual']

                # Create a new call to torch.ops._C.rotary_embedding.default
                with graph.inserting_before(node):
                    # just insert the call to the custom op
                    # NOTE: don't run dead code elimination,
                    # otherwise this op will be removed
                    graph.call_function(
                        torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs)

                for user in list(node.users):
                    if user.op == 'call_function' and user.target == operator.getitem:  # noqa
                        # Remove the getitem node
                        if user.args[1] == 1:
                            replace_node = input
                        elif user.args[1] == 2:
                            replace_node = residual
                        user.replace_all_uses_with(replace_node)
                        nodes_to_remove.append(user)
                nodes_to_remove.append(node)

            elif node.args[0] == torch.ops._C.rms_norm.default:
                # manual replace for rms_norm

                kwargs = node.kwargs

                input = kwargs['input']
                out = kwargs['out']
                weight = kwargs['weight']
                epsilon = kwargs['epsilon']
                # Create a new call to torch.ops._C.rotary_embedding.default
                # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
                with graph.inserting_before(node):
                    # just insert the call to the custom op
                    # NOTE: don't run dead code elimination,
                    # otherwise this op will be removed
                    graph.call_function(
                        torch.ops._C.rms_norm.default,
                        args=(out, input, weight, epsilon),
                    )

                replace_node = out

                for user in list(node.users):
                    if user.op == 'call_function' and user.target == operator.getitem:  # noqa
                        user.replace_all_uses_with(replace_node)
                        nodes_to_remove.append(user)
                nodes_to_remove.append(node)

            elif node.args[0] == torch.ops._C.silu_and_mul.default:
                # manual replace for silu_and_mul

                kwargs = node.kwargs

                input = kwargs['input']
                out = kwargs['out']

                # Create a new call to torch.ops._C.rotary_embedding.default
                # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
                with graph.inserting_before(node):
                    # just insert the call to the custom op
                    # NOTE: don't run dead code elimination,
                    # otherwise this op will be removed
                    graph.call_function(
                        torch.ops._C.silu_and_mul.default,
                        args=(out, input),
                    )
                replace_node = out

                for user in list(node.users):
                    if user.op == 'call_function' and user.target == operator.getitem:  # noqa
                        user.replace_all_uses_with(replace_node)
                        nodes_to_remove.append(user)
                nodes_to_remove.append(node)

    # Remove the nodes all at once
    for node in nodes_to_remove:
        graph.erase_node(node)

    # debug code, if we want to see the graph after the transformation
    # with open("after.py", "w") as f:
    #     print(graph.python_code(root_module="self", verbose=True).src, file=f)


163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def wrap_inductor(graph,
                  example_inputs,
                  additional_inductor_config,
                  do_logging=False,
                  runtime_shape: Optional[int] = None,
                  use_inductor: bool = True):
    if not use_inductor:
        return graph

    compilation_counter.num_inductor_compilations += 1

    if do_logging:
        if runtime_shape is None:
            logger.info("Compiling a graph for general shape")
        else:
            logger.info("Compiling a graph for shape %s", runtime_shape)

180
181
182
    from torch._inductor import config
    current_config = config.shallow_copy_dict()
    from torch._inductor.compile_fx import compile_fx
183
184
185

    if additional_inductor_config is not None:
        current_config.update(additional_inductor_config)
186
187
188
189

    # inductor can inplace modify the graph, so we need to copy it
    # see https://github.com/pytorch/pytorch/issues/138980
    graph = copy.deepcopy(graph)
190
    return compile_fx(graph, example_inputs, config_patches=current_config)
191
192


193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
@dataclasses.dataclass
class SplitItem:
    submod_name: str
    is_splitting_graph: bool
    graph: fx.GraphModule


def split_graph(graph: fx.GraphModule,
                ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
        if node.op == 'call_function' and str(node.target) in ops:
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
222
        graph,
223
224
225
        None,
        lambda node: node_to_subgraph_id[node],
        keep_original_order=True)
226

227
    outputs = []
228

229
230
231
    # sort the names to make sure the order is deterministic
    names = [name for (name, module) in split_gm.named_modules()]
    names.sort()
232

233
234
235
236
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
237

238
        module = getattr(split_gm, name)
239

240
241
        graph_id = int(name.replace("submod_", ""))
        outputs.append(SplitItem(name, graph_id in split_op_graphs, module))
242

243
    return split_gm, outputs
244
245


246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
# we share the global graph pool among all the backends
global_graph_pool = None


class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
    """

    def __init__(self, module: torch.fx.GraphModule,
                 compile_submod_names: List[str],
                 compilation_configs: CompilationConfig, graph_pool):
        super().__init__(module)
        from torch._guards import detect_fake_mode
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
        self.compilation_configs = compilation_configs
        self.graph_pool = graph_pool
        self.have_seen_first_graph = False

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
        return super().run(*fake_args)

    def call_module(self, target: torch.fx.node.Target,
                    args: Tuple[torch.fx.node.Argument,
                                ...], kwargs: Dict[str, Any]) -> Any:
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
            compiled_graph_for_general_shape = wrap_inductor(
                submod,
                args,
                self.compilation_configs.inductor_compile_config,
                runtime_shape=None,
                do_logging=not self.have_seen_first_graph,
                use_inductor=self.compilation_configs.use_inductor)

            self.module.__dict__[target] = PiecewiseBackend(
                submod, self.compilation_configs, self.graph_pool,
                not self.have_seen_first_graph, sym_shape_indices,
                compiled_graph_for_general_shape)

            self.have_seen_first_graph = True
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


305
306
307
308
class VllmBackend:
    """The compilation backend for `torch.compile` with VLLM.
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
309

310
311
312
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
    """
313

314
315
316
317
318
319
320
321
322
323
324
    compilation_configs: CompilationConfig
    graph_pool: Any
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
    piecewise_graphs: List[SplitItem]
    returned_callable: Callable

    def __init__(self, ):
325
326
327
328
329
330
331
332
        global global_graph_pool
        if global_graph_pool is None:
            global_graph_pool = torch.cuda.graph_pool_handle()

        # TODO: in the future, if we want to use multiple
        # streams, it might not be safe to share a global pool.
        # only investigate this when we use multiple streams
        self.graph_pool = global_graph_pool
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353

        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

        compilation_counter.num_graphs_seen += 1

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
        # config is read now, because only here can
        # we get the sizes to capture for cudagraph
        # from compilation context
        self.compilation_configs = CompilationConfig.select_and_init_config()

        self.split_gm, self.piecewise_graphs = split_graph(
            graph, self.compilation_configs.non_cudagraph_ops)

354
355
356
        from torch._dynamo.utils import lazy_format_graph_code
        logger.debug("%s",
                     lazy_format_graph_code("stiching module", self.split_gm))
357

358
359
360
361
362
363
364
365
366
367
368
369
        compilation_counter.num_piecewise_graphs_seen += len(
            self.piecewise_graphs)
        submod_names_to_compile = [
            item.submod_name for item in self.piecewise_graphs
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
        PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
                                    self.compilation_configs,
                                    self.graph_pool).run(*example_inputs)
370
371
372

        self._called = True

373
        return self.split_gm
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390


@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int
    need_to_compile: bool  # the size is in compile_sizes
    use_cudagraph: bool  # the size is in capture_sizes

    compiled: bool = False
    runnable: Callable = None  # type: ignore
    num_finished_warmup: int = 0
    cudagraph: Optional[torch.cuda.CUDAGraph] = None
    output: Optional[Any] = None


class PiecewiseBackend:

391
392
393
394
    def __init__(self, graph: fx.GraphModule,
                 compilation_configs: CompilationConfig, graph_pool: Any,
                 is_first_graph: bool, sym_shape_indices: List[int],
                 compiled_graph_for_general_shape: Callable):
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        """
        The backend for piecewise compilation.
        It mainly handles the compilation and cudagraph capturing.

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_configs.compile_sizes`.

        Independently, we will capture cudagraph for different shapes.

        If a shape needs both compilation and cudagraph, we will
        compile it first, and then capture cudagraph.
        """
        self.graph = graph
        self.compilation_configs = compilation_configs
        self.graph_pool = graph_pool
        self.is_first_graph = is_first_graph

        self.compile_sizes: Set[int] = set(
            self.compilation_configs.compile_sizes)
        self.capture_sizes: Set[int] = set(
            self.compilation_configs.capture_sizes
        ) if self.compilation_configs.use_cudagraph else set()

        self.first_run_finished = False

421
        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa
422

423
        self.sym_shape_indices = sym_shape_indices
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445

        # the entries for different shapes that we need to either
        # compile or capture cudagraph
        self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
        for shape in self.compile_sizes.union(self.capture_sizes):
            self.concrete_size_entries[shape] = ConcreteSizeEntry(
                runtime_shape=shape,
                need_to_compile=shape in self.compile_sizes,
                use_cudagraph=shape in self.capture_sizes,
            )

    def __call__(self, *args) -> Any:
        if not self.first_run_finished:
            self.first_run_finished = True
            return self.compiled_graph_for_general_shape(*args)

        runtime_shape = args[self.sym_shape_indices[0]]
        if runtime_shape not in self.concrete_size_entries:
            # we don't need to do anything for this shape
            return self.compiled_graph_for_general_shape(*args)

        entry = self.concrete_size_entries[runtime_shape]
446

447
448
        if entry.runnable is None:
            entry.runnable = self.compiled_graph_for_general_shape
449

450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
        if entry.need_to_compile and not entry.compiled:
            entry.compiled = True
            # args are real arguments
            entry.runnable = wrap_inductor(
                self.graph,
                args,
                self.compilation_configs.inductor_compile_config,
                runtime_shape=runtime_shape,
                do_logging=self.is_first_graph,
                use_inductor=self.compilation_configs.use_inductor)

        if not entry.use_cudagraph:
            return entry.runnable(*args)

        if entry.cudagraph is None:
            if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups:  # noqa
                entry.num_finished_warmup += 1
                if self.is_first_graph:
                    logger.debug(
                        "Warming up %s/%s for shape %s",
                        entry.num_finished_warmup,
                        self.compilation_configs.cudagraph_num_of_warmups,
                        runtime_shape)
                return entry.runnable(*args)

            if self.is_first_graph:
                logger.info("Capturing a cudagraph for shape %s",
                            runtime_shape)

            cudagraph = torch.cuda.CUDAGraph()
            with torch.cuda.graph(cudagraph, pool=self.graph_pool):
                entry.output = weak_ref_tensors(entry.runnable(*args))

            compilation_counter.num_cudagraph_caputured += 1

            entry.cudagraph = cudagraph
            return entry.output

        entry.cudagraph.replay()
        return entry.output
490
491
492
493


def select_default_backend(level: int) -> Union[str, Callable]:
    if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
494
495
        backend_str = "eager"
        return backend_str
496
497
498
    assert level == CompilationLevel.PIECEWISE

    return VllmBackend()