benchmark_embedding.py 4 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import triton

from torch.nn import Embedding
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.transformers.experimental.embedding import LigerEmbedding
from liger_kernel.utils import infer_device

device = infer_device()

# NOTE: For torch compile, we will just use default inductor settings. No further customization
# is needed.


def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
    V = input.x
    provider = input.kernel_provider
    mode = input.kernel_operation_mode

    B = input.extra_benchmark_config["B"]
    T = input.extra_benchmark_config["T"]
    D = input.extra_benchmark_config["D"]
    dtype = input.extra_benchmark_config["dtype"]

    torch_emb = Embedding(V, D).to(device).to(dtype)
    liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
    torch_compile_emb = torch.compile(torch_emb)

    input_ids = torch.randint(0, V, (B, T), device=device)

    def fwd():
        if provider == "liger":
            return liger_emb(input_ids)
        elif provider == "torch_compile":
            return torch_compile_emb(input_ids)
        else:
            return torch_emb(input_ids)

    def full():
        output = fwd()
        output.backward(torch.randn_like(output))

    if mode == "forward":
        ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
    elif mode == "backward":
        output = fwd()
        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            lambda: output.backward(torch.randn_like(output), retain_graph=True),
            quantiles=QUANTILES,
            grad_to_none=[input_ids],
            rep=100,
        )
    elif mode == "full":
        ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
    return SingleBenchmarkRunOutput(
        y_20=ms_20,
        y_50=ms_50,
        y_80=ms_80,
    )


def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
    V = input.x
    provider = input.kernel_provider

    B = input.extra_benchmark_config["B"]
    T = input.extra_benchmark_config["T"]
    D = input.extra_benchmark_config["D"]
    dtype = input.extra_benchmark_config["dtype"]

    torch_emb = Embedding(V, D).to(device).to(dtype)
    liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
    torch_compile_emb = torch.compile(torch_emb)

    input_ids = torch.randint(0, V, (B, T), device=device)

    def fwd():
        if provider == "liger":
            return liger_emb(input_ids)
        elif provider == "torch_compile":
            return torch_compile_emb(input_ids)
        else:
            return torch_emb(input_ids)

    def full():
        output = fwd()
        output.backward(torch.randn_like(output))

    mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
    return SingleBenchmarkRunOutput(
        y_20=mem_20,
        y_50=mem_50,
        y_80=mem_80,
    )


if __name__ == "__main__":
    args = parse_benchmark_script_args()

    common_configs = {
        "kernel_name": "embedding",
        "x_name": "V",
        "x_label": "embedding dimension",
        "x_values": [2**i for i in range(10, 18)],
        "kernel_providers": ["liger", "huggingface", "torch_compile"],
        "extra_benchmark_configs": [
            # BERT
            {"B": 32, "T": 512, "D": 768, "dtype": torch.float32},
            # Llama
            {"B": 8, "T": 2048, "D": 4096, "dtype": torch.float32},
        ],
        "overwrite": args.overwrite,
    }

    run_benchmarks(
        bench_test_fn=bench_speed_embedding,
        kernel_operation_modes=["forward", "backward", "full"],
        metric_name="speed",
        metric_unit="ms",
        **common_configs,
    )
    run_benchmarks(
        bench_test_fn=bench_memory_embedding,
        kernel_operation_modes=["full"],
        metric_name="memory",
        metric_unit="MB",
        **common_configs,
    )