libentry.py 6.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Copied From https://github.com/FlagOpen/FlagGems

import inspect

import triton


class LibEntry(triton.KernelInterface):

    def __init__(
        self,
        fn,
    ):
        self.fn = fn
        self.arg_names = fn.arg_names
        self.divisibility = 16
        self.kernel_cache = dict()
        fn = self.fn
        while not isinstance(fn, triton.runtime.JITFunction):
            fn = fn.fn
        self.jit_function: triton.runtime.JITFunction = fn
        self.specialize_indices = [
            p.num for p in self.jit_function.params
            if not p.is_constexpr and not p.do_not_specialize
        ]
        self.do_not_specialize_indices = [
            p.num for p in self.jit_function.params
            if not p.is_constexpr and p.do_not_specialize
        ]

    def key(self, spec_args, dns_args, const_args):
        spec_key = [(arg.dtype, arg.data_ptr() %
                     self.divisibility == 0) if hasattr(arg, "data_ptr") else
                    (type(arg), arg) for arg in spec_args]
        dns_key = [
            arg.dtype if hasattr(
                arg, "data_ptr") else type(arg) if not isinstance(arg, int)
38
39
            else "i32" if arg >= -(2**31) and arg <= 2**31 -
            1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64"
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
            for arg in dns_args
        ]
        # const args passed by position
        return tuple(spec_key + dns_key + const_args)

    def run(self, *args, **kwargs):
        grid = kwargs["grid"]
        # collect all the arguments
        spec_args = []  # specialize arguments
        dns_args = []  # do not specialize arguments
        const_args = []  # constexpr arguments
        k_args = []  # kernel arguments
        for i, arg in enumerate(args):
            if i in self.specialize_indices:
                k_args.append(arg)
                spec_args.append(arg)
            elif i in self.do_not_specialize_indices:
                k_args.append(arg)
                dns_args.append(arg)
            else:
                const_args.append(arg)
        for p in self.jit_function.params[len(args):]:
            if p.name in kwargs:
                val = kwargs[p.name]
            elif p.default is inspect._empty:
                continue
            else:
                val = p.default

            if p.is_constexpr:
                const_args.append(val)
            elif p.do_not_specialize:
                dns_args.append(val)
                k_args.append(val)
            else:
                spec_args.append(val)
                k_args.append(val)

        entry_key = self.key(spec_args, dns_args, const_args)

        if entry_key not in self.kernel_cache:
            # compile the kernel also completes the related computations
            kernel = self.fn.run(*args, **kwargs)
            fn = self.fn
            # collect constexpr arguments for grid computation
            constexprs = {}
            while not isinstance(fn, triton.runtime.JITFunction):
                if isinstance(fn, triton.runtime.Autotuner):
                    config = fn.best_config
                    constexprs["num_warps"] = config.num_warps
                    constexprs["num_stages"] = config.num_stages
                    constexprs["num_ctas"] = config.num_ctas
                    constexprs = {**constexprs, **config.kwargs}
                elif isinstance(fn, triton.runtime.Heuristics):
                    for v, heur in fn.values.items():
                        constexprs[v] = heur({
                            **dict(zip(fn.arg_names, args)),
                            **kwargs,
                            **constexprs,
                        })
                else:
                    raise RuntimeError("Invalid Runtime Function")
                fn = fn.fn
            # In vLLM, certain kernels like fused_moe_kernel get the
            # best_config(as kwargs) from a configuration json file, rather
            # than using Autotuner & Heuristics. Therefore, all their constexprs
            # (tl.constexpr) are assigned values through the following loop.
            for p in self.jit_function.params:
                if p.is_constexpr and p.name not in constexprs:
                    constexprs[p.name] = p.default  #default=inspect._empty
            self.kernel_cache[entry_key] = (kernel, constexprs)
        else:
            # load kernel from cache directly
            kernel, constexprs = self.kernel_cache[entry_key]

            if callable(grid):
                # collect all arguments to the grid fn,ie:
                # 1. args,
                # 2. kwargs,
                # 3. all all other captured arguments in CompiledKernel from
                # Autotunner & Heuristics when kwargs & captured args conflict,
                # captured args have higher priority
                # 4. We must filter out captured args with default value firstly
                constexprs = {
                    k: v
                    for k, v in constexprs.items() if v is not inspect._empty
                }
                meta = {
                    **dict(zip(self.arg_names, args)),
                    **kwargs,
                    **constexprs,
                }
                grid = grid(meta)
            if isinstance(grid, tuple):
                grid = grid + (1, 1)
            elif isinstance(grid, list):
                grid = grid + [1, 1]
            kernel[grid[0:3]](*k_args)
        # maintaining the same return type as the JITFunction.run
        return kernel


def libentry():
    """
    Decorator for triton library entries.
    Motivation:
        The runtime overhead of Triton kernels is the reason for the lower 
        performance of small kernels, particularly evident with smaller models. 
        Using this decorator can reduce Triton runtime overhead.
    How:
        The `run` function of JITFunction needs to accomplish:
            - Parameter binding using inspect
            - KernelArg type wrapping
            - Cache key calculation
        When dealing with small size, these steps can become bottlenecks in 
        Triton runtime. Libentry simplifies these steps to reduce runtime 
        overhead, thereby improving the runtime expenses of small kernels.
    NOTE:
        When Triton is upgraded to version 3.0.0, libentry can be removed,
        see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245
        

    """

    def decorator(fn):
        return LibEntry(fn)

    return decorator