fakequant.py 4.44 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from torch.nn import functional as F
from dump_flux import group_scale

def compare(
        ref: torch.Tensor, 
        v: torch.Tensor, 
        refname: str, 
        vname: str, 
        list_diff: bool = False):
    
    print(f"== COMPARE v={vname} vs ref={refname}")
    diff = v - ref
    print(f" - diff = {diff}")

    if list_diff:
        print(f" - diffs at {diff.nonzero()}")

    mse = diff.square().mean()
    print(f" - mse = {mse}")
    nmse = mse / ref.square().mean()
    print(f" - nmse = {nmse}")

    print(f" - mean(v/ref)={v.mean()}/{ref.mean()}")
    print(f" - var(v/ref)={v.var()}/{ref.var()}")
    print(f"== ")
    print()

def print_debug_results(debug_results: dict[str, torch.Tensor], is_ref: bool = False):
    ref = 'REF' if is_ref else ''
    for k, v in debug_results.items():
        has_nan = v.isnan().any()
        has_inf = v.isinf().any()
        
        if v.dtype.is_floating_point:
            print(f" {ref} {k}: {v.shape} ({v.dtype}) has_nan={has_nan} has_inf={has_inf} max={v.max()} min={v.min()} mean={v.mean()} var={v.var()}")
        else:
            print(f" {ref} {k}: {v.shape} ({v.dtype})")
        if has_nan:
            cnt = v.isnan().count_nonzero()
            print(f" {ref} -- {cnt} ({cnt / v.numel() * 100}%) nans at {v.isnan().nonzero()[0:10]}")
        if has_inf:
            cnt = v.isinf().count_nonzero()
            print(f" {ref} -- {cnt} ({cnt / v.numel() * 100}%) infs at {v.isinf().nonzero()[0:10]}")
        print(f" {ref} -- {v}")
        print()

def fakequant(
        act: torch.Tensor,
        weight: torch.Tensor, 
        bias: torch.Tensor,
        group_size: int = 64,
        force_cuda: bool = False,
        ):
    
    oc, ic = weight.shape
    batch_size = act.shape[0]
    assert act.shape[1] == ic

    # [oc, ic // group_size]
    wscales = group_scale(weight, num_bits=4, group_size=group_size)
    qweight = weight.reshape(oc, ic // group_size, group_size).to(dtype=torch.float32) / wscales[..., None]
    qweight = qweight.round().clamp(-8, 7)
    qweight_i = qweight.int()
    qweight = qweight * wscales[..., None]
    qweight = qweight.to(weight.dtype)
    qweight = qweight.reshape(oc, ic)
    # print(f"qweight = {qweight}")
    print_debug_results({"qweight": qweight})

    # [batch_size, ic // group_size]
    ascales = group_scale(act, num_bits=4, group_size=group_size).to(dtype=weight.dtype)
    qact = act.reshape(batch_size, ic // group_size, group_size).to(dtype=torch.float32) / ascales[..., None]
    qact = qact.round().clamp(-8, 7)
    qact_i = qact.int()
    print_debug_results({"qact_i": qact_i})
    qact = qact * ascales[..., None]
    qact = qact.to(act.dtype)
    qact = qact.reshape(batch_size, ic)
    # print(f"qact = {qact}")
    print_debug_results({"qact": qact})

    outref_q = F.linear(qact.to(qweight.dtype), qweight, bias)
    # print(f"outref_q = {outref_q}")
    print_debug_results({"outref_q": outref_q})

    ###

    if force_cuda:
        qweight_i = qweight_i.to("cuda")
        qact_i = qact_i.to("cuda")
        wscales = wscales.to("cuda")
        ascales = ascales.to("cuda")
        bias = bias.to("cuda")

    qweight = qweight_i
    qact = qact_i
    qweight = qweight.reshape(oc, ic // group_size, group_size).transpose(0, 1).transpose(1, 2)
    qact = qact.reshape(batch_size, ic // group_size, group_size).transpose(0, 1)

    # [ic // group_size, batch_size, oc]
    psum = torch.bmm(qact.float(), qweight.float())
    print(f"psum_i ({psum.shape}) = {psum} ")
    # print(psum[:, 0, 23])

    # print(f"ascales = {ascales}")
    print_debug_results({"ascales": ascales})
    print(f"ascales[0:16] = {ascales[0:16, 0]}")

    ws1 = wscales.transpose(0, 1).reshape(ic // group_size, 1, oc).repeat(1, batch_size, 1)
    as1 = ascales.transpose(0, 1).reshape(ic // group_size, batch_size, 1).repeat(1, 1, oc)
    scales = ws1 * as1

    print(f"scales = {scales}")
    # print(scales[:, 0, 23])

    psum = psum.to(dtype=act.dtype).float()
    psum = psum * scales
    print(f"psum ({psum.shape}) = {psum} ")
    # print(psum[:, 0, 23])

    # outref_q2 = psum.sum(dim=0) # .to(layer.weight.dtype)
    outref_q2 = torch.zeros_like(psum[0])
    for i in range(psum.shape[0]):
        outref_q2 = (outref_q2 + psum[i]).to(act.dtype)
    outref_q2 += bias[None, ...]
    # print(f"outref_q2 = {outref_q2}")
    print_debug_results({"outref_q2": outref_q2})

    # print(outref_q2[0, 23])

    if force_cuda:
        outref_q2 = outref_q2.to(act.device)

    return outref_q, outref_q2