nvfp4_utils.py 5.08 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import torch

5
from vllm._custom_ops import scaled_fp4_quant
6
7
8
9
10
from vllm.scalar_type import scalar_types

FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max

11
12
13
kE2M1ToFloat = torch.tensor(
    [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)
14
15
16
17
18
19
20
21
22
23
24
25


def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
    m_tiles = (m + 128 - 1) // 128
    f = block_size * 4
    k_tiles = (k + f - 1) // f
    tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
    tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
    out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
    return out[0:m, 0:k]


26
27
28
29
30
31
32
33
34
35
36
37
def convert_swizzled_8x4_layout_to_linear(
    a_sf_swizzled: torch.Tensor, m, k, block_size
):
    m_tiles = (m + 8 - 1) // 8
    f = block_size * 4
    k_tiles = (k + f - 1) // f
    tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 8, 4))
    tmp = torch.permute(tmp, (0, 1, 3, 2, 4))
    out = tmp.reshape(m_tiles * 8, k_tiles * f // block_size)
    return out[0:m, 0:k]


38
def dequantize_nvfp4_to_dtype(
39
40
41
42
43
44
45
    tensor_fp4,
    tensor_sf,
    global_scale,
    dtype,
    device,
    block_size=16,
    is_sf_128x4_layout=True,
46
):
47
48
49
50
51
52
53
54
    """Dequantize the fp4 tensor back to high precision."""
    # Two fp4 values are packed into one uint8.
    assert tensor_fp4.dtype == torch.uint8
    m, packed_k = tensor_fp4.shape
    k = packed_k * 2
    tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
    tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
    tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
55
56
57
58
59
    if is_sf_128x4_layout:
        tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
    else:
        tensor_sf = convert_swizzled_8x4_layout_to_linear(tensor_sf, m, k, block_size)

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale

    # scale the tensor
    out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
    return out.to(dtype=dtype)


def break_fp4_bytes(a, dtype):
    assert a.dtype == torch.uint8
    m, n = a.shape

    # Vectorized nibble processing
    a_flat = a.flatten()
    high = (a_flat & 0xF0) >> 4  # Upper nibbles
    low = a_flat & 0x0F  # Lower nibbles

    # Combine nibbles for batch processing
    combined = torch.stack((low, high), dim=1).flatten()

    # Vectorized sign and magnitude extraction
    signs = (combined & 0x08).to(torch.bool)  # Sign bits
    abs_vals = (combined & 0x07).to(torch.long)  # Magnitude indices

    # Device-aware lookup and sign application
    kE2M1 = kE2M1ToFloat.to(device=a.device)
    values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)

    # Reshape to final form
    return values.reshape(m, n * 2).to(dtype=dtype)
89
90


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def dequant_nvfp4_kv_cache(
    fp4_data: torch.Tensor,
    block_scale: torch.Tensor,
    global_scale: float,
    head_size: int,
    block_size: int,
) -> torch.Tensor:
    """Dequantize an NVFP4 KV cache with 4x4-swizzled block scales.

    The input must be in HND layout so that the last two dims are
    (block_size, last_dim).  For NHD caches, permute to HND first.

    Args:
        fp4_data: [..., num_heads, block_size, head_size//2] uint8 packed fp4.
        block_scale: [..., num_heads, block_size, head_size//16] fp8 block
            scales (as uint8 or float8_e4m3fn).
        global_scale: checkpoint dequant scale (k_scale or v_scale).
        head_size: head dimension.
        block_size: page size.

    Returns:
        [..., num_heads, block_size, head_size] float32.
    """
    data_dim = head_size // 2
    scale_dim = head_size // 16

    fp4_packed = fp4_data
    sf_swizzled = block_scale.view(torch.uint8)

    # Unswizzle 4x4 block scales on (block_size, scale_dim) plane.
    # [..., T, S] → [..., T//4, 4, sg, 4] → permute → [..., T, S]
    batch_shape = sf_swizzled.shape[:-2]
    T, S = block_size, scale_dim
    sg = S // 4
    sf_reshape = sf_swizzled.reshape(*batch_shape, T // 4, 4, sg, 4)
    ndim = sf_reshape.ndim
    # Swap the last four dims: (..., T//4, 4, sg, 4) → (..., T//4, 4, 4, sg)
    perm = list(range(ndim - 4)) + [ndim - 4, ndim - 1, ndim - 3, ndim - 2]
    sf_linear = sf_reshape.permute(*perm).reshape(*batch_shape, T, S)
    sf_f32 = sf_linear.view(torch.float8_e4m3fn).to(torch.float32)

    # Unpack fp4
    shape = fp4_packed.shape  # [..., T, data_dim]
    fp4_flat = fp4_packed.reshape(-1, data_dim)
    fp4_vals = break_fp4_bytes(fp4_flat, torch.float32)
    fp4_vals = fp4_vals.reshape(*shape[:-1], head_size)

    # Dequant: fp4_val * block_scale * global_scale per 16-element group
    return (
        fp4_vals.reshape(*shape[:-1], scale_dim, 16)
        * (sf_f32 * global_scale).unsqueeze(-1)
    ).reshape(*shape[:-1], head_size)


145
146
147
148
def get_nvfp4_global_scale(a: torch.Tensor):
    return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)


149
def quant_nvfp4_tensor(a: torch.Tensor):
150
    a_global_scale = get_nvfp4_global_scale(a)
151
152
    a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
    return a_quant, a_block_scale, a_global_scale