# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm._custom_ops import scaled_fp4_quant from vllm.scalar_type import scalar_types FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max kE2M1ToFloat = torch.tensor( [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 ) def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): m_tiles = (m + 128 - 1) // 128 f = block_size * 4 k_tiles = (k + f - 1) // f tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) return out[0:m, 0:k] def convert_swizzled_8x4_layout_to_linear( a_sf_swizzled: torch.Tensor, m, k, block_size ): m_tiles = (m + 8 - 1) // 8 f = block_size * 4 k_tiles = (k + f - 1) // f tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 8, 4)) tmp = torch.permute(tmp, (0, 1, 3, 2, 4)) out = tmp.reshape(m_tiles * 8, k_tiles * f // block_size) return out[0:m, 0:k] def dequantize_nvfp4_to_dtype( tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16, is_sf_128x4_layout=True, ): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. assert tensor_fp4.dtype == torch.uint8 m, packed_k = tensor_fp4.shape k = packed_k * 2 tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) tensor_sf = tensor_sf.view(torch.float8_e4m3fn) if is_sf_128x4_layout: tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) else: tensor_sf = convert_swizzled_8x4_layout_to_linear(tensor_sf, m, k, block_size) tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale # scale the tensor out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) return out.to(dtype=dtype) def break_fp4_bytes(a, dtype): assert a.dtype == torch.uint8 m, n = a.shape # Vectorized nibble processing a_flat = a.flatten() high = (a_flat & 0xF0) >> 4 # Upper nibbles low = a_flat & 0x0F # Lower nibbles # Combine nibbles for batch processing combined = torch.stack((low, high), dim=1).flatten() # Vectorized sign and magnitude extraction signs = (combined & 0x08).to(torch.bool) # Sign bits abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices # Device-aware lookup and sign application kE2M1 = kE2M1ToFloat.to(device=a.device) values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) # Reshape to final form return values.reshape(m, n * 2).to(dtype=dtype) def dequant_nvfp4_kv_cache( fp4_data: torch.Tensor, block_scale: torch.Tensor, global_scale: float, head_size: int, block_size: int, ) -> torch.Tensor: """Dequantize an NVFP4 KV cache with 4x4-swizzled block scales. The input must be in HND layout so that the last two dims are (block_size, last_dim). For NHD caches, permute to HND first. Args: fp4_data: [..., num_heads, block_size, head_size//2] uint8 packed fp4. block_scale: [..., num_heads, block_size, head_size//16] fp8 block scales (as uint8 or float8_e4m3fn). global_scale: checkpoint dequant scale (k_scale or v_scale). head_size: head dimension. block_size: page size. Returns: [..., num_heads, block_size, head_size] float32. """ data_dim = head_size // 2 scale_dim = head_size // 16 fp4_packed = fp4_data sf_swizzled = block_scale.view(torch.uint8) # Unswizzle 4x4 block scales on (block_size, scale_dim) plane. # [..., T, S] → [..., T//4, 4, sg, 4] → permute → [..., T, S] batch_shape = sf_swizzled.shape[:-2] T, S = block_size, scale_dim sg = S // 4 sf_reshape = sf_swizzled.reshape(*batch_shape, T // 4, 4, sg, 4) ndim = sf_reshape.ndim # Swap the last four dims: (..., T//4, 4, sg, 4) → (..., T//4, 4, 4, sg) perm = list(range(ndim - 4)) + [ndim - 4, ndim - 1, ndim - 3, ndim - 2] sf_linear = sf_reshape.permute(*perm).reshape(*batch_shape, T, S) sf_f32 = sf_linear.view(torch.float8_e4m3fn).to(torch.float32) # Unpack fp4 shape = fp4_packed.shape # [..., T, data_dim] fp4_flat = fp4_packed.reshape(-1, data_dim) fp4_vals = break_fp4_bytes(fp4_flat, torch.float32) fp4_vals = fp4_vals.reshape(*shape[:-1], head_size) # Dequant: fp4_val * block_scale * global_scale per 16-element group return ( fp4_vals.reshape(*shape[:-1], scale_dim, 16) * (sf_f32 * global_scale).unsqueeze(-1) ).reshape(*shape[:-1], head_size) def get_nvfp4_global_scale(a: torch.Tensor): return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32) def quant_nvfp4_tensor(a: torch.Tensor): a_global_scale = get_nvfp4_global_scale(a) a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale) return a_quant, a_block_scale, a_global_scale