benchmark_tiled_mlp.py 14 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
import math

import torch
import torch.nn as nn
import triton

from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
from liger_kernel.utils import infer_device

device = infer_device()


# DeepSpeed TiledMLP implementation
# Based on: https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
class DeepSpeedTiledMLP(torch.autograd.Function):
    """
    DeepSpeed's TiledMLP implementation for fair comparison.
    This is the actual DeepSpeed algorithm that performs tiled MLP computation
    to massively reduce memory usage with very long sequence lengths.

    This module re-computes forward in the backward, so forward occurs twice per iteration.
    """

    @staticmethod
    def forward(ctx, fn, self, x, shards, compute_params) -> torch.Tensor:
        ctx.fn = fn
        ctx.self = self
        ctx.shards = shards
        ctx.compute_params = [p for p in compute_params if p.requires_grad] if compute_params else []
        ctx.save_for_backward(x)

        # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
        x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
        with torch.no_grad():
            output_shards = [fn(self, x_shard) for x_shard in x_shards]
        output_unsharded = torch.cat(output_shards, dim=-2)

        return output_unsharded

    @staticmethod
    def backward(ctx, *grads):
        fn = ctx.fn
        (x,) = ctx.saved_tensors
        self = ctx.self
        shards = ctx.shards
        compute_params = ctx.compute_params

        x_requires_grad = x.requires_grad
        x = x.detach()
        # detach() unsets x.requires_grad, so restore it
        x.requires_grad_(x_requires_grad)

        # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
        hidden_size = x.shape[-1]
        x_shape_orig = x.shape

        # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
        x = x.view(-1, hidden_size)
        incoming_grad = grads[0].view(-1, hidden_size)
        x_grad = torch.zeros_like(x)

        x_shards = list(torch.chunk(x, chunks=shards, dim=0))

        for i, x_shard in enumerate(x_shards):
            # Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
            # XXX: DDP, FSDP will need something similar to make it work
            if compute_params:
                if i + 1 < shards:
                    for param in compute_params:
                        if hasattr(param, "ds_grad_is_ready"):
                            param.ds_grad_is_ready = False
                else:
                    # last shard, can add the grad
                    for param in compute_params:
                        if hasattr(param, "ds_grad_is_ready"):
                            param.ds_grad_is_ready = True

            x_shard.requires_grad_(x_requires_grad)

            # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
            shard_step = x_shards[i].shape[0]
            shard_offset = i * x_shards[0].shape[0]

            x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
            incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
            with torch.enable_grad():
                output = fn(self, x_shard)
            torch.autograd.backward(output, incoming_grad_shard)

        # unflatten
        x_grad = x_grad.view(x_shape_orig)

        return (None, None, x_grad, None, None)


# DeepSpeed TiledMLP wrapper to match our interface
class DeepSpeedTiledMLPWrapper(nn.Module):
    """
    Wrapper for DeepSpeed's TiledMLP to match the interface used in benchmarks.
    Uses the DeepSpeed TiledMLP algorithm for memory-efficient MLP computation.
    """

    def __init__(self, config, num_shards=None):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.num_shards = num_shards

        self.mlp = LlamaMLP(config=config)

    def forward(self, x):
        # Calculate num_shards if not provided
        num_shards = self.num_shards
        if num_shards is None:
            hidden_size = x.shape[-1]
            seqlen = x.shape[-2]
            num_shards = math.ceil(seqlen / hidden_size)
        num_shards = max(1, num_shards)

        # Collect compute parameters for DeepSpeed ZeRO compatibility
        compute_params = [
            self.mlp.down_proj.weight,
            self.mlp.gate_proj.weight,
            self.mlp.up_proj.weight,
        ]

        # Define the MLP forward function for DeepSpeed TiledMLP
        def mlp_forward(mlp_module, x_input):
            return mlp_module.down_proj(mlp_module.act_fn(mlp_module.gate_proj(x_input)) * mlp_module.up_proj(x_input))

        # Use DeepSpeed's TiledMLP implementation
        return DeepSpeedTiledMLP.apply(
            mlp_forward,
            self.mlp,
            x,
            num_shards,
            compute_params,
        )


def bench_speed_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
    seq_len = input.x
    bsz = input.extra_benchmark_config["bsz"]
    hidden_size = input.extra_benchmark_config["hidden_size"]
    intermediate_size = input.extra_benchmark_config["intermediate_size"]
    hidden_act = input.extra_benchmark_config["hidden_act"]
    dtype = input.extra_benchmark_config["dtype"]
    num_shards = input.extra_benchmark_config.get("num_shards", None)
    activation_type = input.extra_benchmark_config["activation_type"]
    provider = input.kernel_provider
    mode = input.kernel_operation_mode

    llama_config = LlamaConfig(
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        hidden_act=hidden_act,
    )

    x_shape = (bsz, seq_len, hidden_size)

    # initialize input
    x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)

    if activation_type == "geglu":
        if provider == "huggingface":
            layer = LlamaMLP(config=llama_config).to(device).to(dtype)
        elif provider == "liger":
            layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype)
        elif provider == "liger_tiled":
            layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
        elif provider == "deepspeed_tiled":
            layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype)
        else:
            raise ValueError(f"Invalid provider: {provider} for GEGLU")
    elif activation_type == "swiglu":
        if provider == "huggingface":
            layer = LlamaMLP(config=llama_config).to(device).to(dtype)
        elif provider == "liger":
            layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype)
        elif provider == "liger_tiled":
            layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
        elif provider == "deepspeed_tiled":
            layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype)
        else:
            raise ValueError(f"Invalid provider: {provider} for SwiGLU")
    else:
        raise ValueError(f"Invalid activation_type: {activation_type}")

    def fwd():
        return layer(x)

    if mode == "forward":
        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            fwd,
            grad_to_none=[x],
            rep=10,
            quantiles=QUANTILES,
        )
    elif mode == "backward":
        do = torch.randn_like(x)
        y = fwd()
        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            lambda: y.backward(do, retain_graph=True),
            grad_to_none=[x],
            rep=10,
            quantiles=QUANTILES,
        )
    else:

        def full():
            y = fwd()
            y.backward(torch.randn_like(y), retain_graph=True)

        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            full,
            grad_to_none=[x],
            rep=10,
            quantiles=QUANTILES,
        )

    return SingleBenchmarkRunOutput(
        y_20=ms_20,
        y_50=ms_50,
        y_80=ms_80,
    )


def bench_memory_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
    seq_len = input.x
    bsz = input.extra_benchmark_config["bsz"]
    hidden_size = input.extra_benchmark_config["hidden_size"]
    intermediate_size = input.extra_benchmark_config["intermediate_size"]
    hidden_act = input.extra_benchmark_config["hidden_act"]
    dtype = input.extra_benchmark_config["dtype"]
    num_shards = input.extra_benchmark_config.get("num_shards", None)
    activation_type = input.extra_benchmark_config["activation_type"]
    provider = input.kernel_provider
    mode = input.kernel_operation_mode

    llama_config = LlamaConfig(
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        hidden_act=hidden_act,
    )

    x_shape = (bsz, seq_len, hidden_size)
    # initialize input
    x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)

    if activation_type == "geglu":
        if provider == "huggingface":
            layer = LlamaMLP(config=llama_config).to(device).to(dtype)
        elif provider == "liger":
            layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype)
        elif provider == "liger_tiled":
            layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
        elif provider == "deepspeed_tiled":
            layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype)
        else:
            raise ValueError(f"Invalid provider: {provider} for GEGLU")
    elif activation_type == "swiglu":
        if provider == "huggingface":
            layer = LlamaMLP(config=llama_config).to(device).to(dtype)
        elif provider == "liger":
            layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype)
        elif provider == "liger_tiled":
            layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
        elif provider == "deepspeed_tiled":
            layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype)
        else:
            raise ValueError(f"Invalid provider: {provider} for SwiGLU")
    else:
        raise ValueError(f"Invalid activation_type: {activation_type}")

    def fwd():
        return layer(x)

    def full():
        y = fwd()
        y.backward(torch.randn_like(y), retain_graph=True)

    if mode == "forward":
        mem_50, mem_20, mem_80 = _test_memory(
            fwd,
            quantiles=QUANTILES,
        )
    elif mode == "backward":
        do = torch.randn_like(x)
        y = fwd()
        mem_50, mem_20, mem_80 = _test_memory(
            lambda: y.backward(do, retain_graph=True),
            quantiles=QUANTILES,
        )
    else:
        mem_50, mem_20, mem_80 = _test_memory(
            full,
            quantiles=QUANTILES,
        )

    return SingleBenchmarkRunOutput(
        y_20=mem_20,
        y_50=mem_50,
        y_80=mem_80,
    )


if __name__ == "__main__":
    args = parse_benchmark_script_args()

    # Benchmark GEGLU variants
    kernel_providers_geglu = ["huggingface", "liger", "liger_tiled", "deepspeed_tiled"]

    common_configs_geglu = {
        "kernel_name": "tiled_geglu",
        "x_name": "T",
        "x_label": "sequence length",
        "x_values": [2**i for i in range(10, 15)],  # 1024 to 16384
        "kernel_providers": kernel_providers_geglu,
        "extra_benchmark_configs": [
            {
                "bsz": 2,
                "hidden_size": 2048,
                "intermediate_size": 4096,
                "hidden_act": "gelu_pytorch_tanh",
                "activation_type": "geglu",
                "num_shards": 4,
                "dtype": torch.bfloat16,
            }
        ],
        "overwrite": args.overwrite,
    }

    run_benchmarks(
        bench_test_fn=bench_speed_tiled_mlp,
        kernel_operation_modes=["full", "forward", "backward"],
        metric_name="speed",
        metric_unit="ms",
        **common_configs_geglu,
    )
    run_benchmarks(
        bench_test_fn=bench_memory_tiled_mlp,
        kernel_operation_modes=["full", "forward", "backward"],
        metric_name="memory",
        metric_unit="MB",
        **common_configs_geglu,
    )

    # Benchmark SwiGLU variants
    kernel_providers_swiglu = ["huggingface", "liger", "liger_tiled", "deepspeed_tiled"]

    common_configs_swiglu = {
        "kernel_name": "tiled_swiglu",
        "x_name": "T",
        "x_label": "sequence length",
        "x_values": [2**i for i in range(10, 15)],  # 1024 to 16384
        "kernel_providers": kernel_providers_swiglu,
        "extra_benchmark_configs": [
            {
                "bsz": 2,
                "hidden_size": 2048,
                "intermediate_size": 4096,
                "hidden_act": "silu",
                "activation_type": "swiglu",
                "num_shards": 4,
                "dtype": torch.bfloat16,
            }
        ],
        "overwrite": args.overwrite,
    }

    run_benchmarks(
        bench_test_fn=bench_speed_tiled_mlp,
        kernel_operation_modes=["full", "forward", "backward"],
        metric_name="speed",
        metric_unit="ms",
        **common_configs_swiglu,
    )
    run_benchmarks(
        bench_test_fn=bench_memory_tiled_mlp,
        kernel_operation_modes=["full", "forward", "backward"],
        metric_name="memory",
        metric_unit="MB",
        **common_configs_swiglu,
    )