speed_benchmark.py 4.93 KB
Newer Older
Mitchell Wortsman's avatar
Mitchell Wortsman committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import json

import time
import torch
import torch.nn as nn

from bitsandbytes.nn.triton_utils.v0.quantize_rowwise import quantize_rowwise
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.nn.triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze

# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.

def get_time(k, fn, info_dict):

    for _ in range(repeat // 2):
       fn()

    torch.cuda.synchronize()
    start = time.time()
    for _ in range(repeat):
       fn()

    torch.cuda.synchronize()
    end = time.time()
    ms = (end - start) / repeat * 1000
    print(f"time {k}: {ms:.3f} ms")
    info_dict[k] = ms

if __name__ == '__main__':
    torch.manual_seed(0)
    wm = 4
    for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
        # note "batch_size" is actually "batch_size * embed_dim", which is why it's large
        for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
            
            # switch switches dim_in and dim_out
            for switch in [False, True]:

                # hparams
                repeat = 64
                batch_size = batch_size
                dim_out = dim * wm
                dim_in = dim
                if switch:
                    dim_out = dim
                    dim_in = wm * dim

                dim_in = round(dim_in)
                dim_out = round(dim_out)

                # simulate forward pass
                x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
                g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
                w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
                
                x_int8 = x.clone().to(torch.int8)
                g_int8 = g.clone().to(torch.int8)
                w_int8 = w.clone().to(torch.int8)
                wt_int8 = w.t().contiguous().clone().to(torch.int8)
                state_x_rowwise = x.max(dim=1)[0]
                state_g_rowwise = g.max(dim=1)[0]
                state_w_columnwise = w.max(dim=0)[0]
                state_w_rowwise = w.max(dim=1)[0]
                state_w_global = w.max()

                info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}

                get_time('standard_fwd', lambda : x.matmul(w.t()), info)
                get_time('standard_gw', lambda : g.t().matmul(x), info)
                get_time('standard_gx', lambda : g.matmul(w), info)
                get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
                get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
                get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
                get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
                get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
                get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
                get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
                get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
                get_time('w_quantize_global', lambda : quantize_global(w), info)
                get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)

                time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
                time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise']  + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
                time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']

                print('TOTAL STANDARD', time_standard)
                print('TOTAL ROWWISE', time_rowwise)
                print('TOTAL GLOBAL', time_global)

                print('speedup', -100*(time_global - time_standard)/time_standard)

                info['time_standard'] = time_standard
                info['time_rowwise'] = time_rowwise
                info['time_global'] = time_global

                info_json = json.dumps(info)

                with open("speed_benchmark/info_a100_py2.jsonl", "a") as file:
                    file.write(info_json + "\n")