Unverified Commit f3b97c26 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Fix out of bounds access in the FP4 dequantize kernel (#2346)


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent dcaca2a6
...@@ -39,6 +39,10 @@ __global__ void __launch_bounds__(512) ...@@ -39,6 +39,10 @@ __global__ void __launch_bounds__(512)
const size_t x = thread_idx % M; const size_t x = thread_idx % M;
const size_t y = thread_idx / M; const size_t y = thread_idx / M;
if (y >= N) {
return;
}
union fp4vec { union fp4vec {
uint64_t vec; uint64_t vec;
fp4e2m1x4 small_vec[4]; fp4e2m1x4 small_vec[4];
......
...@@ -13,12 +13,12 @@ import warnings ...@@ -13,12 +13,12 @@ import warnings
import torch import torch
# import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ...quantized_tensor import QuantizedTensorStorage, Quantizer
# from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ...utils import _empty_tensor from ...utils import _empty_tensor
...@@ -45,34 +45,7 @@ class _FromNVFP4Func(torch.autograd.Function): ...@@ -45,34 +45,7 @@ class _FromNVFP4Func(torch.autograd.Function):
# Dequantize row-wise data # Dequantize row-wise data
if tensor._rowwise_data is not None: if tensor._rowwise_data is not None:
### TODO(tmoon): Debug dequantize kernel and remove unfused impl return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
# return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
# Tensor properties
shape = list(tensor._rowwise_data.size())
shape[-1] *= 2
device = tensor._rowwise_data.device
# Convert FP4E2M1 values to FP32
data = tensor._rowwise_data.view(torch.uint8).to(torch.int32)
data = torch.stack((data & 0x0F, data >> 4), dim=-1).reshape(shape)
data = _fp4_e2m1_vals(device, dtype=torch.float32)[data]
data = data.to(torch.float32).contiguous()
# Convert FP8E4M3 block scales to FP32
block_scales = tensor._rowwise_scale_inv
block_scales = block_scales.reshape(-1, block_scales.size(-1))
block_scales = block_scales[: math.prod(shape[:-1]), : shape[-1] // 16]
block_scales = block_scales.view(torch.float8_e4m3fn).to(torch.float32)
# Convert amax to FP32 tensor scale
tensor_scale = tensor._amax_rowwise / (6.0 * 448.0) # Scale by FP4E2M1 and FP8E4M3 max
# Apply scales
block_data = data.view(-1, 16)
block_data *= tensor_scale.view(()) * block_scales.reshape(-1, 1)
return data.to(dtype)
if tensor._columnwise_data is not None: if tensor._columnwise_data is not None:
raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!") raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment