Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f3b97c26
Unverified
Commit
f3b97c26
authored
Nov 06, 2025
by
Przemyslaw Tredak
Committed by
GitHub
Nov 06, 2025
Browse files
Fix out of bounds access in the FP4 dequantize kernel (#2346)
Signed-off-by:
Przemek Tredak
<
ptredak@nvidia.com
>
parent
dcaca2a6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
30 deletions
+7
-30
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh
+4
-0
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
...mer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
+3
-30
No files found.
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh
View file @
f3b97c26
...
@@ -39,6 +39,10 @@ __global__ void __launch_bounds__(512)
...
@@ -39,6 +39,10 @@ __global__ void __launch_bounds__(512)
const
size_t
x
=
thread_idx
%
M
;
const
size_t
x
=
thread_idx
%
M
;
const
size_t
y
=
thread_idx
/
M
;
const
size_t
y
=
thread_idx
/
M
;
if
(
y
>=
N
)
{
return
;
}
union
fp4vec
{
union
fp4vec
{
uint64_t
vec
;
uint64_t
vec
;
fp4e2m1x4
small_vec
[
4
];
fp4e2m1x4
small_vec
[
4
];
...
...
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
View file @
f3b97c26
...
@@ -13,12 +13,12 @@ import warnings
...
@@ -13,12 +13,12 @@ import warnings
import
torch
import
torch
#
import transformer_engine_torch as tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
DType
as
TE_DType
from
...quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
from
...quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
#
from ...constants import TE_DType as torch_to_transformer_engine_dtype
from
...constants
import
TE_DType
as
torch_to_transformer_engine_dtype
from
...utils
import
_empty_tensor
from
...utils
import
_empty_tensor
...
@@ -45,34 +45,7 @@ class _FromNVFP4Func(torch.autograd.Function):
...
@@ -45,34 +45,7 @@ class _FromNVFP4Func(torch.autograd.Function):
# Dequantize row-wise data
# Dequantize row-wise data
if
tensor
.
_rowwise_data
is
not
None
:
if
tensor
.
_rowwise_data
is
not
None
:
### TODO(tmoon): Debug dequantize kernel and remove unfused impl
return
tex
.
dequantize
(
tensor
,
torch_to_transformer_engine_dtype
[
dtype
])
# return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
# Tensor properties
shape
=
list
(
tensor
.
_rowwise_data
.
size
())
shape
[
-
1
]
*=
2
device
=
tensor
.
_rowwise_data
.
device
# Convert FP4E2M1 values to FP32
data
=
tensor
.
_rowwise_data
.
view
(
torch
.
uint8
).
to
(
torch
.
int32
)
data
=
torch
.
stack
((
data
&
0x0F
,
data
>>
4
),
dim
=-
1
).
reshape
(
shape
)
data
=
_fp4_e2m1_vals
(
device
,
dtype
=
torch
.
float32
)[
data
]
data
=
data
.
to
(
torch
.
float32
).
contiguous
()
# Convert FP8E4M3 block scales to FP32
block_scales
=
tensor
.
_rowwise_scale_inv
block_scales
=
block_scales
.
reshape
(
-
1
,
block_scales
.
size
(
-
1
))
block_scales
=
block_scales
[:
math
.
prod
(
shape
[:
-
1
]),
:
shape
[
-
1
]
//
16
]
block_scales
=
block_scales
.
view
(
torch
.
float8_e4m3fn
).
to
(
torch
.
float32
)
# Convert amax to FP32 tensor scale
tensor_scale
=
tensor
.
_amax_rowwise
/
(
6.0
*
448.0
)
# Scale by FP4E2M1 and FP8E4M3 max
# Apply scales
block_data
=
data
.
view
(
-
1
,
16
)
block_data
*=
tensor_scale
.
view
(())
*
block_scales
.
reshape
(
-
1
,
1
)
return
data
.
to
(
dtype
)
if
tensor
.
_columnwise_data
is
not
None
:
if
tensor
.
_columnwise_data
is
not
None
:
raise
NotImplementedError
(
"Dequantizing column-wise NVFP4 data is not implemented yet!"
)
raise
NotImplementedError
(
"Dequantizing column-wise NVFP4 data is not implemented yet!"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment