benchmark_kl_div.py 3.4 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn
import triton

from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.transformers.kl_div import LigerKLDIVLoss
from liger_kernel.utils import infer_device

device = infer_device()

S, E = 12, 18


def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
    reduction = "batchmean"
    V = input.x
    B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
    torch_kl_div = nn.KLDivLoss(reduction=reduction)
    liger_kl_div = LigerKLDIVLoss(reduction=reduction)

    _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1)
    target = torch.randn(B * T, V, device=device).softmax(dim=-1)

    def fwd():
        if input.kernel_provider == "liger":
            return liger_kl_div(_input, target)
        else:
            return torch_kl_div(_input, target)

    if input.kernel_operation_mode == "forward":
        ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
    elif input.kernel_operation_mode == "backward":
        y = fwd()

        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            lambda: y.backward(retain_graph=True),
            quantiles=QUANTILES,
            grad_to_none=[_input],
            rep=100,
        )
    elif input.kernel_operation_mode == "full":

        def full():
            y = fwd()
            y.backward(retain_graph=True)

        ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
    return SingleBenchmarkRunOutput(
        y_20=ms_20,
        y_50=ms_50,
        y_80=ms_80,
    )


def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
    reduction = "batchmean"
    torch_kl_div = nn.KLDivLoss(reduction=reduction)
    liger_kl_div = LigerKLDIVLoss(reduction=reduction)

    V = input.x
    B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]

    _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1)
    target = torch.randn(B * T, V, device=device).softmax(dim=-1)

    def fwd():
        if input.kernel_provider == "liger":
            return liger_kl_div(_input, target)
        else:
            return torch_kl_div(_input, target)

    def full():
        y = fwd()
        y.backward(retain_graph=True)

    mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)

    return SingleBenchmarkRunOutput(
        y_20=mem_20,
        y_50=mem_50,
        y_80=mem_80,
    )


if __name__ == "__main__":
    args = parse_benchmark_script_args()
    common_args = {
        "kernel_name": "kl_div",
        "x_name": "V",
        "x_label": "vocab size",
        "x_values": [2**i for i in range(12, 18)],
        "kernel_providers": ["liger", "torch"],
        "extra_benchmark_configs": [{"B": 8, "T": 512}],
        "overwrite": args.overwrite,
    }

    run_benchmarks(
        bench_test_fn=bench_memory_kldiv,
        kernel_operation_modes=["full"],
        metric_name="memory",
        metric_unit="MB",
        **common_args,
    )

    run_benchmarks(
        bench_test_fn=bench_speed_kldiv,
        kernel_operation_modes=["forward", "backward", "full"],
        metric_name="speed",
        metric_unit="ms",
        **common_args,
    )