dequantize_utils.py 6.36 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch


def torch_convert_bit_twiddling(tensor):
    """
    This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.

    Parameters:
        tensor (torch.Tensor): 2-D input tensor with dtype `torch.uint8`. Shape (N, K).

    Returns:
        torch.Tensor: New tensor of dtype `torch.bfloat16` with shape (N, K*2), where each input column pair produces two bf16 output columns.

    Raises:
        AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
    """
    assert tensor.dim() == 2 and tensor.dtype == torch.uint8
    N, K = tensor.shape
    assert K % 2 == 0, "Number of columns must be even"

    # Combine pairs of uint8 values into uint32 for safe bitwise ops on CUDA
    val0 = tensor[:, 0::2].to(torch.int32)
    val1 = tensor[:, 1::2].to(torch.int32)
    val_concat = (val0 << 8) | val1  # (N, K//2), uint32

    # Expand to match output shape where each pair generates 4 values
    val_concat_expanded = val_concat.repeat_interleave(4, dim=1)  # (N, K//2*4)

    # Positional encoding for bit-twiddling logic
    pos = torch.arange(K * 2, device=tensor.device) % 4  # (K*2,)

    # Bit masks for decoding (as uint32 for CUDA compatibility)
    mask = 0b1000000111000000
    mask1 = 0b1000000000000000
    mask2 = 0b0000000110000000
    mask3 = 0b0000000001000000

    # Calculate results for all 4 positions in parallel
    res0 = val_concat_expanded & mask
    res1 = (val_concat_expanded << 3) & mask
    res2 = (val_concat_expanded << 6) & mask
    res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | (
        (val_concat_expanded >> 7) & mask3)

    # Select the correct result based on position
    bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1,
                                                   torch.where(pos == 2, res2, res3)))

    # Convert to uint16 for .view(torch.bfloat16)
    bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
    bf16_bf16 = bf16_uint16.view(torch.bfloat16)

    # Avoid integer overflow by using a float32 multiplier for the exponent scaling
    bf16_new = bf16_bf16 * (2.0**126)

    return bf16_new


def torch_convert(tensor, scale_size=None, Scale=None):
    """
    Decode a 2D uint8 tensor into a 2D bfloat16 tensor by expanding each byte into two bf16 values using a 4-bit (nibble) encoding.

    Each input byte holds two 4-bit encoded values (low and high nibble). For each nibble this function derives sign/scale bits, a 3-bit exponent fragment and a 1-bit mantissa fragment, assembles a 16-bit bf16 pattern, and returns the resulting tensor with shape (N, K*2) and dtype torch.bfloat16 on the same device as the input.

    Parameters:
        tensor (torch.Tensor): 2D tensor of dtype torch.uint8 and shape (N, K). Each byte contains two encoded 4-bit entries that become two bf16 values.
        scale_size (int, optional): If provided, controls how elements of the optional Scale tensor are indexed. When supplied, per-output-element scaling is applied to the exponent using Scale.
        Scale (torch.Tensor, optional): A 2D tensor used to supply per-element integer scale adjustments to the exponent. If scale_size is provided, the scale used for output element (i, j) is Scale[i][j // scale_size].

    Returns:
        torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values.
    """

    def _convert(val, pos, scale=None):
        assert val.dtype == torch.uint8
        # val = val.view(torch.int8)
        mask = (1 << 4) - 1
        f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
        s = f4 >> 3
        e_f4 = (f4 & 6) >> 1
        e_f16 = e_f4 + 126
        if scale is not None:
            e_f16 = min(e_f16 + scale, (1 << 8) - 1)
        m_f4 = f4 & 1
        m_f16 = m_f4
        val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF
        lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
        return lower_16_bits.view(torch.bfloat16)

    N = tensor.shape[0]
    K = tensor.shape[1]
    new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
    for i in range(new_tensor.shape[0]):
        for j in range(new_tensor.shape[1]):
            if scale_size is not None:
                new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2, Scale[i][j // scale_size])
            else:
                new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
    return new_tensor


def print_bit(name, val):
    """
    Print the 32-bit binary representation of a CPU scalar extracted from a PyTorch tensor.

    Converts `val` to CPU, reads its Python scalar with `.item()`, formats it as a 32-bit binary string, and prints it prefixed by `name`.

    Parameters:
        name (str): Label printed before the binary representation.
        val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown.
    """
    val_cpu = val.cpu().item()
    binary_repr = f'{val_cpu:032b}'
    print(name, binary_repr)


def print_red_warning(message):
    print(f"\033[31mWARNING: {message}\033[0m")


def calc_sim(x, y, name="tensor"):
    x, y = x.data.double(), y.data.double()
    denominator = (x * x + y * y).sum()
    if denominator == 0:
        print_red_warning(f'{name} all zero')
        return 1
    sim = 2 * (x * y).sum() / denominator
    return sim


def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
    x_mask = torch.isfinite(x)
    y_mask = torch.isfinite(y)
    if not torch.all(x_mask == y_mask):
        print_red_warning(f'{name} Error: isfinite mask mismatch')
        if raise_assert:
            raise AssertionError
    if not torch.isclose(
            x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
            equal_nan=True).all():
        print_red_warning(f'{name} Error: nonfinite value mismatch')
        if raise_assert:
            raise AssertionError
    x = x.masked_fill(~x_mask, 0)
    y = y.masked_fill(~y_mask, 0)
    sim = calc_sim(x, y, name)
    diff = (1. - sim).item()
    print(f'{diff=}')
    if not (0 <= diff <= eps):
        print_red_warning(f'{name} Error: {diff=}')
        if raise_assert:
            raise AssertionError