Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
53fa872c
Commit
53fa872c
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_release_v2.8' into release_v2.8
parents
27ddce40
40c69e75
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2998 additions
and
340 deletions
+2998
-340
transformer_engine/pytorch/csrc/extensions/attention.cpp
transformer_engine/pytorch/csrc/extensions/attention.cpp
+225
-169
transformer_engine/pytorch/csrc/extensions/bias.cpp
transformer_engine/pytorch/csrc/extensions/bias.cpp
+59
-12
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+18
-2
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+16
-4
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+176
-94
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+19
-0
transformer_engine/pytorch/csrc/pybind.h
transformer_engine/pytorch/csrc/pybind.h
+14
-6
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+591
-5
transformer_engine/pytorch/csrc/type_converters.cpp
transformer_engine/pytorch/csrc/type_converters.cpp
+40
-0
transformer_engine/pytorch/csrc/util.cpp
transformer_engine/pytorch/csrc/util.cpp
+31
-24
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+257
-6
transformer_engine/pytorch/experimental/__init__.py
transformer_engine/pytorch/experimental/__init__.py
+10
-0
transformer_engine/pytorch/experimental/config.py
transformer_engine/pytorch/experimental/config.py
+201
-0
transformer_engine/pytorch/experimental/gemm.py
transformer_engine/pytorch/experimental/gemm.py
+139
-0
transformer_engine/pytorch/experimental/quantization.py
transformer_engine/pytorch/experimental/quantization.py
+203
-0
transformer_engine/pytorch/experimental/quantization_microblock_ref.py
...ngine/pytorch/experimental/quantization_microblock_ref.py
+811
-0
transformer_engine/pytorch/experimental/utils.py
transformer_engine/pytorch/experimental/utils.py
+30
-0
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+108
-1
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+33
-5
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+17
-12
No files found.
transformer_engine/pytorch/csrc/extensions/attention.cpp
View file @
53fa872c
...
...
@@ -35,22 +35,6 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
{
nvte_memset
(
base_ptr
,
0
,
total_bytes
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
void
unpack
(
at
::
PhiloxCudaState
arg
,
int64_t
*
rng_state_ptr
)
{
NVTE_SCOPED_GIL_RELEASE
({
nvte_extract_seed_and_offset
(
rng_state_ptr
,
arg
.
captured_
,
arg
.
seed_
.
ptr
,
arg
.
seed_
.
val
,
arg
.
offset_
.
ptr
,
arg
.
offset_
.
val
,
arg
.
offset_intragraph_
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
// extract PhiloxCudaState from CUDA random number generator
at
::
PhiloxCudaState
init_philox_state
(
at
::
CUDAGeneratorImpl
*
gen
,
size_t
elts_per_thread
)
{
at
::
PhiloxCudaState
philox_args
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
philox_args
=
gen
->
philox_cuda_state
(
elts_per_thread
);
return
philox_args
;
}
}
// namespace
namespace
transformer_engine
::
pytorch
{
...
...
@@ -58,73 +42,144 @@ namespace transformer_engine::pytorch {
// get the fused attention backend
NVTE_Fused_Attn_Backend
get_fused_attn_backend
(
bool
is_training
,
const
DType
q_dtype
,
const
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
#ifdef __HIP_PLATFORM_AMD__
return
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
#else
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
attn_mask_type
,
p_dropout
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
window_size_left
,
window_size_right
);
bias_type
,
attn_mask_type
,
softmax_type
,
p_dropout
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
window_size_left
,
window_size_right
);
return
fused_attention_backend
;
#endif
}
// helper function for S and dP quantizers
std
::
pair
<
TensorWrapper
,
py
::
object
>
quantizer_helper
(
py
::
handle
quantizer
,
const
std
::
vector
<
size_t
>
&
shape
,
DType
dtype
,
bool
create_hp_tensor_for_cs
,
std
::
optional
<
at
::
Tensor
>
data
)
{
std
::
unique_ptr
<
Quantizer
>
T_quantizer
=
convert_quantizer
(
quantizer
);
TensorWrapper
te_T
;
py
::
object
py_T
;
if
(
quantizer
.
is_none
())
{
// high precision
auto
*
none_quantizer
=
dynamic_cast
<
NoneQuantizer
*>
(
T_quantizer
.
get
());
if
(
data
.
has_value
())
{
std
::
tie
(
te_T
,
py_T
)
=
none_quantizer
->
create_tensor
(
shape
,
dtype
,
data
.
value
());
}
else
{
std
::
tie
(
te_T
,
py_T
)
=
none_quantizer
->
create_tensor
(
shape
,
dtype
);
}
}
else
if
(
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
// delayed scaling; this helps initialize scale_inv
auto
*
T_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
T_quantizer
.
get
());
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_tensor
(
shape
,
dtype
,
data
,
std
::
nullopt
,
std
::
nullopt
);
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// current scaling
auto
*
T_quantizer_fp8
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
T_quantizer
.
get
());
if
(
create_hp_tensor_for_cs
)
{
if
(
data
.
has_value
())
{
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_unquantized_tensor_with_amax
(
shape
,
dtype
,
data
.
value
());
}
else
{
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_unquantized_tensor_with_amax
(
shape
,
dtype
);
}
}
else
{
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_tensor
(
shape
,
dtype
);
NVTE_CHECK
(
!
data
.
has_value
(),
"Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"
);
}
}
return
{
std
::
move
(
te_T
),
std
::
move
(
py_T
)};
}
// helper function for S and dP quantizers
std
::
pair
<
TensorWrapper
,
py
::
object
>
quantizer_helper
(
py
::
handle
quantizer
,
const
std
::
vector
<
size_t
>
&
shape
,
DType
dtype
,
bool
create_hp_tensor_for_cs
,
std
::
optional
<
at
::
Tensor
>
data
)
{
std
::
unique_ptr
<
Quantizer
>
T_quantizer
=
convert_quantizer
(
quantizer
);
TensorWrapper
te_T
;
py
::
object
py_T
;
if
(
quantizer
.
is_none
())
{
// high precision
auto
*
none_quantizer
=
dynamic_cast
<
NoneQuantizer
*>
(
T_quantizer
.
get
());
if
(
data
.
has_value
())
{
std
::
tie
(
te_T
,
py_T
)
=
none_quantizer
->
create_tensor
(
shape
,
dtype
,
data
.
value
());
}
else
{
std
::
tie
(
te_T
,
py_T
)
=
none_quantizer
->
create_tensor
(
shape
,
dtype
);
}
}
else
if
(
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
// delayed scaling; this helps initialize scale_inv
auto
*
T_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
T_quantizer
.
get
());
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_tensor
(
shape
,
dtype
,
data
,
std
::
nullopt
,
std
::
nullopt
);
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// current scaling
auto
*
T_quantizer_fp8
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
T_quantizer
.
get
());
if
(
create_hp_tensor_for_cs
)
{
if
(
data
.
has_value
())
{
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_unquantized_tensor_with_amax
(
shape
,
dtype
,
data
.
value
());
}
else
{
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_unquantized_tensor_with_amax
(
shape
,
dtype
);
}
}
else
{
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_tensor
(
shape
,
dtype
);
NVTE_CHECK
(
!
data
.
has_value
(),
"Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"
);
}
}
return
{
std
::
move
(
te_T
),
std
::
move
(
py_T
)};
}
// fused attention FWD with separate Q, K and V tensors
std
::
vector
<
py
::
object
>
fused_attn_fwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
const
std
::
vector
<
int64_t
>
window_siz
e
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_
kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_typ
e
,
const
std
::
vector
<
int64_t
>
window_size
,
const
at
::
Tensor
cu_seqlens_
q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
const
std
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
const
std
::
optional
<
at
::
Tensor
>
SoftmaxOffset
,
const
std
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
#else
TensorWrapper
te_Q
,
te_K
,
te_V
,
te_O
,
te_S
;
auto
none
=
py
::
none
();
std
::
unique_ptr
<
Quantizer
>
S_quantizer
=
convert_quantizer
(
s_quantizer
);
std
::
unique_ptr
<
Quantizer
>
O_quantizer
=
convert_quantizer
(
o_quantizer
);
// create QKV tensor wrappers
TensorWrapper
te_Q
,
te_K
,
te_V
;
te_Q
=
makeTransformerEngineTensor
(
Q
,
none
);
te_K
=
makeTransformerEngineTensor
(
K
,
none
);
te_V
=
makeTransformerEngineTensor
(
V
,
none
);
// If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types.
const
DType
qkv_type
=
te_Q
.
dtype
();
const
DType
fake_dtype_te
=
GetTransformerEngineDType
(
fake_dtype
);
// create S tensor
TensorWrapper
te_S
;
py
::
object
py_S
;
std
::
tie
(
te_S
,
py_S
)
=
quantizer_helper
(
s_quantizer
,
{
0
},
DType
::
kFloat32
,
false
,
std
::
nullopt
);
// create O tensor
TensorWrapper
te_O
;
py
::
object
py_O
;
std
::
unique_ptr
<
Quantizer
>
O_quantizer
=
convert_quantizer
(
o_quantizer
);
std
::
vector
<
size_t
>
q_shape
=
convertShape
(
te_Q
.
shape
());
std
::
vector
<
size_t
>
k_shape
=
convertShape
(
te_K
.
shape
());
std
::
vector
<
size_t
>
v_shape
=
convertShape
(
te_V
.
shape
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
qkv_type
)).
device
(
torch
::
kCUDA
);
// create output tensor O
auto
o_shape
=
std
::
vector
<
size_t
>
{
q_shape
.
begin
(),
q_shape
.
end
()};
o_shape
[
o_shape
.
size
()
-
1
]
=
v_shape
[
v_shape
.
size
()
-
1
];
py
::
object
o_python
,
s_python
;
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
// Initialize FP8 tensor with scale-inverse
auto
*
O_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
O_quantizer
.
get
());
auto
*
S_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
S_quantizer
.
get
());
NVTE_CHECK
(
O_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
NVTE_CHECK
(
S_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
std
::
tie
(
te_O
,
o_python
)
=
O_quantizer_fp8
->
create_tensor
(
o_shape
,
fake_dtype_te
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer_fp8
->
create_tensor
({
0
},
DType
::
kFloat32
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
}
else
{
std
::
tie
(
te_O
,
o_python
)
=
O_quantizer
->
create_tensor
(
o_shape
,
fake_dtype_te
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
}
auto
o_shape_int64
=
std
::
vector
<
int64_t
>
{
o_shape
.
begin
(),
o_shape
.
end
()};
const
DType
fake_dtype_te
=
GetTransformerEngineDType
(
fake_dtype
);
std
::
tie
(
te_O
,
py_O
)
=
quantizer_helper
(
o_quantizer
,
o_shape
,
fake_dtype_te
,
true
,
std
::
nullopt
);
// construct NVTE tensors
TensorWrapper
te_Bias
;
...
...
@@ -135,11 +190,12 @@ std::vector<py::object> fused_attn_fwd(
// FP8
auto
h
=
q_shape
[
q_shape
.
size
()
-
2
];
auto
d
=
q_shape
[
q_shape
.
size
()
-
1
];
if
(
set_zero
&&
((
h
*
d
)
%
block_size
==
0
)
&&
(
nvte_get_qkv_format
(
qkv_layout
)
==
NVTE_QKV_Format
::
NVTE_THD
))
{
mha_fill
(
te_O
,
cu_seqlens_q
.
index
({
torch
::
indexing
::
Slice
(
-
1
,
torch
::
indexing
::
None
)}));
}
else
{
te_O
.
zero_
(
at
::
cuda
::
getCurrentCUDAStream
());
if
(
set_zero
&&
(
nvte_get_qkv_format
(
qkv_layout
)
==
NVTE_QKV_Format
::
NVTE_THD
))
{
if
((
h
*
d
)
%
block_size
==
0
)
{
mha_fill
(
te_O
,
cu_seqlens_q
.
index
({
torch
::
indexing
::
Slice
(
-
1
,
torch
::
indexing
::
None
)}));
}
else
{
te_O
.
zero_
(
at
::
cuda
::
getCurrentCUDAStream
());
}
}
}
else
if
(
qkv_type
==
DType
::
kBFloat16
||
qkv_type
==
DType
::
kFloat16
)
{
if
(
nvte_get_qkv_format
(
qkv_layout
)
==
NVTE_QKV_Format
::
NVTE_THD
)
{
...
...
@@ -188,12 +244,23 @@ std::vector<py::object> fused_attn_fwd(
DType
::
kInt32
,
nullptr
,
nullptr
,
nullptr
);
}
// softmax offset
TensorWrapper
te_SoftmaxOffset
;
if
((
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
&&
(
SoftmaxOffset
.
has_value
()))
{
auto
SoftmaxOffset_sizes
=
SoftmaxOffset
.
value
().
sizes
().
vec
();
std
::
vector
<
size_t
>
SoftmaxOffset_shape
{
SoftmaxOffset_sizes
.
begin
(),
SoftmaxOffset_sizes
.
end
()};
te_SoftmaxOffset
=
makeTransformerEngineTensor
(
SoftmaxOffset
.
value
().
data_ptr
(),
SoftmaxOffset_shape
,
DType
::
kFloat32
,
nullptr
,
nullptr
,
nullptr
);
}
// extract rng seed and offset
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
rng_gen
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
at
::
PhiloxCudaState
philox_args
=
init_philox_state
(
gen
,
rng_elts_per_thread
);
auto
rng_state
=
torch
::
empty
({
2
},
options
.
dtype
(
torch
::
kInt64
));
unpack
(
philox_args
,
static_cast
<
int64_t
*>
(
rng_state
.
data_ptr
()));
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
);
auto
rng_state
=
torch
::
empty
({
2
},
options
);
philox_unpack
(
philox_args
,
static_cast
<
int64_t
*>
(
rng_state
.
data_ptr
()));
auto
te_rng_state
=
makeTransformerEngineTensor
(
rng_state
);
// create auxiliary output tensors
...
...
@@ -206,11 +273,11 @@ std::vector<py::object> fused_attn_fwd(
// populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE
({
nvte_fused_attn_fwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_Bias
.
data
(),
te_S
.
data
(),
te_
O
.
data
(),
&
nvte_aux_tensor_pack
,
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_Bias
.
data
(),
te_S
oftmaxOffset
.
data
(),
te_
S
.
data
(),
te_O
.
data
(),
&
nvte_aux_tensor_pack
,
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
te_page_table_k
.
data
(),
te_page_table_v
.
data
(),
te_rng_state
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size
[
0
],
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
...
...
@@ -221,52 +288,53 @@ std::vector<py::object> fused_attn_fwd(
// output_tensors = [O, nvte_aux_tensor_pack.tensors]
std
::
vector
<
py
::
object
>
output_tensors
;
output_tensors
.
push_back
(
o_python
);
for
(
size_t
i
=
0
;
i
<
nvte_aux_tensor_pack
.
size
;
++
i
)
{
// allocate memory for nvte_aux_tensor_pack.tensors
at
::
Tensor
output_tensor
;
if
(
nvte_aux_tensor_pack
.
size
>=
2
)
{
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
)
&&
(
Bias
.
has_value
()))
{
if
(
i
<
nvte_aux_tensor_pack
.
size
-
2
)
{
NVTEShape
temp_shape
=
nvte_tensor_shape
(
nvte_aux_tensor_pack
.
tensors
[
i
]);
output_tensor
=
allocateSpace
(
nvte_shape_to_vector
(
temp_shape
),
static_cast
<
DType
>
(
nvte_tensor_type
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
false
);
}
else
if
(
i
==
nvte_aux_tensor_pack
.
size
-
2
)
{
output_tensor
=
rng_state
;
}
else
if
(
i
==
nvte_aux_tensor_pack
.
size
-
1
)
{
output_tensor
=
Bias
.
value
();
}
}
else
{
NVTEShape
temp_shape
=
nvte_tensor_shape
(
nvte_aux_tensor_pack
.
tensors
[
i
]);
output_tensor
=
(
i
<
nvte_aux_tensor_pack
.
size
-
1
)
?
allocateSpace
(
nvte_shape_to_vector
(
temp_shape
),
static_cast
<
DType
>
(
nvte_tensor_type
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
false
)
:
rng_state
;
}
}
else
{
NVTEShape
temp_shape
=
nvte_tensor_shape
(
nvte_aux_tensor_pack
.
tensors
[
i
]);
output_tensor
=
allocateSpace
(
nvte_shape_to_vector
(
temp_shape
),
static_cast
<
DType
>
(
nvte_tensor_type
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
false
);
}
output_tensors
.
push_back
(
py_O
);
auto
set_tensor_param
=
[
&
](
size_t
i
,
const
at
::
Tensor
&
output_tensor
)
{
output_tensors
.
push_back
(
py
::
cast
(
output_tensor
));
NVTEBasicTensor
temp_data
=
{
output_tensor
.
data_ptr
(),
nvte_tensor_type
(
nvte_aux_tensor_pack
.
tensors
[
i
]),
nvte_tensor_shape
(
nvte_aux_tensor_pack
.
tensors
[
i
])};
nvte_set_tensor_param
(
&
nvte_aux_tensor_pack
.
tensors
[
i
],
kNVTERowwiseData
,
&
temp_data
);
};
// allocate memory for nvte_aux_tensor_pack.tensors
// f16_max512 : S [b, h, sq, skv]
// f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
size_t
i
=
0
;
at
::
Tensor
output_tensor
;
// intermediate softmax tensor, S or M
output_tensor
=
allocateSpace
(
nvte_shape_to_vector
(
nvte_tensor_shape
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
static_cast
<
DType
>
(
nvte_tensor_type
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
false
);
set_tensor_param
(
i
++
,
output_tensor
);
// fp8 has an additional softmax stats tensor, ZInv
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
output_tensor
=
allocateSpace
(
nvte_shape_to_vector
(
nvte_tensor_shape
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
static_cast
<
DType
>
(
nvte_tensor_type
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
false
);
set_tensor_param
(
i
++
,
output_tensor
);
}
// rng_state
if
(
i
<
nvte_aux_tensor_pack
.
size
)
{
set_tensor_param
(
i
++
,
rng_state
);
}
// bias (optional)
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
)
&&
(
Bias
.
has_value
()))
{
set_tensor_param
(
i
++
,
Bias
.
value
());
}
// softmax_offset (optional)
if
((
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
&&
(
SoftmaxOffset
.
has_value
()))
{
set_tensor_param
(
i
++
,
SoftmaxOffset
.
value
());
}
// execute the kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_fused_attn_fwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_Bias
.
data
(),
te_S
.
data
(),
te_
O
.
data
(),
&
nvte_aux_tensor_pack
,
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_Bias
.
data
(),
te_S
oftmaxOffset
.
data
(),
te_
S
.
data
(),
te_O
.
data
(),
&
nvte_aux_tensor_pack
,
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
te_page_table_k
.
data
(),
te_page_table_v
.
data
(),
te_rng_state
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size
[
0
],
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
...
...
@@ -282,9 +350,10 @@ std::vector<py::object> fused_attn_fwd(
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
const
std
::
vector
<
int64_t
>
window_size
,
bool
deterministic
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
DType
dqkv_type
,
NVTE_Softmax_Type
softmax_type
,
const
std
::
vector
<
int64_t
>
window_size
,
bool
deterministic
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
DType
dqkv_type
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
...
...
@@ -293,50 +362,44 @@ std::vector<py::object> fused_attn_bwd(
assert
(
false
);
#else
auto
none
=
py
::
none
();
TensorWrapper
te_Q
,
te_K
,
te_V
,
te_O
,
te_dO
,
te_S
,
te_dP
,
te_dQ
,
te_dK
,
te_dV
;
// create QKV, O, dO tensor wrappers
TensorWrapper
te_Q
,
te_K
,
te_V
,
te_O
,
te_dO
;
te_Q
=
makeTransformerEngineTensor
(
Q
,
none
);
te_K
=
makeTransformerEngineTensor
(
K
,
none
);
te_V
=
makeTransformerEngineTensor
(
V
,
none
);
te_O
=
makeTransformerEngineTensor
(
O
,
none
);
te_dO
=
makeTransformerEngineTensor
(
dO
,
none
);
// qkv type from the te_Q
std
::
unique_ptr
<
Quantizer
>
dQKV_quantizer
=
convert_quantizer
(
dqkv_quantizer
);
const
DType
qkv_type
=
te_Q
.
dtype
();
const
DType
fake_dtype_te
=
GetTransformerEngineDType
(
fake_dtype
);
py
::
object
s_python
,
dp_python
;
std
::
unique_ptr
<
Quantizer
>
S_quantizer
=
convert_quantizer
(
s_quantizer
);
std
::
unique_ptr
<
Quantizer
>
dP_quantizer
=
convert_quantizer
(
dp_quantizer
);
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
auto
*
S_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
S_quantizer
.
get
());
auto
*
dP_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
dP_quantizer
.
get
());
NVTE_CHECK
(
S_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
NVTE_CHECK
(
dP_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer_fp8
->
create_tensor
({
0
},
DType
::
kFloat32
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_dP
,
dp_python
)
=
dP_quantizer_fp8
->
create_tensor
({
0
},
DType
::
kFloat32
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
}
else
{
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
std
::
tie
(
te_dP
,
dp_python
)
=
dP_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
}
// create S and dP tensors
TensorWrapper
te_S
,
te_dP
;
py
::
object
py_S
,
py_dP
;
std
::
tie
(
te_S
,
py_S
)
=
quantizer_helper
(
s_quantizer
,
{
0
},
DType
::
kFloat32
,
false
,
std
::
nullopt
);
std
::
tie
(
te_dP
,
py_dP
)
=
quantizer_helper
(
dp_quantizer
,
{
0
},
DType
::
kFloat32
,
false
,
std
::
nullopt
);
// create dQ, dK, dV tensors
TensorWrapper
te_dQ
,
te_dK
,
te_dV
;
py
::
object
py_dQ
,
py_dK
,
py_dV
;
std
::
unique_ptr
<
Quantizer
>
dQKV_quantizer
=
convert_quantizer
(
dqkv_quantizer
);
std
::
vector
<
size_t
>
q_shape
=
convertShape
(
te_Q
.
shape
());
std
::
vector
<
size_t
>
k_shape
=
convertShape
(
te_K
.
shape
());
std
::
vector
<
size_t
>
v_shape
=
convertShape
(
te_V
.
shape
());
auto
h_q
=
q_shape
[
q_shape
.
size
()
-
2
];
auto
h_kv
=
k_shape
[
k_shape
.
size
()
-
2
];
auto
d_qk
=
q_shape
[
q_shape
.
size
()
-
1
];
auto
d_v
=
v_shape
[
v_shape
.
size
()
-
1
];
auto
options
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
dqkv_type
)).
device
(
torch
::
kCUDA
);
std
::
vector
<
size_t
>
o_shape
{
q_shape
.
begin
(),
q_shape
.
end
()};
o_shape
[
o_shape
.
size
()
-
1
]
=
d_v
;
const
DType
fake_dtype_te
=
GetTransformerEngineDType
(
fake_dtype
);
at
::
Tensor
dQ
,
dK
,
dV
,
dQKV
,
dKV
;
py
::
object
py_dQ
,
py_dK
,
py_dV
;
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
std
::
vector
<
int64_t
>
tmp_shape
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
dqkv_type
)).
device
(
torch
::
kCUDA
);
if
(
dqkv_type
==
DType
::
kFloat8E4M3
||
dqkv_type
==
DType
::
kFloat8E5M2
)
{
options
=
options
.
dtype
(
torch
::
kUInt8
);
}
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
dqkv_quantizer
.
ptr
()))
{
options
=
options
.
dtype
(
fake_dtype
);
}
switch
(
layout_group
)
{
case
NVTE_QKV_Layout_Group
::
NVTE_3HD
:
...
...
@@ -409,39 +472,27 @@ std::vector<py::object> fused_attn_bwd(
default:
NVTE_ERROR
(
"QKV layout not supported!"
);
}
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
auto
*
fp8_quantizer
=
dynamic_cast
<
Float8Quantizer
*>
(
dQKV_quantizer
.
get
());
NVTE_CHECK
(
fp8_quantizer
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
std
::
tie
(
te_dQ
,
py_dQ
)
=
fp8_quantizer
->
create_tensor
(
q_shape
,
fake_dtype_te
,
dQ
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_dK
,
py_dK
)
=
fp8_quantizer
->
create_tensor
(
k_shape
,
fake_dtype_te
,
dK
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_dV
,
py_dV
)
=
fp8_quantizer
->
create_tensor
(
v_shape
,
fake_dtype_te
,
dV
,
std
::
nullopt
,
std
::
nullopt
);
}
else
{
auto
*
none_quantizer
=
dynamic_cast
<
NoneQuantizer
*>
(
dQKV_quantizer
.
get
());
NVTE_CHECK
(
none_quantizer
!=
nullptr
,
"Expected NoneQuantizer when dtype is not FP8"
);
std
::
tie
(
te_dQ
,
py_dQ
)
=
none_quantizer
->
create_tensor
(
q_shape
,
fake_dtype_te
,
dQ
);
std
::
tie
(
te_dK
,
py_dK
)
=
none_quantizer
->
create_tensor
(
k_shape
,
fake_dtype_te
,
dK
);
std
::
tie
(
te_dV
,
py_dV
)
=
none_quantizer
->
create_tensor
(
v_shape
,
fake_dtype_te
,
dV
);
}
std
::
tie
(
te_dQ
,
py_dQ
)
=
quantizer_helper
(
dqkv_quantizer
,
q_shape
,
fake_dtype_te
,
true
,
dQ
);
std
::
tie
(
te_dK
,
py_dK
)
=
quantizer_helper
(
dqkv_quantizer
,
k_shape
,
fake_dtype_te
,
true
,
dK
);
std
::
tie
(
te_dV
,
py_dV
)
=
quantizer_helper
(
dqkv_quantizer
,
v_shape
,
fake_dtype_te
,
true
,
dV
);
// construct NVTE tensors
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
if
(
d
qkv_type
==
DType
::
kFloat8E4M3
||
d
qkv_type
==
DType
::
kFloat8E5M2
)
{
// FP8
if
(
set_zero
&&
((
h_q
*
d_qk
)
%
block_size
==
0
)
&&
((
h_kv
*
d_qk
)
%
block_size
==
0
)
&&
dQ
.
is_contiguous
()
&&
dK
.
is_contiguous
()
&&
dV
.
is_contiguous
()
&&
(
nvte_get_qkv_format
(
qkv_layout
)
==
NVTE_QKV_Format
::
NVTE_THD
))
{
mha_fill
(
te_dQ
,
cu_seqlens_q
.
index
({
torch
::
indexing
::
Slice
(
-
1
,
torch
::
indexing
::
None
)}));
mha_fill
(
te_dK
,
cu_seqlens_kv
.
index
({
torch
::
indexing
::
Slice
(
-
1
,
torch
::
indexing
::
None
)}));
mha_fill
(
te_dV
,
cu_seqlens_kv
.
index
({
torch
::
indexing
::
Slice
(
-
1
,
torch
::
indexing
::
None
)}));
}
else
{
dQ
.
fill_
(
0
);
dK
.
fill_
(
0
);
dV
.
fill_
(
0
);
if
(
set_zero
&&
(
nvte_get_qkv_format
(
qkv_layout
)
==
NVTE_QKV_Format
::
NVTE_THD
))
{
if
(((
h_q
*
d_qk
)
%
block_size
==
0
)
&&
((
h_kv
*
d_qk
)
%
block_size
==
0
)
&&
dQ
.
is_contiguous
()
&&
dK
.
is_contiguous
()
&&
dV
.
is_contiguous
())
{
mha_fill
(
te_dQ
,
cu_seqlens_q
.
index
({
torch
::
indexing
::
Slice
(
-
1
,
torch
::
indexing
::
None
)}));
mha_fill
(
te_dK
,
cu_seqlens_kv
.
index
({
torch
::
indexing
::
Slice
(
-
1
,
torch
::
indexing
::
None
)}));
mha_fill
(
te_dV
,
cu_seqlens_kv
.
index
({
torch
::
indexing
::
Slice
(
-
1
,
torch
::
indexing
::
None
)}));
}
else
{
dQ
.
fill_
(
0
);
dK
.
fill_
(
0
);
dV
.
fill_
(
0
);
}
}
}
else
if
(
qkv_type
==
DType
::
kBFloat16
||
qkv_type
==
DType
::
kFloat16
)
{
}
else
if
(
dqkv_type
==
DType
::
kBFloat16
||
dqkv_type
==
DType
::
kFloat16
)
{
if
(
nvte_get_qkv_format
(
qkv_layout
)
==
NVTE_QKV_Format
::
NVTE_THD
)
{
dQ
.
fill_
(
0
);
dK
.
fill_
(
0
);
...
...
@@ -510,6 +561,15 @@ std::vector<py::object> fused_attn_bwd(
}
}
// create dSoftmaxOffset in the same shape as SoftmaxOffset
at
::
Tensor
dSoftmaxOffset
;
TensorWrapper
te_dSoftmaxOffset
;
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
options
=
torch
::
TensorOptions
().
dtype
(
at
::
kFloat
).
device
(
torch
::
kCUDA
);
dSoftmaxOffset
=
torch
::
empty
({
1
,
static_cast
<
int64_t
>
(
h_q
),
1
,
1
},
options
);
te_dSoftmaxOffset
=
makeTransformerEngineTensor
(
dSoftmaxOffset
);
}
// create workspace
TensorWrapper
workspace
;
...
...
@@ -518,10 +578,10 @@ std::vector<py::object> fused_attn_bwd(
nvte_fused_attn_bwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_O
.
data
(),
te_dO
.
data
(),
te_S
.
data
(),
te_dP
.
data
(),
&
nvte_aux_tensor_pack
,
te_dQ
.
data
(),
te_dK
.
data
(),
te_dV
.
data
(),
te_dBias
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size
[
0
]
,
window_size
[
1
],
deterministic
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
te_dSoftmaxOffset
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
deterministic
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// allocate memory for workspace
...
...
@@ -534,16 +594,16 @@ std::vector<py::object> fused_attn_bwd(
nvte_fused_attn_bwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_O
.
data
(),
te_dO
.
data
(),
te_S
.
data
(),
te_dP
.
data
(),
&
nvte_aux_tensor_pack
,
te_dQ
.
data
(),
te_dK
.
data
(),
te_dV
.
data
(),
te_dBias
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size
[
0
]
,
window_size
[
1
],
deterministic
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
te_dSoftmaxOffset
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
deterministic
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// destroy tensor wrappers
nvte_tensor_pack_destroy
(
&
nvte_aux_tensor_pack
);
return
{
py_dQ
,
py_dK
,
py_dV
,
py
::
cast
(
dBias
)};
return
{
py_dQ
,
py_dK
,
py_dV
,
py
::
cast
(
dBias
)
,
py
::
cast
(
dSoftmaxOffset
)
};
#endif
}
...
...
@@ -610,7 +670,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
// Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1
int
seq_dim
=
tensor
.
dim
()
==
3
?
0
:
1
;
int
batch
=
cu_seqlens
.
size
(
0
)
-
1
;
int
num_heads
=
tensor
.
size
(
seq_dim
+
1
);
int
dim_per_head
=
tensor
.
size
(
seq_dim
+
2
);
int
hidden_size_in_bytes
=
num_heads
*
dim_per_head
*
c10
::
elementSize
(
tensor
.
scalar_type
());
...
...
@@ -774,8 +833,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
NVTE_CHECK
(
world_size
>
0
);
NVTE_CHECK
(
total_tokens
>
0
&&
total_tokens
%
(
world_size
*
2
)
==
0
);
int
batch
=
cu_seqlens
.
size
(
0
)
-
1
;
std
::
vector
<
int64_t
>
shape
=
{
total_tokens
/
world_size
};
at
::
Tensor
output
=
at
::
empty
(
shape
,
at
::
CUDA
(
at
::
ScalarType
::
Int
));
...
...
@@ -813,7 +870,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
**************************************************************************************************/
at
::
Tensor
convert_bshd_to_thd
(
at
::
Tensor
tensor
,
at
::
Tensor
cu_seqlens
,
int
t
)
{
int
max_seq_len
=
tensor
.
size
(
1
);
int
h
=
tensor
.
size
(
2
);
int
d
=
tensor
.
size
(
3
);
std
::
vector
<
int64_t
>
shape
=
{
t
,
h
,
d
};
...
...
transformer_engine/pytorch/csrc/extensions/bias.cpp
View file @
53fa872c
...
...
@@ -54,10 +54,25 @@ std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
// Unfused impl if quantizer is not supported
const
bool
with_fused_dbias_quantize_kernel
=
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
());
if
(
!
with_fused_dbias_quantize_kernel
)
{
// Check if fused kernel is supported
bool
with_fused_kernel
=
false
;
if
(
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
auto
prop
=
at
::
cuda
::
getCurrentDeviceProperties
();
const
size_t
sm_arch
=
10
*
prop
->
major
+
prop
->
minor
;
if
(
sm_arch
>=
100
)
{
// Fused kernel for dbias + FP8 cast on SM arch 10.0+
with_fused_kernel
=
true
;
}
else
if
(
quantizer_cpp
->
rowwise_usage
&&
quantizer_cpp
->
columnwise_usage
)
{
// Fused kernel for dbias + FP8 cast + FP8 transpose
with_fused_kernel
=
true
;
}
}
else
if
(
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
// Fused kernel for dbias + MXFP8 quantize
with_fused_kernel
=
true
;
}
// Apply unfused impl if fused kernel is not supported
if
(
!
with_fused_kernel
)
{
at
::
sum_out
(
grad_bias_torch
,
grad_output_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
quantizer_cpp
->
quantize
(
grad_output_nvte
,
grad_input_nvte
);
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
...
...
@@ -122,13 +137,27 @@ std::vector<py::object> dact_dbias(
}
// Choose implementation
enum
class
Impl
{
UNFUSED
,
FUSED_DACT_DBIAS_QUANTIZE
,
FUSED_DACT_AMAX
};
enum
class
Impl
{
UNFUSED
,
FUSED_DACT_DBIAS_QUANTIZE
,
FUSED_DACT_AMAX_FP8
,
FUSED_DACT_AMAX_NVFP4
};
Impl
impl
=
Impl
::
UNFUSED
;
if
(
detail
::
IsFloat8Quantizers
(
quantizer_py
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer_py
.
ptr
()))
{
impl
=
Impl
::
FUSED_DACT_DBIAS_QUANTIZE
;
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer_py
.
ptr
()))
{
impl
=
Impl
::
FUSED_DACT_AMAX
;
impl
=
Impl
::
FUSED_DACT_AMAX_FP8
;
}
else
if
(
detail
::
IsNVFP4Quantizers
(
quantizer_py
.
ptr
()))
{
auto
nvfp4_quantizer_cpp
=
dynamic_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
nvfp4_quantizer_cpp
!=
nullptr
,
"Could not cast to NVFP4 quantizer"
);
if
(
nvfp4_quantizer_cpp
->
with_rht
&&
nvfp4_quantizer_cpp
->
with_post_rht_amax
)
{
// Post-RHT amax is handled within NVFP4 quantizer
impl
=
Impl
::
UNFUSED
;
}
else
{
impl
=
Impl
::
FUSED_DACT_AMAX_NVFP4
;
}
}
// Perform compute
...
...
@@ -172,20 +201,38 @@ std::vector<py::object> dact_dbias(
});
break
;
}
case
Impl
::
FUSED_DACT_AMAX
:
// Fused dact-amax kernel, unfused dbias and quantize
case
Impl
::
FUSED_DACT_AMAX
_FP8
:
// Fused dact-amax kernel, unfused dbias and
FP8
quantize
{
auto
*
quantizer_cpp_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
quantizer_cpp_cs
!=
nullptr
,
auto
*
fp8_quantizer_cpp
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
fp8_quantizer_cpp
!=
nullptr
,
"Invalid quantizer for fused dact-amax kernel impl"
);
auto
[
temp_nvte
,
temp_py
]
=
quantizer_cpp_cs
->
create_hp_tensor_with_amax
(
input_shape
,
grad_output_dtype
);
fp8_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
input_shape
,
grad_output_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_nvte
.
data
(),
act_input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
const
auto
temp_torch
=
temp_py
.
cast
<
at
::
Tensor
>
();
at
::
sum_out
(
grad_bias_torch
,
temp_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
fp8_quantizer_cpp
->
quantize_with_amax
(
temp_nvte
,
grad_input_nvte
);
break
;
}
case
Impl
::
FUSED_DACT_AMAX_NVFP4
:
// Fused dact-amax kernel, unfused dbias and NVFP4 quantize
{
auto
*
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
// Already checked cast is valid
NVTE_CHECK
(
nvfp4_quantizer_cpp
!=
nullptr
,
"Invalid quantizer for fused dact-amax kernel impl"
);
auto
[
temp_nvte
,
temp_py
]
=
nvfp4_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
grad_input_nvte
,
grad_output_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_nvte
.
data
(),
act_input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
const
auto
temp_torch
=
temp_py
.
cast
<
at
::
Tensor
>
();
at
::
sum_out
(
grad_bias_torch
,
temp_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
quantizer_cpp
_cs
->
quantize_with_amax
(
temp_nvte
,
grad_input_nvte
);
nvfp4_
quantizer_cpp
->
quantize_with_amax
(
temp_nvte
,
grad_input_nvte
);
break
;
}
default:
...
...
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
53fa872c
...
...
@@ -37,7 +37,18 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
// Convert input tensor to C++ object
auto
input_contiguous
=
tensor
.
contiguous
();
const
auto
input_cpp
=
makeTransformerEngineTensor
(
input_contiguous
);
auto
input_cpp
=
makeTransformerEngineTensor
(
input_contiguous
);
// Set amax if use_existing_amax = true (only valid for CS)
bool
use_existing_amax
=
false
;
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
use_existing_amax
=
quantizer
.
attr
(
"use_existing_amax"
).
cast
<
bool
>
();
if
(
use_existing_amax
)
{
const
at
::
Tensor
&
amax
=
quantizer
.
attr
(
"amax"
).
cast
<
at
::
Tensor
>
();
input_cpp
.
set_amax
(
amax
.
data_ptr
(),
GetTransformerEngineDType
(
amax
.
scalar_type
()),
getTensorShape
(
amax
));
}
}
// Initialize output tensor
TensorWrapper
output_cpp
;
...
...
@@ -57,7 +68,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
}
// Perform quantization
quantizer_cpp
->
quantize
(
input_cpp
,
output_cpp
,
noop_flag_cpp
);
if
(
use_existing_amax
)
{
auto
*
quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
quantizer_cs
->
quantize_with_amax
(
input_cpp
,
output_cpp
,
noop_flag_cpp
);
}
else
{
quantizer_cpp
->
quantize
(
input_cpp
,
output_cpp
,
noop_flag_cpp
);
}
return
output_py
;
}
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
53fa872c
...
...
@@ -215,6 +215,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
// Construct GEMM config
transformer_engine
::
MatmulConfigWrapper
config
;
if
(
grad
)
{
config
.
set_dbias_tensor
(
bias_tensor
.
data
());
config
.
set_with_dgelu_epilogue
(
gelu
);
}
else
{
config
.
set_bias_tensor
(
bias_tensor
.
data
());
config
.
set_with_gelu_epilogue
(
gelu
);
}
config
.
set_epilogue_aux_tensor
(
te_pre_gelu_out
.
data
());
config
.
set_use_split_accumulator
(
use_split_accumulator
);
config
.
set_sm_count
(
num_math_sms
);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -278,10 +291,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_gemm_scaled
(
A_tensor
.
data
(),
B_tensor
.
data
(),
out_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
alpha
,
*
beta
,
use_split_accumulator
,
num_math_sms
,
main_stream
);
nvte_cublas_gemm_v2
(
transa
,
transb
,
&
alpha
,
A_tensor
.
data
(),
B_tensor
.
data
(),
&
beta
.
value
(),
out_tensor
.
data
(),
out_tensor
.
data
(),
te_workspace
.
data
(),
config
,
main_stream
);
});
}
}
else
{
...
...
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
53fa872c
...
...
@@ -66,67 +66,102 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Input and param tensors
auto
none
=
py
::
none
();
const
TensorWrapper
&
input_
cu
=
makeTransformerEngineTensor
(
input
,
none
);
const
TensorWrapper
&
weight_
cu
=
makeTransformerEngineTensor
(
weight
,
none
);
TensorWrapper
bias_
cu
;
const
TensorWrapper
&
input_
nvte
=
makeTransformerEngineTensor
(
input
,
none
);
const
TensorWrapper
&
weight_
nvte
=
makeTransformerEngineTensor
(
weight
,
none
);
TensorWrapper
bias_
nvte
;
if
(
bias
.
has_value
())
{
bias_
cu
=
makeTransformerEngineTensor
(
*
bias
);
bias_
nvte
=
makeTransformerEngineTensor
(
*
bias
);
}
// Tensor dimensions
const
size_t
N
=
static_cast
<
size_t
>
(
input_cu
.
siz
e
(
0
));
const
size_t
H
=
static_cast
<
size_t
>
(
input_cu
.
size
(
1
)
);
const
std
::
vector
<
size_t
>
size
=
{
N
,
H
}
;
const
auto
shape
=
nvte_shape_to_vector
(
input_nvte
.
shap
e
());
const
auto
outer_size
=
product
(
shape
)
/
shape
.
back
(
);
const
auto
inner_size
=
shape
.
back
()
;
// Tensors to save for backward pass
at
::
Tensor
mu
=
at
::
empty
({
static_cast
<
int64_t
>
(
N
)},
at
::
CUDA
(
at
::
kFloat
));
at
::
Tensor
rsigma
=
at
::
empty
({
static_cast
<
int64_t
>
(
N
)},
at
::
CUDA
(
at
::
kFloat
));
TensorWrapper
mu_
cu
=
makeTransformerEngineTensor
(
mu
);
TensorWrapper
rsigma_
cu
=
makeTransformerEngineTensor
(
rsigma
);
at
::
Tensor
mu
_py
=
at
::
empty
({
static_cast
<
int64_t
>
(
outer_size
)},
at
::
CUDA
(
at
::
kFloat
));
at
::
Tensor
rsigma
_py
=
at
::
empty
({
static_cast
<
int64_t
>
(
outer_size
)},
at
::
CUDA
(
at
::
kFloat
));
TensorWrapper
mu_
nvte
=
makeTransformerEngineTensor
(
mu
_py
);
TensorWrapper
rsigma_
nvte
=
makeTransformerEngineTensor
(
rsigma
_py
);
// Output tensor
std
::
unique_ptr
<
Quantizer
>
my_
quantizer
=
convert_quantizer
(
quantizer
);
TensorWrapper
out_
cu
;
auto
quantizer
_cpp
=
convert_quantizer
(
quantizer
);
TensorWrapper
out_
nvte
;
if
(
out
.
is_none
())
{
std
::
tie
(
out_
cu
,
out
)
=
my_
quantizer
->
create_tensor
(
s
iz
e
,
out_dtype
);
std
::
tie
(
out_
nvte
,
out
)
=
quantizer
_cpp
->
create_tensor
(
s
hap
e
,
out_dtype
);
}
else
{
out_
cu
=
makeTransformerEngineTensor
(
out
,
quantizer
);
out_
nvte
=
makeTransformerEngineTensor
(
out
,
quantizer
);
}
// Determine whether to avoid fused kernel
bool
force_unfused_kernel
=
true
;
if
(
quantizer
.
is_none
())
{
// No need for separate quantization step if output is unquantized
force_unfused_kernel
=
false
;
}
else
if
(
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
// Always used fused kernel for FP8 delayed scaling
force_unfused_kernel
=
false
;
// Choose implementation
enum
class
Impl
{
// Compute norm in high precision, then quantize
UNFUSED
,
// Compute norm directly
FULLY_FUSED
,
// Compute norm and amax in high precision, then quantize to FP8
FUSED_NORM_AMAX_FP8
,
// Compute norm and amax in high precision, then quantize to NVFP4
FUSED_NORM_AMAX_NVFP4
};
Impl
impl
=
Impl
::
UNFUSED
;
if
(
quantizer
.
is_none
()
||
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
impl
=
Impl
::
FULLY_FUSED
;
}
else
if
(
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel
=
N
%
128
!=
0
||
H
%
128
!=
0
;
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
)
&&
outer_size
%
128
==
0
&&
inner_size
%
128
==
0
)
{
// cuDNN MXFP8 kernel requires full 128x128 tiles
impl
=
Impl
::
FULLY_FUSED
;
}
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
fp8_quantizer_cpp
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
fp8_quantizer_cpp
!=
nullptr
,
"Could not cast to FP8 current scaling quantizer"
);
impl
=
Impl
::
FUSED_NORM_AMAX_FP8
;
}
else
if
(
detail
::
IsNVFP4Quantizers
(
quantizer
.
ptr
()))
{
auto
nvfp4_quantizer_cpp
=
dynamic_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
nvfp4_quantizer_cpp
!=
nullptr
,
"Could not cast to NVFP4 quantizer"
);
if
(
nvfp4_quantizer_cpp
->
with_rht
&&
nvfp4_quantizer_cpp
->
with_post_rht_amax
)
{
// Post-RHT amax is handled within NVFP4 quantizer
impl
=
Impl
::
UNFUSED
;
}
else
if
(
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
// TE kernel supports amax output
impl
=
Impl
::
FUSED_NORM_AMAX_NVFP4
;
}
}
TensorWrapper
unquantized_out_cu
;
// Construct unquantized output tensor if needed
TensorWrapper
unquantized_out_nvte
;
py
::
object
unquantized_out
;
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
std
::
tie
(
unquantized_out_cu
,
unquantized_out
)
=
my_quantizer_cs
->
create_hp_tensor_with_amax
(
size
,
out_dtype
);
}
else
{
TensorWrapper
*
kernel_out_nvte
=
&
out_nvte
;
switch
(
impl
)
{
case
Impl
::
UNFUSED
:
{
NoneQuantizer
q
{
none
};
std
::
tie
(
unquantized_out_cu
,
unquantized_out
)
=
q
.
create_tensor
(
size
,
out_dtype
);
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
q
.
create_tensor
(
shape
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
case
Impl
::
FUSED_NORM_AMAX_FP8
:
{
auto
fp8_quantizer_cpp
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
fp8_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
shape
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
case
Impl
::
FUSED_NORM_AMAX_NVFP4
:
{
auto
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
nvfp4_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
out_nvte
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
default:
{
}
}
TensorWrapper
&
kernel_out_cu
=
force_unfused_kernel
?
unquantized_out_cu
:
out_cu
;
// Query workspace size
TensorWrapper
workspace
;
NVTE_SCOPED_GIL_RELEASE
({
nvte_layernorm_fwd
(
input_cu
.
data
(),
weight_cu
.
data
(),
bias_cu
.
data
(),
eps
,
kernel_out_cu
.
data
(),
mu_cu
.
data
(),
rsigma_cu
.
data
(),
workspace
.
data
(),
nvte_layernorm_fwd
(
input_nvte
.
data
(),
weight_nvte
.
data
(),
bias_nvte
.
data
(),
eps
,
kernel_out_nvte
->
data
(),
mu_nvte
.
data
(),
rsigma_nvte
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
...
...
@@ -138,24 +173,31 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_layernorm_fwd
(
input_cu
.
data
(),
weight_cu
.
data
(),
bias_cu
.
data
(),
eps
,
kernel_out_cu
.
data
(),
mu_cu
.
data
(),
rsigma_cu
.
data
(),
workspace
.
data
(),
nvte_layernorm_fwd
(
input_nvte
.
data
(),
weight_nvte
.
data
(),
bias_nvte
.
data
(),
eps
,
kernel_out_nvte
->
data
(),
mu_nvte
.
data
(),
rsigma_nvte
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
my_quantizer_cs
->
quantize_with_amax
(
unquantized_out_cu
,
out_cu
);
}
else
{
my_quantizer
->
quantize
(
unquantized_out_cu
,
out_cu
);
// Quantize output if needed
switch
(
impl
)
{
case
Impl
::
UNFUSED
:
{
quantizer_cpp
->
quantize
(
unquantized_out_nvte
,
out_nvte
);
}
break
;
case
Impl
::
FUSED_NORM_AMAX_FP8
:
{
auto
fp8_quantizer_cpp
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
fp8_quantizer_cpp
->
quantize_with_amax
(
unquantized_out_nvte
,
out_nvte
);
}
break
;
case
Impl
::
FUSED_NORM_AMAX_NVFP4
:
{
auto
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
nvfp4_quantizer_cpp
->
quantize_with_amax
(
unquantized_out_nvte
,
out_nvte
);
}
break
;
default:
{
}
}
return
{
out
,
py
::
cast
(
mu
),
py
::
cast
(
rsigma
)};
return
{
out
,
py
::
cast
(
mu
_py
),
py
::
cast
(
rsigma
_py
)};
}
std
::
vector
<
py
::
object
>
rmsnorm_bwd
(
const
at
::
Tensor
&
dz
,
const
at
::
Tensor
&
x
,
...
...
@@ -254,61 +296,95 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Input and param tensors
auto
none
=
py
::
none
();
const
TensorWrapper
&
input_
cu
=
makeTransformerEngineTensor
(
input
,
none
);
const
TensorWrapper
&
weight_
cu
=
makeTransformerEngineTensor
(
weight
,
none
);
const
TensorWrapper
&
input_
nvte
=
makeTransformerEngineTensor
(
input
,
none
);
const
TensorWrapper
&
weight_
nvte
=
makeTransformerEngineTensor
(
weight
,
none
);
// Tensor dimensions
const
size_t
N
=
static_cast
<
size_t
>
(
input_
cu
.
shape
()
.
data
[
0
]
);
const
size_t
H
=
static_cast
<
size_t
>
(
input_cu
.
shape
().
data
[
1
]
);
const
std
::
vector
<
size_t
>
size
=
{
N
,
H
}
;
const
auto
shape
=
nvte_shape_to_vector
(
input_
nvte
.
shape
());
const
auto
outer_size
=
product
(
shape
)
/
shape
.
back
(
);
const
auto
inner_size
=
shape
.
back
()
;
// Tensors to save for backward pass
a
uto
rsigma
=
at
::
empty
({
static_cast
<
int64_t
>
(
N
)},
at
::
CUDA
(
at
::
kFloat
));
auto
rsigma_
cu
=
makeTransformerEngineTensor
(
rsigma
);
a
t
::
Tensor
rsigma
_py
=
at
::
empty
({
static_cast
<
int64_t
>
(
outer_size
)},
at
::
CUDA
(
at
::
kFloat
));
TensorWrapper
rsigma_
nvte
=
makeTransformerEngineTensor
(
rsigma
_py
);
// Output tensor
std
::
unique_ptr
<
Quantizer
>
my_
quantizer
=
convert_quantizer
(
quantizer
);
TensorWrapper
out_
cu
;
auto
quantizer
_cpp
=
convert_quantizer
(
quantizer
);
TensorWrapper
out_
nvte
;
if
(
out
.
is_none
())
{
std
::
tie
(
out_
cu
,
out
)
=
my_
quantizer
->
create_tensor
(
s
iz
e
,
out_dtype
);
std
::
tie
(
out_
nvte
,
out
)
=
quantizer
_cpp
->
create_tensor
(
s
hap
e
,
out_dtype
);
}
else
{
out_
cu
=
makeTransformerEngineTensor
(
out
,
quantizer
);
out_
nvte
=
makeTransformerEngineTensor
(
out
,
quantizer
);
}
// Determine whether to avoid fused kernel
bool
force_unfused_kernel
=
true
;
if
(
quantizer
.
is_none
())
{
// No need for separate quantization step if output is unquantized
force_unfused_kernel
=
false
;
}
else
if
(
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
// Always used fused kernel for FP8 delayed scaling
force_unfused_kernel
=
false
;
// Choose implementation
enum
class
Impl
{
// Compute norm in high precision, then quantize
UNFUSED
,
// Compute norm directly
FULLY_FUSED
,
// Compute norm and amax in high precision, then quantize to FP8
FUSED_NORM_AMAX_FP8
,
// Compute norm and amax in high precision, then quantize to NVFP4
FUSED_NORM_AMAX_NVFP4
};
Impl
impl
=
Impl
::
UNFUSED
;
if
(
quantizer
.
is_none
()
||
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
impl
=
Impl
::
FULLY_FUSED
;
}
else
if
(
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel
=
N
%
128
!=
0
||
H
%
128
!=
0
;
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
)
&&
outer_size
%
128
==
0
&&
inner_size
%
128
==
0
)
{
// cuDNN MXFP8 kernel requires full 128x128 tiles
impl
=
Impl
::
FULLY_FUSED
;
}
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
fp8_quantizer_cpp
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
fp8_quantizer_cpp
!=
nullptr
,
"Could not cast to FP8 current scaling quantizer"
);
impl
=
Impl
::
FUSED_NORM_AMAX_FP8
;
}
else
if
(
detail
::
IsNVFP4Quantizers
(
quantizer
.
ptr
()))
{
auto
nvfp4_quantizer_cpp
=
dynamic_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
nvfp4_quantizer_cpp
!=
nullptr
,
"Could not cast to NVFP4 quantizer"
);
if
(
nvfp4_quantizer_cpp
->
with_rht
&&
nvfp4_quantizer_cpp
->
with_post_rht_amax
)
{
// Post-RHT amax is handled within NVFP4 quantizer
impl
=
Impl
::
UNFUSED
;
}
else
if
(
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
// TE kernel supports amax output
impl
=
Impl
::
FUSED_NORM_AMAX_NVFP4
;
}
}
TensorWrapper
unquantized_out_cu
;
// Construct unquantized output tensor if needed
TensorWrapper
unquantized_out_nvte
;
py
::
object
unquantized_out
;
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
std
::
tie
(
unquantized_out_cu
,
unquantized_out
)
=
my_quantizer_cs
->
create_hp_tensor_with_amax
(
size
,
out_dtype
);
}
else
{
TensorWrapper
*
kernel_out_nvte
=
&
out_nvte
;
switch
(
impl
)
{
case
Impl
::
UNFUSED
:
{
NoneQuantizer
q
{
none
};
std
::
tie
(
unquantized_out_cu
,
unquantized_out
)
=
q
.
create_tensor
(
size
,
out_dtype
);
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
q
.
create_tensor
(
shape
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
case
Impl
::
FUSED_NORM_AMAX_FP8
:
{
auto
fp8_quantizer_cpp
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
fp8_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
shape
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
case
Impl
::
FUSED_NORM_AMAX_NVFP4
:
{
auto
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
nvfp4_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
out_nvte
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
default:
{
}
}
TensorWrapper
&
kernel_out_cu
=
force_unfused_kernel
?
unquantized_out_cu
:
out_cu
;
// Query workspace size
TensorWrapper
workspace
;
NVTE_SCOPED_GIL_RELEASE
({
nvte_rmsnorm_fwd
(
input_
cu
.
data
(),
weight_
cu
.
data
(),
eps
,
kernel_out_
cu
.
data
(),
rsigma_cu
.
data
(),
workspace
.
data
(),
nvte_rmsnorm_fwd
(
input_
nvte
.
data
(),
weight_
nvte
.
data
(),
eps
,
kernel_out_
nvte
->
data
(),
rsigma_nvte
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
...
...
@@ -320,24 +396,30 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_rmsnorm_fwd
(
input_
cu
.
data
(),
weight_
cu
.
data
(),
eps
,
kernel_out_
cu
.
data
(),
rsigma_cu
.
data
(),
workspace
.
data
(),
nvte_rmsnorm_fwd
(
input_
nvte
.
data
(),
weight_
nvte
.
data
(),
eps
,
kernel_out_
nvte
->
data
(),
rsigma_nvte
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
my_quantizer_cs
->
quantize_with_amax
(
unquantized_out_cu
,
out_cu
);
}
else
{
my_quantizer
->
quantize
(
unquantized_out_cu
,
out_cu
);
// Quantize output if needed
switch
(
impl
)
{
case
Impl
::
UNFUSED
:
{
quantizer_cpp
->
quantize
(
unquantized_out_nvte
,
out_nvte
);
}
break
;
case
Impl
::
FUSED_NORM_AMAX_FP8
:
{
auto
fp8_quantizer_cpp
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
fp8_quantizer_cpp
->
quantize_with_amax
(
unquantized_out_nvte
,
out_nvte
);
}
break
;
case
Impl
::
FUSED_NORM_AMAX_NVFP4
:
{
auto
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
nvfp4_quantizer_cpp
->
quantize_with_amax
(
unquantized_out_nvte
,
out_nvte
);
}
break
;
default:
{
}
}
return
{
out
,
py
::
none
(),
py
::
cast
(
rsigma
)};
return
{
out
,
py
::
none
(),
py
::
cast
(
rsigma
_py
)};
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
53fa872c
...
...
@@ -32,6 +32,9 @@ PyTypeObject *MXFP8QuantizerClass = nullptr;
PyTypeObject
*
Float8BlockwiseQTensorPythonClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQTensorBasePythonClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQuantizerClass
=
nullptr
;
PyTypeObject
*
NVFP4TensorPythonClass
=
nullptr
;
PyTypeObject
*
NVFP4TensorBasePythonClass
=
nullptr
;
PyTypeObject
*
NVFP4QuantizerClass
=
nullptr
;
void
init_float8_extension
()
{
if
(
Float8TensorPythonClass
)
return
;
...
...
@@ -86,10 +89,26 @@ void init_float8blockwise_extension() {
"Internal error: could not initialize pyTorch float8blockwise extension."
);
}
void
init_nvfp4_extensions
()
{
if
(
NVFP4TensorPythonClass
)
return
;
auto
nvfp4_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.nvfp4_tensor"
);
NVFP4QuantizerClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
nvfp4_module
.
ptr
(),
"NVFP4Quantizer"
));
NVFP4TensorPythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
nvfp4_module
.
ptr
(),
"NVFP4Tensor"
));
auto
nvfp4_base_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor._internal.nvfp4_tensor_base"
);
NVFP4TensorBasePythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
nvfp4_base_module
.
ptr
(),
"NVFP4TensorBase"
));
NVTE_CHECK
(
NVFP4TensorPythonClass
!=
nullptr
,
"Internal error: could not initialize pyTorch NVFP4 extension."
);
}
void
init_extension
()
{
init_float8_extension
();
init_mxfp8_extension
();
init_float8blockwise_extension
();
init_nvfp4_extensions
();
}
}
// namespace transformer_engine::pytorch
...
...
transformer_engine/pytorch/csrc/pybind.h
View file @
53fa872c
...
...
@@ -40,13 +40,12 @@ extern PyTypeObject *MXFP8QuantizerClass;
extern
PyTypeObject
*
Float8BlockwiseQTensorPythonClass
;
extern
PyTypeObject
*
Float8BlockwiseQTensorBasePythonClass
;
extern
PyTypeObject
*
Float8BlockwiseQuantizerClass
;
extern
PyTypeObject
*
NVFP4TensorPythonClass
;
extern
PyTypeObject
*
NVFP4TensorBasePythonClass
;
extern
PyTypeObject
*
NVFP4QuantizerClass
;
void
init_extension
();
void
init_float8_extension
();
void
init_mxfp8_extension
();
namespace
detail
{
inline
bool
IsFloat8Quantizers
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
Float8QuantizerClass
;
}
...
...
@@ -69,11 +68,17 @@ inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) {
return
Py_TYPE
(
obj
)
==
Float8BlockwiseQuantizerClass
;
}
inline
bool
IsNVFP4Quantizers
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
NVFP4QuantizerClass
;
}
inline
bool
IsFloat8BlockwiseQTensor
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
Float8BlockwiseQTensorPythonClass
||
Py_TYPE
(
obj
)
==
Float8BlockwiseQTensorBasePythonClass
;
}
inline
bool
IsNVFP4Tensor
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
NVFP4TensorPythonClass
||
Py_TYPE
(
obj
)
==
NVFP4TensorBasePythonClass
;
}
TensorWrapper
NVTETensorFromFloat8Tensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
);
template
<
typename
T
>
...
...
@@ -88,6 +93,8 @@ std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params);
TensorWrapper
NVTETensorFromFloat8BlockwiseQTensor
(
py
::
handle
tensor
,
Quantizer
*
quantization_params
);
TensorWrapper
NVTETensorFromNVFP4Tensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
);
inline
bool
IsFloatingPointType
(
at
::
ScalarType
type
)
{
return
type
==
at
::
kFloat
||
type
==
at
::
kHalf
||
type
==
at
::
kBFloat16
;
}
...
...
@@ -100,8 +107,9 @@ constexpr std::array custom_types_converters = {
std
::
make_tuple
(
IsMXFP8Tensor
,
IsMXFP8Quantizers
,
NVTETensorFromMXFP8Tensor
,
CreateQuantizer
<
MXFP8Quantizer
>
),
std
::
make_tuple
(
IsFloat8BlockwiseQTensor
,
IsFloat8BlockwiseQuantizers
,
NVTETensorFromFloat8BlockwiseQTensor
,
CreateQuantizer
<
Float8BlockQuantizer
>
)};
NVTETensorFromFloat8BlockwiseQTensor
,
CreateQuantizer
<
Float8BlockQuantizer
>
),
std
::
make_tuple
(
IsNVFP4Tensor
,
IsNVFP4Quantizers
,
NVTETensorFromNVFP4Tensor
,
CreateQuantizer
<
NVFP4Quantizer
>
)};
}
// namespace detail
}
// namespace transformer_engine::pytorch
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
53fa872c
...
...
@@ -31,8 +31,20 @@ std::vector<T> make_transpose_shape(const std::vector<S>& shape) {
return
ret
;
}
/*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */
template
<
typename
T
=
size_t
>
std
::
vector
<
T
>
convert_shape_for_fp4
(
const
std
::
vector
<
T
>&
shape
)
{
std
::
vector
<
T
>
ret
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
()
-
1
;
++
i
)
{
ret
.
push_back
(
shape
[
i
]);
}
ret
.
push_back
(
shape
.
back
()
/
2
);
return
ret
;
}
}
// namespace
constexpr
size_t
NVFP4_BLOCK_SIZE
=
16
;
constexpr
size_t
MXFP8_BLOCK_SIZE
=
32
;
Quantizer
::
Quantizer
(
const
py
::
handle
&
quantizer
)
{
...
...
@@ -376,10 +388,15 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
Float8CurrentScalingQuantizer
::
create_hp_tensor_with_amax
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
{
std
::
pair
<
TensorWrapper
,
py
::
object
>
Float8CurrentScalingQuantizer
::
create_unquantized_tensor_with_amax
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
data
)
{
amax
.
zero_
();
auto
[
out_cpp
,
out_py
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
shape
,
dtype
);
auto
out
=
data
.
has_value
()
?
NoneQuantizer
(
py
::
none
()).
create_tensor
(
shape
,
dtype
,
data
.
value
())
:
NoneQuantizer
(
py
::
none
()).
create_tensor
(
shape
,
dtype
);
TensorWrapper
out_cpp
=
std
::
move
(
out
.
first
);
py
::
object
out_py
=
std
::
move
(
out
.
second
);
out_cpp
.
set_amax
(
amax
.
data_ptr
(),
GetTransformerEngineDType
(
amax
.
scalar_type
()),
getTensorShape
(
amax
));
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
...
...
@@ -899,7 +916,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
}
const
size_t
flat_last_dim
=
shape
.
size
()
>
0
?
shape
.
back
()
:
1
;
NVTE_CHECK
(
flat_first_dim
%
MXFP8_BLOCK_SIZE
==
0
&&
flat_last_dim
%
MXFP8_BLOCK_SIZE
==
0
,
"MXFP8 requires tensor dims that are divisble by "
,
MXFP8_BLOCK_SIZE
,
"MXFP8 requires tensor dims that are divis
i
ble by "
,
MXFP8_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
const
auto
rowwise_scale_inv_shape
=
get_scale_shape
(
shape
,
false
);
const
auto
columnwise_scale_inv_shape
=
get_scale_shape
(
shape
,
true
);
...
...
@@ -1095,7 +1112,7 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s
auto
last_dim
=
shape
.
back
();
NVTE_CHECK
(
last_dim
%
MXFP8_BLOCK_SIZE
==
0
&&
(
numel
/
last_dim
)
%
MXFP8_BLOCK_SIZE
==
0
,
"MXFP8 requires tensor dims that are divisble by "
,
MXFP8_BLOCK_SIZE
,
"MXFP8 requires tensor dims that are divis
i
ble by "
,
MXFP8_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
std
::
vector
<
size_t
>
scale_shape
;
...
...
@@ -1116,4 +1133,573 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s
return
scale_shape
;
}
NVFP4Quantizer
::
NVFP4Quantizer
(
const
py
::
handle
&
quantizer
)
:
Quantizer
(
quantizer
)
{
this
->
dtype
=
quantizer
.
attr
(
"dtype"
).
cast
<
DType
>
();
this
->
with_rht
=
quantizer
.
attr
(
"with_rht"
).
cast
<
bool
>
();
this
->
with_post_rht_amax
=
quantizer
.
attr
(
"with_post_rht_amax"
).
cast
<
bool
>
();
this
->
with_2d_quantization
=
quantizer
.
attr
(
"with_2d_quantization"
).
cast
<
bool
>
();
this
->
stochastic_rounding
=
quantizer
.
attr
(
"stochastic_rounding"
).
cast
<
bool
>
();
// Get amax reduction group if needed for NVFP4 AG
const
bool
with_amax_reduction
=
quantizer
.
attr
(
"with_amax_reduction"
).
cast
<
bool
>
();
c10
::
intrusive_ptr
<
dist_group_type
>
amax_reduction_group
;
if
(
with_amax_reduction
)
{
auto
group
=
quantizer
.
attr
(
"_canonicalized_amax_reduction_group"
)();
NVTE_CHECK
(
!
group
.
is_none
(),
"NVFP4Quantizer could not canonicalize amax reduction group"
);
amax_reduction_group
=
group
.
cast
<
c10
::
intrusive_ptr
<
dist_group_type
>>
();
}
this
->
with_amax_reduction
=
with_amax_reduction
;
this
->
amax_reduction_group
=
amax_reduction_group
;
this
->
rht_matrix_random_sign_mask_t
=
quantizer
.
attr
(
"rht_matrix_random_sign_mask_t"
).
cast
<
int
>
();
this
->
rht_matrix
=
quantizer
.
attr
(
"rht_matrix"
).
cast
<
at
::
Tensor
>
();
}
void
NVFP4Quantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{
// set dtype for rowwise and columnwise data in tensor wrapper
auto
rowwise_data
=
tensor
->
get_rowwise_data
();
rowwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
this
->
dtype
);
auto
columnwise_data
=
tensor
->
get_columnwise_data
();
columnwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
this
->
dtype
);
tensor
->
set_rowwise_data
(
rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
rowwise_data
.
dtype
),
rowwise_data
.
shape
);
tensor
->
set_columnwise_data
(
columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
columnwise_data
.
dtype
),
columnwise_data
.
shape
);
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
NVFP4Quantizer
::
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
{
using
namespace
pybind11
::
literals
;
// Tensor dimensions
const
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
size_t
flat_first_dim
=
1
;
if
(
shape
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
shape
.
size
()
-
1
;
++
i
)
{
flat_first_dim
*=
shape
[
i
];
}
}
const
size_t
flat_last_dim
=
shape
.
size
()
>
0
?
shape
.
back
()
:
1
;
NVTE_CHECK
(
flat_first_dim
%
NVFP4_BLOCK_SIZE
==
0
,
"First dim for NVFP4 must be divisible by "
,
NVFP4_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
NVTE_CHECK
(
flat_last_dim
%
NVFP4_BLOCK_SIZE
==
0
,
"NVFP4 requires tensor dims that are divisible by "
,
NVFP4_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
const
auto
rowwise_scale_inv_shape
=
get_scale_shape
(
shape
,
false
);
const
auto
columnwise_scale_inv_shape
=
get_scale_shape
(
shape
,
true
);
// Allocate tensors
at
::
Tensor
rowwise_data_tensor
,
rowwise_scale_inv_tensor
,
amax_rowwise
;
at
::
Tensor
columnwise_data_tensor
,
columnwise_scale_inv_tensor
,
amax_columnwise
;
const
auto
bit8_tensor_opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
const
auto
bit32_tensor_opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
if
(
rowwise_usage
)
{
const
std
::
vector
<
int64_t
>
scale_inv_shape_int64
(
rowwise_scale_inv_shape
.
begin
(),
rowwise_scale_inv_shape
.
end
());
rowwise_data_tensor
=
at
::
empty
(
convert_shape_for_fp4
(
shape_int64
),
bit8_tensor_opts
);
rowwise_scale_inv_tensor
=
at
::
empty
(
scale_inv_shape_int64
,
bit8_tensor_opts
);
amax_rowwise
=
at
::
empty
({
1
},
bit32_tensor_opts
);
}
if
(
columnwise_usage
)
{
const
std
::
vector
<
int64_t
>
scale_inv_shape_int64
(
columnwise_scale_inv_shape
.
begin
(),
columnwise_scale_inv_shape
.
end
());
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std
::
vector
<
int64_t
>
shape_int64_2d
=
{
static_cast
<
int64_t
>
(
flat_first_dim
),
static_cast
<
int64_t
>
(
flat_last_dim
)};
const
auto
transpose_shape_int64
=
make_transpose_shape
<
int64_t
>
(
shape_int64_2d
);
columnwise_data_tensor
=
at
::
empty
(
convert_shape_for_fp4
(
transpose_shape_int64
),
bit8_tensor_opts
);
columnwise_scale_inv_tensor
=
at
::
empty
(
scale_inv_shape_int64
,
bit8_tensor_opts
);
amax_columnwise
=
at
::
empty
({
1
},
bit32_tensor_opts
);
}
// Convert tensors to Python
auto
py_cast
=
[](
at
::
Tensor
&
tensor
,
bool
need_cast
)
->
py
::
object
{
return
need_cast
?
py
::
cast
(
tensor
)
:
py
::
none
();
};
auto
rowwise_data_py
=
py_cast
(
rowwise_data_tensor
,
rowwise_usage
);
auto
rowwise_scale_inv_py
=
py_cast
(
rowwise_scale_inv_tensor
,
rowwise_usage
);
auto
columnwise_data_py
=
py_cast
(
columnwise_data_tensor
,
columnwise_usage
);
auto
columnwise_scale_inv_py
=
py_cast
(
columnwise_scale_inv_tensor
,
columnwise_usage
);
auto
amax_rowwise_py
=
py_cast
(
amax_rowwise
,
rowwise_usage
);
auto
amax_columnwise_py
=
py_cast
(
amax_columnwise
,
columnwise_usage
);
// Construct Python NVFP4 tensor
py
::
object
out_py
;
if
(
internal
)
{
py
::
handle
NVFP4TensorClass
(
reinterpret_cast
<
PyObject
*>
(
NVFP4TensorBasePythonClass
));
out_py
=
NVFP4TensorClass
(
"rowwise_data"
_a
=
rowwise_data_py
,
"columnwise_data"
_a
=
columnwise_data_py
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
"columnwise_scale_inv"
_a
=
columnwise_scale_inv_py
,
"amax_rowwise"
_a
=
amax_rowwise_py
,
"amax_columnwise"
_a
=
amax_columnwise_py
,
"fp4_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
);
}
else
{
py
::
handle
NVFP4TensorClass
(
reinterpret_cast
<
PyObject
*>
(
NVFP4TensorPythonClass
));
out_py
=
NVFP4TensorClass
(
"shape"
_a
=
shape_int64
,
"dtype"
_a
=
GetATenDType
(
dtype
),
"rowwise_data"
_a
=
rowwise_data_py
,
"columnwise_data"
_a
=
columnwise_data_py
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
"columnwise_scale_inv"
_a
=
columnwise_scale_inv_py
,
"amax_rowwise"
_a
=
amax_rowwise_py
,
"amax_columnwise"
_a
=
amax_columnwise_py
,
"fp4_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
);
}
// Construct C++ tensor
TensorWrapper
out_cpp
(
NVTE_NVFP4_1D_SCALING
);
if
(
rowwise_usage
)
{
out_cpp
.
set_rowwise_data
(
rowwise_data_tensor
.
data_ptr
(),
DType
::
kFloat4E2M1
,
shape
);
out_cpp
.
set_rowwise_scale_inv
(
rowwise_scale_inv_tensor
.
data_ptr
(),
DType
::
kFloat8E4M3
,
rowwise_scale_inv_shape
);
out_cpp
.
set_amax
(
amax_rowwise
.
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
if
(
columnwise_usage
)
{
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std
::
vector
<
size_t
>
shape_2d
=
{
flat_first_dim
,
flat_last_dim
};
auto
col_data_shape_fp4
=
make_transpose_shape
<
size_t
>
(
shape_2d
);
out_cpp
.
set_columnwise_data
(
columnwise_data_tensor
.
data_ptr
(),
DType
::
kFloat4E2M1
,
col_data_shape_fp4
);
out_cpp
.
set_columnwise_scale_inv
(
columnwise_scale_inv_tensor
.
data_ptr
(),
DType
::
kFloat8E4M3
,
columnwise_scale_inv_shape
);
out_cpp
.
set_columnwise_amax
(
amax_columnwise
.
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
this
->
set_quantization_params
(
&
out_cpp
);
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
NVFP4Quantizer
::
create_unquantized_tensor_with_amax
(
TensorWrapper
&
quantized_tensor
,
DType
dtype
)
{
// Construct tensor
auto
shape
=
convertShape
(
quantized_tensor
.
shape
());
auto
[
out_cpp
,
out_py
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
shape
,
dtype
);
// Register amax pointer from quantized tensor
void
*
amax_ptr
=
quantized_tensor
.
amax
();
if
(
amax_ptr
==
nullptr
)
{
amax_ptr
=
quantized_tensor
.
get_columnwise_amax
().
data_ptr
;
}
NVTE_CHECK
(
amax_ptr
!=
nullptr
,
"Could not extract amax pointer from NVFP4 tensor."
);
out_cpp
.
set_amax
(
amax_ptr
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
// Zero out amax
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
amax_ptr
,
0
,
sizeof
(
float
),
at
::
cuda
::
getCurrentCUDAStream
()));
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
NVFP4Quantizer
::
convert_and_update_tensor
(
py
::
object
tensor
)
const
{
NVTE_CHECK
(
detail
::
IsNVFP4Tensor
(
tensor
.
ptr
()),
"NVFP4Quantizer must output to IsNVFP4Tensor."
);
// Extract buffers from Python tensor
auto
get_tensor
=
[
&
tensor
](
const
char
*
name
)
->
std
::
optional
<
at
::
Tensor
>
{
auto
attr_py
=
tensor
.
attr
(
name
);
if
(
attr_py
.
is_none
())
{
return
std
::
nullopt
;
}
return
attr_py
.
cast
<
at
::
Tensor
>
();
};
auto
rowwise_data
=
get_tensor
(
"_rowwise_data"
);
auto
rowwise_scale_inv
=
get_tensor
(
"_rowwise_scale_inv"
);
auto
columnwise_data
=
get_tensor
(
"_columnwise_data"
);
auto
columnwise_scale_inv
=
get_tensor
(
"_columnwise_scale_inv"
);
auto
amax_rowwise
=
get_tensor
(
"_amax_rowwise"
);
auto
amax_columnwise
=
get_tensor
(
"_amax_columnwise"
);
NVTE_CHECK
(
rowwise_data
||
columnwise_data
,
"NVFP4Tensor has no data."
);
// Tensor dimensions, shape means original shape
std
::
vector
<
size_t
>
shape
;
if
(
columnwise_data
)
{
shape
=
convert_shape_back_from_fp4
(
getTensorShape
(
*
columnwise_data
),
true
);
if
(
rowwise_data
)
{
auto
expected_shape
=
convert_shape_back_from_fp4
(
getTensorShape
(
*
rowwise_data
),
false
);
NVTE_CHECK
(
shape
==
expected_shape
,
"NVFP4 row-wise data (shape="
,
expected_shape
,
") and column-wise data (shape="
,
shape
,
") do not match"
);
}
}
else
{
// Already checked columnwise_data_tensor == true
shape
=
convert_shape_back_from_fp4
(
getTensorShape
(
*
rowwise_data
),
false
);
}
size_t
flat_first_dim
=
1
;
if
(
shape
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
shape
.
size
()
-
1
;
++
i
)
{
flat_first_dim
*=
shape
[
i
];
}
}
const
size_t
flat_last_dim
=
shape
.
size
()
>
0
?
shape
.
back
()
:
1
;
// Coerce row-wise data
if
(
rowwise_usage
)
{
if
(
!
rowwise_data
)
{
const
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
rowwise_data
=
at
::
empty
(
convert_shape_for_fp4
(
shape_int64
),
opts
);
tensor
.
attr
(
"_rowwise_data"
)
=
*
rowwise_data
;
}
if
(
!
rowwise_scale_inv
)
{
const
auto
scale_inv_shape
=
get_scale_shape
(
shape
,
false
);
const
std
::
vector
<
int64_t
>
scale_inv_shape_int64
(
scale_inv_shape
.
begin
(),
scale_inv_shape
.
end
());
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
rowwise_scale_inv
=
at
::
empty
(
scale_inv_shape_int64
,
opts
);
tensor
.
attr
(
"_rowwise_scale_inv"
)
=
*
rowwise_scale_inv
;
}
if
(
!
amax_rowwise
)
{
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
amax_rowwise
=
at
::
empty
({
1
},
opts
);
tensor
.
attr
(
"_amax_rowwise"
)
=
*
amax_rowwise
;
}
}
else
{
// rowwise_usage == false
if
(
rowwise_data
)
{
rowwise_data
.
reset
();
tensor
.
attr
(
"_rowwise_data"
)
=
py
::
none
();
}
if
(
rowwise_scale_inv
)
{
rowwise_scale_inv
.
reset
();
tensor
.
attr
(
"_rowwise_scale_inv"
)
=
py
::
none
();
}
if
(
amax_rowwise
)
{
amax_rowwise
.
reset
();
tensor
.
attr
(
"_amax_rowwise"
)
=
py
::
none
();
}
}
// Coerce column-wise data
if
(
columnwise_usage
)
{
if
(
!
columnwise_data
)
{
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std
::
vector
<
int64_t
>
shape_int64_2d
=
{
static_cast
<
int64_t
>
(
flat_first_dim
),
static_cast
<
int64_t
>
(
flat_last_dim
)};
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
const
auto
transpose_shape_int64
=
make_transpose_shape
<
int64_t
>
(
shape_int64_2d
);
columnwise_data
=
at
::
empty
(
convert_shape_for_fp4
(
transpose_shape_int64
),
opts
);
tensor
.
attr
(
"_columnwise_data"
)
=
*
columnwise_data
;
}
if
(
!
columnwise_scale_inv
)
{
const
auto
scale_inv_shape
=
get_scale_shape
(
shape
,
true
);
const
std
::
vector
<
int64_t
>
scale_inv_shape_int64
(
scale_inv_shape
.
begin
(),
scale_inv_shape
.
end
());
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
columnwise_scale_inv
=
at
::
empty
(
scale_inv_shape_int64
,
opts
);
tensor
.
attr
(
"_columnwise_scale_inv"
)
=
*
columnwise_scale_inv
;
}
if
(
!
amax_columnwise
)
{
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
amax_columnwise
=
at
::
zeros
({
1
},
opts
);
tensor
.
attr
(
"_amax_columnwise"
)
=
*
amax_columnwise
;
}
}
else
{
// columnwise_usage == false
if
(
columnwise_data
)
{
columnwise_data
.
reset
();
tensor
.
attr
(
"_columnwise_data"
)
=
py
::
none
();
}
if
(
columnwise_scale_inv
)
{
columnwise_scale_inv
.
reset
();
tensor
.
attr
(
"_columnwise_scale_inv"
)
=
py
::
none
();
}
if
(
amax_columnwise
)
{
amax_columnwise
.
reset
();
tensor
.
attr
(
"_amax_columnwise"
)
=
py
::
none
();
}
}
// Construct C++ tensor
TensorWrapper
out_cpp
(
NVTE_NVFP4_1D_SCALING
);
if
(
rowwise_usage
)
{
out_cpp
.
set_rowwise_data
(
rowwise_data
->
data_ptr
(),
DType
::
kFloat4E2M1
,
shape
);
out_cpp
.
set_rowwise_scale_inv
(
rowwise_scale_inv
->
data_ptr
(),
DType
::
kFloat8E4M3
,
getTensorShape
(
*
rowwise_scale_inv
));
out_cpp
.
set_amax
(
amax_rowwise
->
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
if
(
columnwise_usage
)
{
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std
::
vector
<
size_t
>
shape_2d
=
{
flat_first_dim
,
flat_last_dim
};
auto
col_data_shape_fp4
=
make_transpose_shape
<
size_t
>
(
shape_2d
);
out_cpp
.
set_columnwise_data
(
columnwise_data
->
data_ptr
(),
DType
::
kFloat4E2M1
,
col_data_shape_fp4
);
out_cpp
.
set_columnwise_scale_inv
(
columnwise_scale_inv
->
data_ptr
(),
DType
::
kFloat8E4M3
,
getTensorShape
(
*
columnwise_scale_inv
));
out_cpp
.
set_columnwise_amax
(
amax_columnwise
->
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
this
->
set_quantization_params
(
&
out_cpp
);
return
{
std
::
move
(
out_cpp
),
std
::
move
(
tensor
)};
}
void
NVFP4Quantizer
::
quantize_impl
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
,
bool
compute_amax
)
{
// Nothing to be done if input is empty
if
(
input
.
numel
()
==
0
)
{
return
;
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
QuantizationConfigWrapper
quant_config
;
if
(
noop_flag
)
{
quant_config
.
set_noop_tensor
(
noop_flag
->
data
());
}
quant_config
.
set_nvfp4_2d_quantization
(
this
->
with_2d_quantization
);
quant_config
.
set_stochastic_rounding
(
this
->
stochastic_rounding
);
// We only need RHT for columnwise usage.
// flat first dim and last dim for multi dimensional input
size_t
rows
=
1
;
for
(
size_t
i
=
0
;
i
<
input
.
ndim
()
-
1
;
++
i
)
{
rows
*=
input
.
size
(
i
);
}
size_t
cols
=
input
.
size
(
input
.
ndim
()
-
1
);
TensorWrapper
te_rng_state
;
if
(
this
->
stochastic_rounding
)
{
const
size_t
rng_elts_per_thread
=
1024
;
// Wild guess, probably can be tightened
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
std
::
nullopt
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
at
::
PhiloxCudaState
philox_args
=
init_philox_state
(
gen
,
rng_elts_per_thread
);
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
);
auto
rng_state
=
torch
::
empty
({
2
},
opts
);
philox_unpack
(
philox_args
,
static_cast
<
int64_t
*>
(
rng_state
.
data_ptr
()));
te_rng_state
=
makeTransformerEngineTensor
(
rng_state
);
quant_config
.
set_rng_state
(
te_rng_state
.
data
());
}
// Restriction for the RHT cast fusion kernel.
bool
eligible_for_rht_cast_fusion
=
input
.
dtype
()
==
DType
::
kBFloat16
&&
rows
%
64
==
0
&&
cols
%
128
==
0
;
// Compute amax.
if
(
this
->
with_rht
)
{
if
(
input
.
dtype
()
!=
DType
::
kBFloat16
)
{
NVTE_CHECK
(
false
,
"RHT is only supported for bfloat16 input"
);
}
if
(
this
->
with_post_rht_amax
)
{
// We need:
// 1. Rowwise amax = amax for input
// 2. Columnwise amax = amax for RHT(input.t)
NVTE_SCOPED_GIL_RELEASE
({
nvte_hadamard_transform_amax
(
input
.
data
(),
out
.
data
(),
0
,
this
->
rht_matrix_random_sign_mask_t
,
stream
);
});
}
else
{
// raise error since it's not supported yet
NVTE_CHECK
(
false
,
"Pre-RHT amax is not supported yet"
);
}
}
else
{
// Without RHT
if
(
compute_amax
)
{
// Amax pointers
auto
rowwise_amax_ptr
=
out
.
get_amax
().
data_ptr
;
auto
columnwise_amax_ptr
=
out
.
get_columnwise_amax
().
data_ptr
;
void
*
amax_ptr
=
rowwise_amax_ptr
!=
nullptr
?
rowwise_amax_ptr
:
columnwise_amax_ptr
;
NVTE_CHECK
(
amax_ptr
!=
nullptr
,
"Could not find amax pointer"
);
// Compute amax of input tensor
out
.
set_amax
(
amax_ptr
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_compute_amax_with_config
(
input
.
data
(),
out
.
data
(),
quant_config
,
stream
);
});
out
.
set_amax
(
rowwise_amax_ptr
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
// Make sure row-wise and column-wise amaxes match
if
(
rowwise_amax_ptr
!=
amax_ptr
&&
rowwise_amax_ptr
!=
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
rowwise_amax_ptr
,
amax_ptr
,
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
stream
));
}
if
(
columnwise_amax_ptr
!=
amax_ptr
&&
columnwise_amax_ptr
!=
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
columnwise_amax_ptr
,
amax_ptr
,
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
stream
));
}
}
}
// amax reduction
if
(
this
->
with_amax_reduction
)
{
std
::
vector
<
at
::
Tensor
>
amax_tensors
;
// push amax tensors inside if they need to be reduced
auto
make_amax_tensor
=
[](
void
*
data_ptr
)
{
return
at
::
from_blob
(
data_ptr
,
std
::
vector
<
int64_t
>
{
1
},
[](
void
*
)
{},
// deleter doing nothing since it doesn't own the data
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kFloat32
));
};
if
(
rowwise_usage
)
{
amax_tensors
.
push_back
(
make_amax_tensor
(
out
.
get_amax
().
data_ptr
));
}
if
(
columnwise_usage
)
{
amax_tensors
.
push_back
(
make_amax_tensor
(
out
.
get_columnwise_amax
().
data_ptr
));
}
c10d
::
AllreduceCoalescedOptions
opts
;
opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
NVTE_SCOPED_GIL_RELEASE
(
{
this
->
amax_reduction_group
->
allreduce_coalesced
(
amax_tensors
,
opts
)
->
wait
();
});
}
if
(
this
->
with_rht
)
{
if
(
rowwise_usage
)
{
// For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise
TensorWrapper
out_identity
(
out
.
scaling_mode
());
auto
out_identity_data
=
out
.
get_rowwise_data
();
auto
out_identity_scale_inv
=
out
.
get_rowwise_scale_inv
();
auto
out_identity_amax
=
out
.
get_amax
();
out_identity
.
set_rowwise_data
(
out_identity_data
.
data_ptr
,
static_cast
<
DType
>
(
out_identity_data
.
dtype
),
out_identity_data
.
shape
);
out_identity
.
set_rowwise_scale_inv
(
out_identity_scale_inv
.
data_ptr
,
static_cast
<
DType
>
(
out_identity_scale_inv
.
dtype
),
out_identity_scale_inv
.
shape
);
out_identity
.
set_amax
(
out_identity_amax
.
data_ptr
,
static_cast
<
DType
>
(
out_identity_amax
.
dtype
),
out_identity_amax
.
shape
);
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_quantize_v2
(
input
.
data
(),
out_identity
.
data
(),
quant_config
,
stream
);
});
}
if
(
columnwise_usage
)
{
// Get the output columnwise data, scale_inv, and amax
auto
out_columnwise_data
=
out
.
get_columnwise_data
();
auto
out_columnwise_scale_inv
=
out
.
get_columnwise_scale_inv
();
// NOTE: should already be populated.
auto
out_columnwise_amax
=
out
.
get_columnwise_amax
();
// Create a wrapper for the columnwise output, as the rowwise output.
// The reason is due to the input `rht_output_t` is already in the transposed layout.
// Thus, we only need a rowwise quantization to generate the columnwise output.
TensorWrapper
out_transpose
(
out
.
scaling_mode
());
// Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail
// need to convert the shape to 2D here
auto
colwise_data_shape
=
out_columnwise_data
.
shape
;
std
::
vector
<
size_t
>
colwise_data_shape_2d
;
// shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte
// the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again
// so the multiple 2 get cancelled out
colwise_data_shape_2d
.
push_back
(
colwise_data_shape
.
data
[
0
]);
size_t
last_dim
=
1
;
for
(
size_t
i
=
1
;
i
<
colwise_data_shape
.
ndim
;
++
i
)
{
last_dim
*=
colwise_data_shape
.
data
[
i
];
}
colwise_data_shape_2d
.
push_back
(
last_dim
);
out_transpose
.
set_rowwise_data
(
out_columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
out_columnwise_data
.
dtype
),
colwise_data_shape_2d
);
out_transpose
.
set_rowwise_scale_inv
(
out_columnwise_scale_inv
.
data_ptr
,
static_cast
<
DType
>
(
out_columnwise_scale_inv
.
dtype
),
out_columnwise_scale_inv
.
shape
);
out_transpose
.
set_amax
(
out_columnwise_amax
.
data_ptr
,
static_cast
<
DType
>
(
out_columnwise_amax
.
dtype
),
out_columnwise_amax
.
shape
);
if
(
!
eligible_for_rht_cast_fusion
)
{
// Invoking fallback RHT kernel.
// If using RHT, then amax will be computed in the RHT step
// If not using RHT, then amax will be computed based on input x
at
::
Tensor
rht_output_t
;
// The RHT(x_t) output, in columnwise layout
// This wrapper is going to be passed as input to the quantization kernel.
TensorWrapper
rht_output_t_cpp
;
// Wrapper to contain the RHT(x) and RHT(x_t) outputs
rht_output_t
=
allocateTorchTensor
(
static_cast
<
int
>
(
cols
),
static_cast
<
int
>
(
rows
),
input
.
dtype
());
// NOTE (frsun): This is non-intuitive, we are writing the
// result of transposed RHT to the output of rowwise.
rht_output_t_cpp
.
set_rowwise_data
(
rht_output_t
.
data_ptr
(),
input
.
dtype
(),
std
::
vector
<
size_t
>
{
cols
,
rows
});
NVTE_SCOPED_GIL_RELEASE
({
// Perform the RHT(input.t), and write to rht_output_cpp.columnwise.
nvte_hadamard_transform
(
input
.
data
(),
rht_output_t_cpp
.
data
(),
0
,
this
->
rht_matrix_random_sign_mask_t
,
stream
);
});
// Quantize kernel will treat everything as rowwise input/output, which is
// intended.
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
rht_output_t_cpp
.
data
(),
out_transpose
.
data
(),
quant_config
,
stream
);
});
}
else
{
// RHT cast fusion kernel.
NVTE_CHECK
(
this
->
rht_matrix
.
defined
()
&&
this
->
rht_matrix
.
numel
()
>
0
,
"RHT matrix is not set"
);
auto
rht_matrix_nvte
=
makeTransformerEngineTensor
(
this
->
rht_matrix
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_hadamard_transform_cast_fusion_columnwise
(
input
.
data
(),
out_transpose
.
data
(),
rht_matrix_nvte
.
data
(),
quant_config
,
stream
);
});
}
}
}
else
{
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
input
.
data
(),
out
.
data
(),
quant_config
,
stream
);
});
}
}
void
NVFP4Quantizer
::
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
)
{
this
->
quantize_impl
(
input
,
out
,
noop_flag
,
true
);
}
void
NVFP4Quantizer
::
quantize_with_amax
(
TensorWrapper
&
input
,
TensorWrapper
&
out
)
{
// Update output tensor amaxes with input tensor amax
auto
input_amax_ptr
=
input
.
amax
();
auto
output_rowwise_amax_ptr
=
out
.
get_amax
().
data_ptr
;
auto
output_columnwise_amax_ptr
=
out
.
get_columnwise_amax
().
data_ptr
;
NVTE_CHECK
(
input_amax_ptr
!=
nullptr
||
(
output_rowwise_amax_ptr
==
nullptr
&&
output_columnwise_amax_ptr
==
nullptr
),
"Input tensor does not have pre-computed amax"
);
if
(
input_amax_ptr
!=
output_rowwise_amax_ptr
&&
input_amax_ptr
!=
nullptr
&&
output_rowwise_amax_ptr
!=
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
output_rowwise_amax_ptr
,
input_amax_ptr
,
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
at
::
cuda
::
getCurrentCUDAStream
()));
}
if
(
input_amax_ptr
!=
output_columnwise_amax_ptr
&&
input_amax_ptr
!=
nullptr
&&
output_columnwise_amax_ptr
!=
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
output_columnwise_amax_ptr
,
input_amax_ptr
,
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
at
::
cuda
::
getCurrentCUDAStream
()));
}
input
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
input
.
defaultShape
);
// Perform quantization
this
->
quantize_impl
(
input
,
out
,
std
::
nullopt
,
false
);
}
std
::
vector
<
size_t
>
NVFP4Quantizer
::
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
{
size_t
numel
=
1
;
for
(
auto
s
:
shape
)
{
numel
*=
s
;
}
auto
last_dim
=
shape
.
back
();
auto
flat_first_dim
=
numel
/
last_dim
;
NVTE_CHECK
(
last_dim
%
NVFP4_BLOCK_SIZE
==
0
,
"Last dim for NVFP4 must be divisible by "
,
NVFP4_BLOCK_SIZE
,
" (got dim="
,
last_dim
,
")"
);
NVTE_CHECK
(
flat_first_dim
%
NVFP4_BLOCK_SIZE
==
0
,
"NVFP4 requires tensor dims that are divisible by "
,
NVFP4_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
std
::
vector
<
size_t
>
scale_shape
;
bool
rowwise_usage
=
!
columnwise
;
if
(
rowwise_usage
)
{
// rowwise scaling factor shape
size_t
sinv0
=
roundup
(
flat_first_dim
,
128
);
size_t
sinv1
=
roundup
(
last_dim
/
NVFP4_BLOCK_SIZE
,
4
);
scale_shape
=
{
sinv0
,
sinv1
};
}
else
{
// columnwise scaling factor shape
size_t
sinv0
=
roundup
(
last_dim
,
128
);
size_t
sinv1
=
roundup
(
flat_first_dim
/
NVFP4_BLOCK_SIZE
,
4
);
scale_shape
=
{
sinv0
,
sinv1
};
}
return
scale_shape
;
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/type_converters.cpp
View file @
53fa872c
...
...
@@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
return
ret
;
}
TensorWrapper
NVTETensorFromNVFP4Tensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
)
{
const
DType
dtype
=
tensor
.
attr
(
"_fp4_dtype"
).
cast
<
DType
>
();
auto
ret
=
TensorWrapper
(
NVTE_NVFP4_1D_SCALING
);
bool
rowwise_usage
=
!
(
tensor
.
attr
(
"_rowwise_data"
).
is_none
());
bool
columnwise_usage
=
!
(
tensor
.
attr
(
"_columnwise_data"
).
is_none
());
NVTE_CHECK
(
rowwise_usage
||
columnwise_usage
,
"No data found for NVFP4 Tensor."
);
// Row-scaled data
if
(
rowwise_usage
)
{
const
auto
&
data
=
tensor
.
attr
(
"_rowwise_data"
).
cast
<
at
::
Tensor
>
();
const
auto
&
scale_inv
=
tensor
.
attr
(
"_rowwise_scale_inv"
).
cast
<
at
::
Tensor
>
();
const
auto
&
amax_rowwise
=
tensor
.
attr
(
"_amax_rowwise"
).
cast
<
at
::
Tensor
>
();
ret
.
set_rowwise_data
(
data
.
data_ptr
(),
dtype
,
convert_shape_back_from_fp4
(
getTensorShape
(
data
),
false
));
ret
.
set_rowwise_scale_inv
(
scale_inv
.
data_ptr
(),
DType
::
kFloat8E4M3
,
getTensorShape
(
scale_inv
));
ret
.
set_amax
(
amax_rowwise
.
data_ptr
(),
DType
::
kFloat32
,
getTensorShape
(
amax_rowwise
));
}
// Column-scaled data
if
(
columnwise_usage
)
{
const
auto
&
data
=
tensor
.
attr
(
"_columnwise_data"
).
cast
<
at
::
Tensor
>
();
const
auto
&
scale_inv
=
tensor
.
attr
(
"_columnwise_scale_inv"
).
cast
<
at
::
Tensor
>
();
const
auto
&
amax_columnwise
=
tensor
.
attr
(
"_amax_columnwise"
).
cast
<
at
::
Tensor
>
();
ret
.
set_columnwise_data
(
data
.
data_ptr
(),
DType
::
kFloat4E2M1
,
convert_shape_back_from_fp4
(
getTensorShape
(
data
),
false
));
ret
.
set_columnwise_scale_inv
(
scale_inv
.
data_ptr
(),
DType
::
kFloat8E4M3
,
getTensorShape
(
scale_inv
));
ret
.
set_columnwise_amax
(
amax_columnwise
.
data_ptr
(),
DType
::
kFloat32
,
getTensorShape
(
amax_columnwise
));
}
// Quantizer state
quantizer
->
set_quantization_params
(
&
ret
);
return
ret
;
}
}
// namespace detail
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/util.cpp
View file @
53fa872c
...
...
@@ -14,22 +14,31 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
if
(
input
.
scaling_mode
()
==
NVTE_INVALID_SCALING
)
{
NVTE_ERROR
(
"Invalid scaling mode for swizzle."
);
}
else
if
(
input
.
scaling_mode
()
!=
NVTE_MXFP8_1D_SCALING
)
{
}
else
if
(
input
.
scaling_mode
()
!=
NVTE_MXFP8_1D_SCALING
&&
input
.
scaling_mode
()
!=
NVTE_NVFP4_1D_SCALING
)
{
return
std
::
nullopt
;
}
NVTE_CHECK
(
input
.
element_size
()
==
1
,
"8-bit input required for swizzling scaling factors."
);
NVTE_CHECK
(
input
.
element_size_bits
()
==
4
||
input
.
element_size_bits
()
==
8
,
"4-bit or 8-bit input required for swizzling scaling factors."
);
const
auto
nvfp4
=
input
.
scaling_mode
()
==
NVTE_NVFP4_1D_SCALING
;
NVTEBasicTensor
scale_inv
;
NVTEShape
nvte_input_shape
;
if
(
rowwise
)
{
nvte_input_shape
=
input
.
shape
();
scale_inv
=
input
.
get_rowwise_scale_inv
();
}
else
{
nvte_input_shape
=
input
.
get_columnwise_data
().
shape
;
scale_inv
=
input
.
get_columnwise_scale_inv
();
}
auto
input_shape
=
nvte_shape_to_vector
(
input
.
shape
()
);
auto
input_shape
=
nvte_shape_to_vector
(
nvte_
input
_
shape
);
auto
scale_inv_shape
=
nvte_shape_to_vector
(
scale_inv
.
shape
);
NVTE_CHECK
(
input_shape
.
size
()
>=
2
,
"Wrong ndims for swizzle input shape."
);
// Allocate memory for swizzled output.
auto
options
=
at
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
torch
::
kCUDA
);
std
::
vector
<
int64_t
>
scale_inv_shape_int
;
...
...
@@ -41,36 +50,34 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine
::
TensorWrapper
input_cu
(
NVTE_MXFP8_1D_SCALING
);
transformer_engine
::
TensorWrapper
output_cu
(
NVTE_MXFP8_1D_SCALING
);
// The specific dtype used is irrelevant, just needs to be correct bits.
transformer_engine
::
TensorWrapper
input_cu
(
input
.
scaling_mode
());
transformer_engine
::
TensorWrapper
output_cu
(
input
.
scaling_mode
());
const
auto
input_dtype
=
(
nvfp4
)
?
transformer_engine
::
DType
::
kFloat4E2M1
:
transformer_engine
::
DType
::
kFloat8E4M3
;
const
auto
scale_inv_dtype
=
(
nvfp4
)
?
transformer_engine
::
DType
::
kFloat8E4M3
:
transformer_engine
::
DType
::
kFloat8E8M0
;
if
(
rowwise
)
{
input_cu
.
set_rowwise_data
(
input
.
dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
output_cu
.
set_rowwise_data
(
input
.
dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input_cu
.
set_rowwise_data
(
input
.
dptr
(),
input_dtype
,
input_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
output_cu
.
set_rowwise_data
(
input
.
dptr
(),
input_dtype
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
}
else
{
input_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
output_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
input_dtype
,
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
output_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
input_dtype
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
}
// Launch kernel
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
if
(
rowwise
)
{
input
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
}
else
{
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
}
return
swizzled_scale_inv
;
...
...
transformer_engine/pytorch/distributed.py
View file @
53fa872c
...
...
@@ -39,11 +39,14 @@ from .constants import dist_group_type
from
.fp8
import
FP8GlobalStateManager
,
fp8_autocast
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
.tensor.nvfp4_tensor
import
NVFP4Quantizer
from
.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
.tensor.quantized_tensor
import
QuantizedTensor
,
Quantizer
from
.tensor.quantized_tensor
import
QuantizedTensorBase
,
QuantizedTensor
,
Quantizer
from
.tensor._internal.float8_tensor_base
import
Float8TensorBase
from
.tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
.tensor._internal.nvfp4_tensor_base
import
NVFP4TensorBase
from
.tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
.triton.pad
import
pad_columnwise_scale_inv
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
,
DebugQuantizer
...
...
@@ -1204,6 +1207,245 @@ def _all_gather_fp8_blockwise(
return
out
,
handle
def
_swap_first_dims
(
tensor
:
torch
.
Tensor
,
world_size
:
int
):
"""
Swap first 2 dimensions of a tensor to fix interleaved
data format after gathering transposed data.
For more than 2 dimensions, we squash the trailing dimensions,
instead of the first few dimensions, that's because the shape
passed in this function is already transposed.
"""
shape
=
tensor
.
shape
assert
tensor
.
ndim
>=
2
,
"Wrong number of dimensions for fixing interleave."
first_dim
=
shape
[
0
]
flattened_trailing
=
math
.
prod
(
shape
[
1
:])
assert
first_dim
%
world_size
==
0
,
"Wrong dimensions for fixing interleave."
tensor
=
tensor
.
reshape
(
world_size
,
first_dim
//
world_size
,
flattened_trailing
)
tensor
=
tex
.
swap_first_dims
(
tensor
,
out
=
None
)
return
tensor
.
reshape
(
first_dim
//
world_size
,
flattened_trailing
*
world_size
)
def
_post_process_nvfp4_gather
(
out
:
NVFP4TensorBase
,
columnwise_data_interleaved
:
torch
.
Tensor
,
columnwise_scale_inv_interleaved
:
torch
.
Tensor
,
world_size
:
int
,
handle
:
Optional
[
torch
.
distributed
.
Work
]
=
None
,
)
->
NVFP4TensorBase
:
"""Post-process FP8 blockwise gather."""
if
handle
is
not
None
:
handle
.
wait
()
handle
=
None
# Fix the interleaved transposed data from gathering along first dim.
out
.
_columnwise_scale_inv
=
_swap_first_dims
(
columnwise_scale_inv_interleaved
,
world_size
)
out
.
_columnwise_data
=
_swap_first_dims
(
columnwise_data_interleaved
,
world_size
)
# Optionally pad the scaling inverse if needed.
out
.
_columnwise_scale_inv
=
pad_columnwise_scale_inv
(
out
.
_columnwise_scale_inv
)
@
dataclass
class
_NVFP4AllGatherAsyncHandle
:
"""Handle for asynchronous NVFP4 all-gather."""
output
:
NVFP4TensorBase
columnwise_data_interleaved
:
torch
.
Tensor
columnwise_scale_inv_interleaved
:
torch
.
Tensor
world_size
:
int
async_handle
:
torch
.
distributed
.
Work
_synchronized
:
bool
=
False
def
wait
(
self
)
->
None
:
"""Wait for the async operation to complete and post-process the tensor."""
if
self
.
_synchronized
:
return
self
.
async_handle
.
wait
()
_post_process_nvfp4_gather
(
self
.
output
,
self
.
columnwise_data_interleaved
,
self
.
columnwise_scale_inv_interleaved
,
self
.
world_size
,
)
self
.
_synchronized
=
True
def
_all_gather_nvfp4
(
inp
:
torch
.
Tensor
,
process_group
:
dist_group_type
,
*
,
async_op
:
bool
=
False
,
quantizer
:
NVFP4Quantizer
,
out_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
tuple
[
NVFP4TensorBase
,
Optional
[
torch
.
distributed
.
Work
]]:
"""All-gather NVFP4 tensor along first dimension."""
# Input tensor attributes
in_shape
:
Iterable
[
int
]
=
None
in_shape_t
:
Iterable
[
int
]
=
None
device
:
torch
.
device
dtype
:
torch
.
dtype
# Construct packed shapes for input and input_t.
if
isinstance
(
inp
,
torch
.
Tensor
)
and
not
isinstance
(
inp
,
NVFP4TensorBase
):
# High-precision tensor.
in_shape
=
NVFP4Quantizer
.
convert_shape_for_fp4
(
inp
.
size
())
in_shape_t
=
NVFP4Quantizer
.
convert_shape_for_fp4
(
NVFP4Quantizer
.
get_columnwise_shape
(
inp
.
size
())
)
device
=
inp
.
device
dtype
=
inp
.
dtype
elif
isinstance
(
inp
,
NVFP4TensorBase
):
if
inp
.
_rowwise_data
is
not
None
:
in_shape
=
inp
.
_rowwise_data
.
size
()
device
=
inp
.
_rowwise_data
.
device
if
inp
.
_columnwise_data
is
not
None
:
in_shape_t
=
inp
.
_columnwise_data
.
size
()
device
=
inp
.
_columnwise_data
.
device
dtype
=
torch
.
bfloat16
else
:
raise
ValueError
(
"Invalid type for input tensor (expected torch.Tensor or NVFP4TensorBase, "
f
"found
{
inp
.
__class__
.
__name__
}
)"
)
assert
in_shape
is
not
None
or
in_shape_t
is
not
None
,
"No data found."
world_size
=
get_distributed_world_size
(
process_group
)
if
out_shape
is
None
:
out_shape
=
[
in_shape
[
0
]
*
world_size
]
+
in_shape
[
1
:]
# For cases where inp has dimensions that cannot be quantized,
# we gather in high precision followed by a cast to NVFP4.
if
(
not
isinstance
(
inp
,
NVFP4TensorBase
)
and
quantizer
is
not
None
and
not
quantizer
.
is_quantizable
(
inp
)
):
out
=
torch
.
empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
,
memory_format
=
torch
.
contiguous_format
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
)
out
=
quantizer
(
out
)
return
out
,
None
# Cast input tensor to NVFP4 with required data
if
not
isinstance
(
inp
,
NVFP4TensorBase
):
inp
=
quantizer
(
inp
)
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
quantizer
.
columnwise_usage
and
inp
.
_columnwise_data
is
None
):
warnings
.
warn
(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to NVFP4."
)
inp
=
quantizer
(
inp
.
dequantize
())
# Construct NVFP4 output tensor
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
# Coalesce NCCL collectives for gathering data and scale inverses.
with
torch
.
distributed
.
_coalescing_manager
(
group
=
process_group
,
device
=
device
,
async_ops
=
async_op
,
)
as
gather_coalescing_manager
:
# Gather NVFP4 data for row-wise usage
if
quantizer
.
rowwise_usage
:
# Remove padding from NVFP4 scale-inverses
assert
in_shape
is
not
None
,
"Shape not found."
in_scale_inv
=
inp
.
_rowwise_scale_inv
out_scale_inv
=
out
.
_rowwise_scale_inv
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
# Launch all-gathers
torch
.
distributed
.
all_gather_into_tensor
(
out_scale_inv
,
in_scale_inv
,
group
=
process_group
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_rowwise_data
,
inp
.
_rowwise_data
,
group
=
process_group
,
)
# Transfer amax to output.
out
.
_amax_rowwise
=
inp
.
_amax_rowwise
# Gather the transposed NVFP4 data along first dimension. Fix format later.
if
quantizer
.
columnwise_usage
:
# Remove padding from NVFP4 scale-inverses
# For doing an all-gather on transposed scale inverses,
# we need to remove padding from both dimension.
in_scale_inv
=
inp
.
_columnwise_scale_inv
# take caution that for in_shape_t, flatten in the trailing dimensions!
flattened_in_shape0
=
in_shape_t
[
0
]
flattened_in_shape1
=
math
.
prod
(
in_shape_t
[
1
:])
# Remove dim0 padding
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
# Remove dim1 padding (pack first).
unpadded_dim1
=
flattened_in_shape1
*
2
//
16
if
in_scale_inv
.
size
(
1
)
!=
unpadded_dim1
:
in_scale_inv
=
in_scale_inv
[:,
:
unpadded_dim1
].
contiguous
()
# Construct tensor to gather transposed scale_inv (interleaved) and launch AG.
out_scale_inv
=
torch
.
empty
(
[
flattened_in_shape0
*
world_size
]
+
[
in_scale_inv
.
shape
[
1
]],
dtype
=
in_scale_inv
.
dtype
,
layout
=
in_scale_inv
.
layout
,
device
=
in_scale_inv
.
device
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out_scale_inv
,
in_scale_inv
,
group
=
process_group
,
)
# Construct tensor to gather transposed data (interleaved) and launch AG.
out_columnwise_data
=
torch
.
empty
(
[
inp
.
_columnwise_data
.
shape
[
0
]
*
world_size
]
+
list
(
inp
.
_columnwise_data
.
shape
[
1
:]),
dtype
=
inp
.
_columnwise_data
.
dtype
,
layout
=
inp
.
_columnwise_data
.
layout
,
device
=
inp
.
_columnwise_data
.
device
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out_columnwise_data
,
inp
.
_columnwise_data
,
group
=
process_group
,
)
# Transfer amax to output.
out
.
_amax_columnwise
=
inp
.
_amax_columnwise
handle
=
gather_coalescing_manager
if
async_op
else
None
# Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed.
if
async_op
and
quantizer
.
columnwise_usage
:
handle
=
_NVFP4AllGatherAsyncHandle
(
out
,
out_columnwise_data
,
out_scale_inv
,
world_size
,
handle
)
elif
quantizer
.
columnwise_usage
:
_post_process_nvfp4_gather
(
out
,
out_columnwise_data
,
out_scale_inv
,
world_size
,
handle
)
return
out
,
handle
def
_all_gather_mxfp8
(
inp
:
torch
.
Tensor
,
process_group
:
dist_group_type
,
...
...
@@ -1291,7 +1533,6 @@ def _all_gather_mxfp8(
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
out_scale_inv
[
flattened_in_shape0
*
world_size
:].
zero_
()
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
# Launch all-gathers
...
...
@@ -1315,7 +1556,6 @@ def _all_gather_mxfp8(
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
//
32
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
out_scale_inv
[
flattened_in_shape0
*
world_size
:].
zero_
()
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
# Launch all-gathers
...
...
@@ -1347,7 +1587,7 @@ def gather_along_first_dim(
# Return immediately if no communication is required
world_size
=
get_distributed_world_size
(
process_group
)
if
world_size
==
1
:
if
quantizer
is
not
None
and
not
isinstance
(
inp
,
QuantizedTensor
):
if
quantizer
is
not
None
and
not
isinstance
(
inp
,
QuantizedTensor
Base
):
inp
=
quantizer
(
inp
)
return
inp
,
None
...
...
@@ -1426,13 +1666,24 @@ def gather_along_first_dim(
out_shape
=
out_shape
,
)
# NVFP4 case
if
isinstance
(
inp
,
NVFP4TensorBase
)
or
isinstance
(
quantizer
,
NVFP4Quantizer
):
assert
isinstance
(
quantizer
,
NVFP4Quantizer
)
return
_all_gather_nvfp4
(
inp
,
process_group
,
async_op
=
async_op
,
quantizer
=
quantizer
,
out_shape
=
out_shape
,
)
# High-precision communication for quantized tensors
if
quantizer
is
not
None
:
warnings
.
warn
(
"Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather."
)
if
isinstance
(
inp
,
QuantizedTensor
):
if
isinstance
(
inp
,
QuantizedTensor
Base
):
inp
=
inp
.
dequantize
()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
...
...
@@ -1450,7 +1701,7 @@ def gather_along_first_dim(
return
out
,
None
# Dequantize quantized tensor if not supported
if
isinstance
(
inp
,
QuantizedTensor
):
if
isinstance
(
inp
,
QuantizedTensor
Base
):
warnings
.
warn
(
"Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather."
...
...
transformer_engine/pytorch/experimental/__init__.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Experimental features and APIs."""
from
.config
import
set_qlinear_params
,
get_experimental_quantizers
__all__
=
[
"set_qlinear_params"
,
"get_experimental_quantizers"
]
transformer_engine/pytorch/experimental/config.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Config API for experimental middleware between Transformer Engine and Kitchen."""
import
dataclasses
import
enum
import
os
from
typing
import
Optional
from
transformer_engine.pytorch.experimental
import
utils
from
transformer_engine.pytorch.experimental
import
quantization
from
transformer_engine.pytorch.experimental
import
quantization_microblock_ref
from
transformer_engine.pytorch.experimental.quantization
import
MMParams
@
dataclasses
.
dataclass
()
class
QLinearParams
:
"""Quantization parameters of linear layer.
Contains ready-to-use quantizers for input (x), weight (w), and gradient (g) tensors.
"""
x_quantizer
:
Optional
[
quantization
.
ExperimentalQuantizer
]
=
None
w_quantizer
:
Optional
[
quantization
.
ExperimentalQuantizer
]
=
None
g_quantizer
:
Optional
[
quantization
.
ExperimentalQuantizer
]
=
None
mm_fprop
:
Optional
[
MMParams
]
=
None
mm_dgrad
:
Optional
[
MMParams
]
=
None
mm_wgrad
:
Optional
[
MMParams
]
=
None
@
enum
.
unique
class
QuantizeRecipe
(
enum
.
Enum
):
"""Pre-defined quantization recipes for linear layers."""
NON_QUANTIZE
=
"non_quantize"
NVFP4_REF
=
"nvfp4_ref"
NVFP4_REF_RHT_ONLY
=
"nvfp4_ref_rht_only"
NVFP4_REF_2D_QUANTIZATION_ONLY
=
"nvfp4_ref_2d_quantization_only"
NVFP4_REF_RHT_AND_2D_QUANTIZATION
=
"nvfp4_ref_rht_and_2d_quantization"
def
get_qlinear_params_from_predefined
(
recipe
:
QuantizeRecipe
,
)
->
Optional
[
QLinearParams
]:
"""Get quantization parameters for linear layer based on recipe."""
if
recipe
==
QuantizeRecipe
.
NON_QUANTIZE
:
return
None
if
recipe
==
QuantizeRecipe
.
NVFP4_REF
:
return
QLinearParams
(
x_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
),
w_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
),
g_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
),
)
if
recipe
==
QuantizeRecipe
.
NVFP4_REF_RHT_ONLY
:
return
QLinearParams
(
x_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
True
,
),
w_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
False
,
),
g_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
True
,
),
)
if
recipe
==
QuantizeRecipe
.
NVFP4_REF_2D_QUANTIZATION_ONLY
:
return
QLinearParams
(
x_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
False
,
),
w_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
16
,
16
),
pow_2_scales
=
False
,
with_rht
=
False
,
),
g_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
False
,
),
)
if
recipe
==
QuantizeRecipe
.
NVFP4_REF_RHT_AND_2D_QUANTIZATION
:
return
QLinearParams
(
x_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
True
,
),
w_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
16
,
16
),
pow_2_scales
=
False
,
with_rht
=
False
,
),
g_quantizer
=
quantization_microblock_ref
.
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
True
,
),
)
raise
ValueError
(
f
"Unsupported quantize recipe:
{
recipe
}
"
)
def
get_qlinear_params_from_qat_params
(
qat_params_idx
:
int
)
->
Optional
[
QLinearParams
]:
"""Load quantization options from Kitchen to Transformer Engine.
TODO(etsykunov): Confirm docstring is correct.
"""
assert
qat_params_idx
>
0
,
"QAT_PARAMS is not set."
if
qat_params_idx
==
6010
:
return
get_qlinear_params_from_predefined
(
QuantizeRecipe
.
NVFP4_REF
)
if
qat_params_idx
==
960109
:
return
get_qlinear_params_from_predefined
(
QuantizeRecipe
.
NVFP4_REF_RHT_ONLY
)
if
qat_params_idx
==
9002
:
return
get_qlinear_params_from_predefined
(
QuantizeRecipe
.
NVFP4_REF_2D_QUANTIZATION_ONLY
)
if
qat_params_idx
==
9003
:
return
get_qlinear_params_from_predefined
(
QuantizeRecipe
.
NVFP4_REF_RHT_AND_2D_QUANTIZATION
)
raise
ValueError
(
f
"Unsupported QAT params index:
{
qat_params_idx
}
"
)
def
set_qlinear_params
(
qlinear_params
:
Optional
[
QLinearParams
]
=
None
,
layer_number
:
Optional
[
int
]
=
None
,
layer_name
:
Optional
[
str
]
=
None
,
)
->
Optional
[
QLinearParams
]:
"""Set quantization parameters based on configuration.
Args:
qlinear_params: Quantization parameters. If None, loaded from environment.
layer_number: The numerical index of this layer in the model structure.
layer_name: The name for this layer.
Returns:
QLinearParams: The finalized quantization parameters for this layer.
"""
if
qlinear_params
is
None
:
qat_params_idx
=
int
(
os
.
getenv
(
"QAT_PARAMS"
,
"0"
))
if
qat_params_idx
==
0
:
return
None
return
get_qlinear_params_from_qat_params
(
qat_params_idx
)
# Apply layer-specific overrides
if
layer_number
is
not
None
:
raise
NotImplementedError
(
"Layer-specific overrides are not supported yet."
)
if
layer_name
is
not
None
:
raise
NotImplementedError
(
"Layer-specific overrides are not supported yet."
)
return
qlinear_params
def
get_experimental_quantizers
(
fp8
:
bool
,
qlinear_params
:
QLinearParams
):
"""Replacement of _get_quantizers() in TE modules."""
if
not
fp8
:
raise
ValueError
(
"FP8 is required to be enabled for experimental quantization."
)
input_quantizer
=
qlinear_params
.
x_quantizer
weight_quantizer
=
qlinear_params
.
w_quantizer
output_quantizer
=
None
grad_input_quantizer
=
None
grad_weight_quantizer
=
None
grad_output_quantizer
=
qlinear_params
.
g_quantizer
return
(
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_input_quantizer
,
grad_weight_quantizer
,
grad_output_quantizer
,
)
transformer_engine/pytorch/experimental/gemm.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""GEMM API for experimental middleware between Transformer Engine and Kitchen."""
from
typing
import
Iterable
,
Optional
import
torch
from
transformer_engine.pytorch.experimental.quantization
import
(
MMParams
,
GEMMType
,
ExperimentalQuantizedTensor
,
)
from
transformer_engine.pytorch.tensor.quantized_tensor
import
Quantizer
def
experimental_gemm
(
A
:
ExperimentalQuantizedTensor
,
B
:
ExperimentalQuantizedTensor
,
workspace
:
torch
.
Tensor
,
# pylint: disable=unused-argument
out_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quantization_params
:
Optional
[
Quantizer
]
=
None
,
# pylint: disable=unused-argument
gelu
:
bool
=
False
,
# pylint: disable=unused-argument
gelu_in
:
torch
.
Tensor
=
None
,
# pylint: disable=unused-argument
accumulate
:
bool
=
False
,
# pylint: disable=unused-argument
layout
:
str
=
"TN"
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
# pylint: disable=unused-argument
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_split_accumulator
:
bool
=
False
,
grad
:
bool
=
False
,
)
->
Iterable
[
Optional
[
torch
.
Tensor
]]:
"""Dispatch GEMM to quantizer's qgemm method."""
assert
isinstance
(
A
,
ExperimentalQuantizedTensor
)
and
isinstance
(
B
,
ExperimentalQuantizedTensor
),
"A and B must be ExperimentalQuantizedTensor instances"
A
,
B
=
B
,
A
# Determine GEMM type based on grad flag and layout
if
not
grad
:
gemm_type
=
GEMMType
.
FPROP
else
:
if
layout
==
"NN"
:
gemm_type
=
GEMMType
.
DGRAD
elif
layout
==
"NT"
:
gemm_type
=
GEMMType
.
WGRAD
else
:
# Default to FPROP for other layouts
gemm_type
=
GEMMType
.
FPROP
# Extract quantizer from QuantizedTensor to get qgemm logic
# TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B.quantizer?
quantizer
=
None
if
hasattr
(
A
,
"quantizer"
)
and
A
.
quantizer
is
not
None
:
quantizer
=
A
.
quantizer
elif
hasattr
(
B
,
"quantizer"
)
and
B
.
quantizer
is
not
None
:
quantizer
=
B
.
quantizer
else
:
raise
ValueError
(
"No quantizer found in QuantizedETensor objects"
)
# Create MMParams
m_params
=
MMParams
(
out_dtype
=
out_dtype
,
use_split_accumulator
=
use_split_accumulator
,
)
out_dtype
=
A
.
dtype
if
m_params
.
out_dtype
is
None
else
m_params
.
out_dtype
if
gemm_type
==
GEMMType
.
FPROP
:
qx
,
sx
=
A
.
data
,
A
.
scale
qw
,
sw
=
B
.
data
,
B
.
scale
assert
qx
is
not
None
assert
sx
is
not
None
assert
qw
is
not
None
assert
sw
is
not
None
assert
A
.
original_shape
is
not
None
# Call quantizer's qgemm method
result
=
quantizer
.
qgemm
(
qx
,
qw
,
m_params
,
out_dtype
,
sx
,
sw
,
bias
,
gemm_type
=
GEMMType
.
FPROP
,
qresult_x
=
A
,
qresult_w
=
B
,
)
if
len
(
A
.
original_shape
)
>
2
:
# Original input was 3D, so we need to reshape result back to 3D
batch_size
=
A
.
original_shape
[
0
]
seq_len
=
A
.
original_shape
[
1
]
result
=
result
.
view
(
batch_size
,
seq_len
,
result
.
shape
[
-
1
])
elif
gemm_type
==
GEMMType
.
DGRAD
:
qdy
,
sdy
=
A
.
data
,
A
.
scale
qw_t
,
sw_t
=
B
.
data_t
,
B
.
scale_t
assert
qdy
is
not
None
assert
sdy
is
not
None
assert
qw_t
is
not
None
assert
sw_t
is
not
None
result
=
quantizer
.
qgemm
(
qdy
,
qw_t
,
m_params
,
out_dtype
,
sdy
,
sw_t
,
None
,
gemm_type
=
GEMMType
.
DGRAD
,
qresult_x
=
A
,
qresult_w
=
B
,
)
elif
gemm_type
==
GEMMType
.
WGRAD
:
qdy_t
,
sdy_t
=
A
.
data_t
,
A
.
scale_t
qx_t
,
sx_t
=
B
.
data_t
,
B
.
scale_t
assert
qdy_t
is
not
None
assert
sdy_t
is
not
None
assert
qx_t
is
not
None
assert
sx_t
is
not
None
result
=
quantizer
.
qgemm
(
qdy_t
,
qx_t
,
m_params
,
out_dtype
,
sdy_t
,
sx_t
,
None
,
gemm_type
=
GEMMType
.
WGRAD
,
qresult_x
=
A
,
qresult_w
=
B
,
)
# Return in the same format as general_gemm
return
result
,
None
,
None
,
None
transformer_engine/pytorch/experimental/quantization.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Quantization API for experimental middleware between Transformer Engine and Kitchen."""
from
__future__
import
annotations
import
abc
import
dataclasses
import
enum
from
typing
import
Iterable
,
Optional
,
Tuple
,
Union
import
torch
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensorBase
,
Quantizer
from
transformer_engine.pytorch.experimental
import
utils
@
enum
.
unique
class
GEMMType
(
enum
.
Enum
):
"""Type of GEMM operation being performed."""
FPROP
=
"fprop"
DGRAD
=
"dgrad"
WGRAD
=
"wgrad"
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
MMParams
:
"""Matrix multiplication parameters."""
out_dtype
:
torch
.
dtype
|
None
=
None
# Use split accumulator for more accurate FP8 GEMM
use_split_accumulator
:
bool
=
True
@
dataclasses
.
dataclass
class
ExperimentalQuantizedTensor
(
QuantizedTensorBase
):
"""Base class for experimental quantized tensor containers.
An experimental container to hold quantization result, including quantized tensor, optional
transposed quantized tensor, and corresponding decoding scales.
data: torch.Tensor
the quantized tensor.
scale: torch.Tensor
the decoding scale for the quantized tensor. Shape depends on the scaling granularity.
- if scaling type is PER_TENSOR, it should be a 1D scalar tensor.
data_t: torch.Tensor
the transposed quantized tensor (computed lazily if needed).
scale_t: torch.Tensor
the decoding scale for the transposed quantized tensor.
dtype: torch.dtype
nominal tensor datatype.
device: torch.device
device of the tensor.
quant_dtype: Union[utils.Fp4Formats, torch.dtype]
low precision tensor datatype.
original_shape: Tuple[int, ...]
original shape of the tensor.
quantizer: ExperimentalQuantizer
Builder class for quantized tensor.
"""
data
:
Optional
[
torch
.
Tensor
]
=
None
scale
:
Optional
[
torch
.
Tensor
]
=
None
data_t
:
Optional
[
torch
.
Tensor
]
=
None
scale_t
:
Optional
[
torch
.
Tensor
]
=
None
global_amax_row
:
Optional
[
torch
.
Tensor
]
=
None
global_amax_col
:
Optional
[
torch
.
Tensor
]
=
None
dtype
:
Optional
[
torch
.
dtype
]
=
None
device
:
Optional
[
torch
.
device
]
=
None
quant_dtype
:
Optional
[
Union
[
utils
.
Fp4Formats
,
torch
.
dtype
]]
=
None
original_shape
:
Optional
[
Tuple
[
int
,
...]]
=
None
quantizer
:
Optional
[
ExperimentalQuantizer
]
=
None
@
property
def
experimental
(
self
)
->
bool
:
"""Flag to indicate this quantizer is using experimental Kitchen middleware."""
return
True
def
get_quantizer
(
self
)
->
ExperimentalQuantizer
:
"""Get builder for QuantizedExperimentalTensor
Quantizer can be used for in-place operations.
"""
if
self
.
quantizer
is
not
None
:
return
self
.
quantizer
raise
ValueError
(
"Quantizer is not set"
)
def
prepare_for_saving
(
self
,
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
ExperimentalQuantizedTensor
]:
"""Prepare the quantization result for saving for backward"""
tensors
=
[
self
.
data
,
self
.
data_t
,
self
.
scale
,
self
.
scale_t
]
self
.
data
=
None
self
.
data_t
=
None
self
.
scale
=
None
self
.
scale_t
=
None
return
tensors
,
self
def
restore_from_saved
(
self
,
tensors
:
list
[
Optional
[
torch
.
Tensor
]]
)
->
list
[
Optional
[
torch
.
Tensor
]]:
"""Restore the quantization result from the saved tensors"""
self
.
data
=
tensors
[
0
]
self
.
data_t
=
tensors
[
1
]
self
.
scale
=
tensors
[
2
]
self
.
scale_t
=
tensors
[
3
]
return
tensors
[
4
:]
def
dequantize
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Dequantize the quantized tensor"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement dequantize function"
)
# Compatibility
@
property
def
_data
(
self
):
return
self
.
data
@
_data
.
setter
def
_data
(
self
,
value
):
self
.
data
=
value
@
property
def
_scale_inv
(
self
):
return
self
.
scale
@
_scale_inv
.
setter
def
_scale_inv
(
self
,
value
):
self
.
scale
=
value
class
ExperimentalQuantizer
(
Quantizer
):
"""Experimental Quantizer class
Defines the interface for experimental quantizers.
"""
def
__init__
(
self
,
*
,
rowwise
:
bool
,
columnwise
:
bool
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
internal
=
True
@
property
def
experimental
(
self
)
->
bool
:
"""Flag to indicate this quantizer is using experimental Kitchen middleware"""
return
True
@
abc
.
abstractmethod
def
qgemm
(
self
,
qx
:
torch
.
Tensor
,
qw
:
torch
.
Tensor
,
m_params
:
MMParams
,
out_dtype
:
torch
.
dtype
,
sx
:
torch
.
Tensor
,
sw
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
out
:
torch
.
Tensor
|
None
=
None
,
accumulate
:
bool
=
False
,
gemm_type
:
GEMMType
=
GEMMType
.
FPROP
,
qresult_x
:
ExperimentalQuantizedTensor
|
None
=
None
,
qresult_w
:
ExperimentalQuantizedTensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Quantized GEMM interface."""
def
dequantize
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Dequantize the quantized tensor"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement dequantize function"
)
def
update_quantized
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Update the quantized tensor with the given tensor in-place"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement update_quantized function"
)
def
make_empty
(
self
,
shape
:
Iterable
[
int
],
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
QuantizedTensorBase
:
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement make_empty function"
)
def
calibrate
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement calibrate function"
)
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement _get_compatible_recipe function"
)
transformer_engine/pytorch/experimental/quantization_microblock_ref.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""NVFP4 implementations for experimental middleware between Transformer Engine and Kitchen."""
from
typing
import
Optional
,
Tuple
import
torch
from
transformer_engine.pytorch.experimental
import
quantization
from
transformer_engine.pytorch.experimental
import
utils
from
transformer_engine.pytorch.experimental.quantization
import
(
ExperimentalQuantizedTensor
,
ExperimentalQuantizer
,
)
def
cast_to_fp4x2
(
x
):
"""Quantize a tensor to FP4 E2M1 and store in a byte tensor"""
result
=
torch
.
zeros_like
(
x
,
dtype
=
torch
.
uint8
)
result
[(
x
>=
0.0
)
&
(
x
<=
0.25
)]
=
0
result
[(
x
>
0.25
)
&
(
x
<
0.75
)]
=
1
result
[(
x
>=
0.75
)
&
(
x
<=
1.25
)]
=
2
result
[(
x
>
1.25
)
&
(
x
<
1.75
)]
=
3
result
[(
x
>=
1.75
)
&
(
x
<=
2.5
)]
=
4
result
[(
x
>
2.5
)
&
(
x
<
3.5
)]
=
5
result
[(
x
>=
3.5
)
&
(
x
<=
5.0
)]
=
6
result
[
x
>
5.0
]
=
7
result
[(
x
>=
-
0.25
)
&
(
x
<
-
0.0
)]
=
8
result
[(
x
<
-
0.25
)
&
(
x
>
-
0.75
)]
=
9
result
[(
x
<=
-
0.75
)
&
(
x
>=
-
1.25
)]
=
10
result
[(
x
<
-
1.25
)
&
(
x
>
-
1.75
)]
=
11
result
[(
x
<=
-
1.75
)
&
(
x
>=
-
2.5
)]
=
12
result
[(
x
<
-
2.5
)
&
(
x
>
-
3.5
)]
=
13
result
[(
x
<=
-
3.5
)
&
(
x
>=
-
5.0
)]
=
14
result
[
x
<
-
5.0
]
=
15
return
result
[:,
::
2
]
+
result
[:,
1
::
2
]
*
16
def
cast_from_fp4x2
(
x
,
dq_dtype
):
"""Dequantize FP4 E2M1 tensor that has been represented in a byte tensor"""
fp4_values
=
torch
.
tensor
(
[
0.0
,
0.5
,
1.0
,
1.5
,
2.0
,
3.0
,
4.0
,
6.0
,
-
0.0
,
-
0.5
,
-
1.0
,
-
1.5
,
-
2.0
,
-
3.0
,
-
4.0
,
-
6.0
,
],
device
=
x
.
device
,
dtype
=
dq_dtype
,
)
# Convert to long integers for indexing
second_bit
=
torch
.
div
(
x
,
16
,
rounding_mode
=
"floor"
).
to
(
torch
.
long
)
first_bit
=
(
x
-
second_bit
*
16
).
to
(
torch
.
long
)
# Use the long integers to index fp4_values
first_bit_values
=
fp4_values
[
first_bit
]
second_bit_values
=
fp4_values
[
second_bit
]
result
=
torch
.
zeros
(
(
first_bit_values
.
shape
[
0
],
first_bit_values
.
shape
[
1
]
*
2
),
device
=
x
.
device
,
dtype
=
dq_dtype
,
)
result
[:,
::
2
]
=
first_bit_values
result
[:,
1
::
2
]
=
second_bit_values
return
result
def
cast_to_e8
(
decode_scale
):
"""Cast to a value that is representable in FP8 E8M0.
The result is in FP32, not FP8 E8M0.
"""
max_exponent
=
torch
.
tensor
(
127
,
device
=
decode_scale
.
device
,
dtype
=
torch
.
float32
)
exponent
=
torch
.
ceil
(
torch
.
log2
(
decode_scale
))
exponent
=
torch
.
clamp
(
exponent
,
min
=-
max_exponent
,
max
=
max_exponent
)
return
torch
.
tensor
(
2.0
,
device
=
decode_scale
.
device
,
dtype
=
torch
.
float32
)
**
exponent
def
cast_to_e4m3
(
decode_scale
,
global_amax
):
"""Scale and cast to FP8 E4M3.
decode_scale is actually the encoding scaling factor. global_amax
can be any data tensor and not just the amax.
TODO(etsykunov): Make less unintuitive.
"""
decode_scale
=
decode_scale
*
global_amax
FLOAT8_E4M3_MAX
=
torch
.
tensor
(
448.0
,
device
=
decode_scale
.
device
,
dtype
=
torch
.
float32
)
decode_scale
=
torch
.
clamp
(
decode_scale
,
min
=-
FLOAT8_E4M3_MAX
,
max
=
FLOAT8_E4M3_MAX
)
return
decode_scale
.
to
(
torch
.
float8_e4m3fn
)
def
high_precision_gemm_ref
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
accumulate
:
bool
=
False
,
is_a_transposed
:
bool
=
False
,
is_b_transposed
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
scale_alpha
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""GEMM implementation with unquantized data"""
# Handle transpositions
mat1
,
mat2
=
a
,
b
if
is_a_transposed
:
mat1
=
a
.
T
if
is_b_transposed
:
mat2
=
b
.
T
# Ensure dtype compatibility for torch.addmm
mat1
=
mat1
.
to
(
out_dtype
)
mat2
=
mat2
.
to
(
out_dtype
)
# Determine output shape
y_shape
=
(
mat1
.
size
(
0
),
mat2
.
size
(
1
))
if
bias
is
not
None
:
assert
not
accumulate
,
"Bias is not supported with accumulation"
bias
=
bias
.
to
(
out_dtype
)
# With bias case
if
out_dtype
==
torch
.
float32
:
y_ref
=
torch
.
addmm
(
bias
.
repeat
(
mat1
.
size
(
0
),
1
),
mat1
,
mat2
,
beta
=
1
,
alpha
=
1
)
else
:
y_ref
=
torch
.
addmm
(
bias
,
mat1
,
mat2
,
beta
=
1
,
alpha
=
scale_alpha
)
else
:
# Without bias case
if
accumulate
and
out
is
not
None
:
y_ref
=
out
.
clone
().
to
(
out_dtype
)
else
:
y_ref
=
torch
.
zeros
(
y_shape
,
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
addmm
(
y_ref
,
mat1
,
mat2
,
beta
=
1
,
alpha
=
scale_alpha
,
out
=
y_ref
)
return
y_ref
class
NVFP4TensorRef
(
ExperimentalQuantizedTensor
):
"""NVFP4 tensor for middleware between Transformer Engine and Kitchen"""
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
("
f
"dtype=
{
self
.
dtype
}
, "
f
"device=
{
self
.
device
}
, "
f
"quant_dtype=
{
self
.
quant_dtype
}
, "
f
"data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
, "
f
"original_shape=
{
self
.
original_shape
}
"
")"
)
def
quantize_
(
self
,
tensor
:
torch
.
Tensor
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
ExperimentalQuantizedTensor
:
"""In-place update of quantized data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
if
isinstance
(
tensor
,
ExperimentalQuantizedTensor
):
return
self
.
quantize_
(
tensor
.
dequantize
(),
noop_flag
=
noop_flag
)
self
.
get_quantizer
().
update_quantized
(
tensor
,
self
,
noop_flag
=
noop_flag
)
return
self
def
dequantize
(
self
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
torch
.
Tensor
:
"""
Construct plain PyTorch tensor from quantized tensor
"""
if
dtype
is
None
:
dtype
=
self
.
dtype
# Ignore data_t for now
assert
self
.
data
is
not
None
,
"QuantizedTensor has no valid tensor data"
assert
self
.
scale
is
not
None
,
"QuantizedTensor has no valid scale"
tensor_data
=
self
.
data
tensor_scale
=
self
.
scale
# Dispatch to the quantizer
return
self
.
get_quantizer
().
dequantize
(
tensor_data
,
tensor_scale
,
dtype
=
dtype
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
"""Generate or remove quantized data based on provided usage."""
has_data
=
self
.
data
is
not
None
has_data_transpose
=
self
.
data_t
is
not
None
needs_data
=
has_data
needs_data_transpose
=
has_data_transpose
if
rowwise_usage
is
not
None
:
needs_data
=
rowwise_usage
if
columnwise_usage
is
not
None
:
needs_data_transpose
=
columnwise_usage
# Generate data that is required
if
needs_data
and
not
has_data
:
raise
RuntimeError
(
"Cannot generate FP8 data, even from FP8 data transpose"
)
if
needs_data_transpose
and
not
has_data_transpose
:
if
not
has_data
:
raise
RuntimeError
(
"FP8 data is required to generate FP8 data transpose"
)
self
.
_create_transpose
()
# Delete data that is not required
if
not
needs_data
:
self
.
data
=
None
if
not
needs_data_transpose
:
self
.
data_t
=
None
def
_create_transpose
(
self
):
"""Create transposed quantized tensor"""
if
not
self
.
data
.
is_contiguous
():
self
.
data
=
self
.
data
.
contiguous
()
self
.
data_t
=
self
.
data
.
t
().
contiguous
()
self
.
scale_t
=
self
.
scale
def
size
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
"""Return the original tensor shape, not the internal packed data shape.
FP4 quantization packs two 4-bit values into each 8-bit value, which reduces
the second dimension by half. This method returns the logical shape that
users expect, not the internal packed storage shape.
"""
assert
self
.
original_shape
is
not
None
return
torch
.
Size
(
self
.
original_shape
)
def
get_wgrad_sign_vector
()
->
torch
.
Tensor
:
"""Hard-coded signs for Hadamard transform"""
return
torch
.
tensor
(
[
1.0
,
1.0
,
1.0
,
-
1.0
,
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
1.0
,
-
1.0
,
1.0
,
-
1.0
,
-
1.0
],
dtype
=
torch
.
float32
,
)
class
NVFP4QuantizerRef
(
ExperimentalQuantizer
):
"""NVFP4 quantizer for middleware between Transformer Engine and Kitchen"""
def
__init__
(
self
,
dtype
:
utils
.
Fp4Formats
,
rowwise
:
bool
=
True
,
columnwise
:
bool
=
True
,
pow_2_scales
:
bool
=
False
,
eps
:
float
=
0.0
,
quant_tile_shape
:
Tuple
[
int
,
int
]
=
(
1
,
16
),
with_rht
:
bool
=
False
,
with_random_sign_mask
:
bool
=
True
,
):
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
dtype
self
.
pow_2_scales
=
pow_2_scales
self
.
eps
=
eps
self
.
quant_tile_shape
=
quant_tile_shape
self
.
with_rht
=
with_rht
self
.
with_random_sign_mask
=
with_random_sign_mask
@
staticmethod
def
_build_hadamard_matrix
(
size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
with_random_sign_mask
:
bool
=
True
)
->
torch
.
Tensor
:
"""Construct a Hadamard matrix of given power-of-two size with entries +-1.
Uses Sylvester construction to avoid SciPy dependency.
"""
assert
(
size
&
(
size
-
1
))
==
0
,
"Hadamard size must be a power of two"
h
=
torch
.
ones
((
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
while
h
.
shape
[
0
]
<
size
:
h
=
torch
.
cat
(
[
torch
.
cat
([
h
,
h
],
dim
=
1
),
torch
.
cat
([
h
,
-
h
],
dim
=
1
),
],
dim
=
0
,
)
if
with_random_sign_mask
:
sign_mat
=
get_wgrad_sign_vector
().
to
(
device
)
*
torch
.
eye
(
size
,
device
=
device
,
dtype
=
torch
.
float32
)
h
=
sign_mat
@
h
return
h
.
to
(
dtype
)
def
_apply_rht
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Apply randomized Hadamard transform without random signs (reference path).
This matches the reference used in tests: x_reshaped @ (H * (1/sqrt(g))).
"""
# Only apply when enabled
if
not
self
.
with_rht
:
return
x
# RHT dimension equals the quantization tile length (NVFP4 uses 16)
rht_dim
=
self
.
quant_tile_shape
[
1
]
assert
(
x
.
shape
[
-
1
]
%
rht_dim
==
0
),
f
"Inner dimension
{
x
.
shape
[
-
1
]
}
must be divisible by hadamard dimension
{
rht_dim
}
"
# Build H and scale
H
=
self
.
_build_hadamard_matrix
(
rht_dim
,
x
.
device
,
x
.
dtype
,
self
.
with_random_sign_mask
)
scale
=
1.0
/
float
(
rht_dim
)
**
0.5
# Perform blockwise transform along the last dimension
original_shape
=
x
.
shape
x_mat
=
x
.
contiguous
().
view
(
-
1
,
rht_dim
)
# Random sign matrix is identity in this reference (no sign flipping)
transform
=
H
*
scale
out
=
x_mat
@
transform
return
out
.
view
(
original_shape
)
@
staticmethod
def
_recover_swizzled_scales
(
swizzled_scale
:
bool
,
scale
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
block_length
:
int
)
->
torch
.
Tensor
:
if
not
swizzled_scale
:
return
scale
rounded_m
=
utils
.
roundup_div
(
m
,
128
)
*
128
scale_n
=
utils
.
roundup_div
(
n
,
block_length
)
rounded_n
=
utils
.
roundup_div
(
scale_n
,
4
)
*
4
# Recover swizzled scaling factor layout -> linear layout
tmp
=
torch
.
reshape
(
scale
,
(
rounded_m
//
128
,
rounded_n
//
4
,
32
,
4
,
4
))
# after permutation, the layout is [rounded_m // 128, 4, 32, rounded_n // 4, 4]
tmp
=
torch
.
permute
(
tmp
,
(
0
,
3
,
2
,
1
,
4
))
result
=
torch
.
reshape
(
tmp
,
(
rounded_m
,
rounded_n
))
return
result
[:
m
,
:
scale_n
]
@
classmethod
def
_quantize_blockwise_reference
(
cls
,
x
:
torch
.
Tensor
,
global_amax
:
torch
.
Tensor
,
tile_len_x
:
int
,
tile_len_y
:
int
,
*
,
pow_2_scales
:
bool
,
eps
:
float
,
# pylint: disable=unused-argument
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
ndim
==
2
using_2d_quantization
=
tile_len_x
==
16
and
tile_len_y
==
16
m
,
n
=
x
.
shape
# Compute vec_max based on the original x (before reshape)
# For 1D quantization: amax over each row chunk of 16
# For 2D quantization: amax over each 16x16 block, but output shape is still (128, 8, 1), filled with block amax
if
using_2d_quantization
:
# x shape: (128, 128)
x_blocks
=
(
x
.
unfold
(
0
,
tile_len_y
,
tile_len_y
)
.
unfold
(
1
,
tile_len_x
,
tile_len_x
)
.
to
(
torch
.
float32
)
)
# (8, 8, 16, 16)
block_amax
=
torch
.
amax
(
torch
.
abs
(
x_blocks
),
dim
=
(
-
1
,
-
2
))
# (8, 8)
# Now, expand to (128, 8, 1) by repeating each block_amax for 16 rows
vec_max
=
block_amax
.
repeat_interleave
(
tile_len_y
,
dim
=
0
).
unsqueeze
(
-
1
)
# (128, 8, 1)
else
:
# x shape: (128, 128)
x_reshaped
=
x
.
view
(
m
,
n
//
tile_len_x
,
tile_len_x
)
# (128, 8, 16)
vec_max
=
torch
.
amax
(
torch
.
abs
(
x_reshaped
),
dim
=-
1
,
keepdim
=
True
).
to
(
torch
.
float32
)
# (128, 8, 1)
x
=
x
.
view
(
m
,
n
//
tile_len_x
,
tile_len_x
)
FLOAT4_E2M1_MAX
=
torch
.
tensor
(
6.0
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
FLOAT8_E4M3_MAX
=
torch
.
tensor
(
448.0
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
decode_scale
=
torch
.
div
(
vec_max
,
FLOAT4_E2M1_MAX
)
if
pow_2_scales
:
decode_scale
=
cast_to_e8
(
decode_scale
)
encode_scale
=
torch
.
div
(
torch
.
tensor
(
1.0
,
device
=
x
.
device
,
dtype
=
torch
.
float32
),
decode_scale
.
to
(
torch
.
float32
),
)
else
:
global_encode_scale
=
torch
.
div
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
,
global_amax
)
global_encode_scale
=
torch
.
min
(
global_encode_scale
,
torch
.
tensor
(
torch
.
finfo
(
torch
.
float32
).
max
,
device
=
global_encode_scale
.
device
,
dtype
=
torch
.
float32
,
),
)
if
global_encode_scale
==
torch
.
tensor
(
0.0
,
device
=
x
.
device
,
dtype
=
torch
.
float32
):
global_encode_scale
=
torch
.
tensor
(
1.0
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
global_decode_scale
=
torch
.
div
(
1.0
,
global_encode_scale
)
decode_scale
=
decode_scale
*
global_encode_scale
decode_scale
=
torch
.
min
(
decode_scale
,
torch
.
tensor
(
torch
.
finfo
(
torch
.
float32
).
max
,
device
=
decode_scale
.
device
,
dtype
=
torch
.
float32
,
),
)
decode_scale
=
torch
.
clamp
(
decode_scale
,
min
=-
FLOAT8_E4M3_MAX
,
max
=
FLOAT8_E4M3_MAX
)
decode_scale
=
decode_scale
.
to
(
torch
.
float8_e4m3fn
)
encode_scale
=
torch
.
min
(
torch
.
div
(
1.0
,
decode_scale
.
to
(
torch
.
float32
)
*
global_decode_scale
),
torch
.
tensor
(
torch
.
finfo
(
torch
.
float32
).
max
,
device
=
decode_scale
.
device
,
dtype
=
torch
.
float32
,
),
)
scaled_x
=
x
.
to
(
torch
.
float32
)
*
encode_scale
clipped_x
=
torch
.
clamp
(
scaled_x
,
-
FLOAT4_E2M1_MAX
,
FLOAT4_E2M1_MAX
).
reshape
(
m
,
n
)
return
cast_to_fp4x2
(
clipped_x
),
decode_scale
.
squeeze
(
-
1
)
@
staticmethod
def
_pad_tensor
(
tensor
:
torch
.
Tensor
,
row_divisor
:
Optional
[
int
],
col_divisor
:
Optional
[
int
]
)
->
torch
.
Tensor
:
assert
tensor
.
dim
()
==
2
,
"only supports 2D tensors"
M
,
N
=
tensor
.
shape
padding_needed_rows
=
0
padding_needed_cols
=
0
if
row_divisor
is
not
None
and
M
%
row_divisor
!=
0
:
padding_needed_rows
=
row_divisor
-
(
M
%
row_divisor
)
# Check and calculate column padding if col_divisor is provided
if
col_divisor
is
not
None
and
N
%
col_divisor
!=
0
:
padding_needed_cols
=
col_divisor
-
(
N
%
col_divisor
)
# Return original tensor if no padding is needed
if
padding_needed_rows
==
0
and
padding_needed_cols
==
0
:
return
tensor
# pad the tensor
out
=
torch
.
nn
.
functional
.
pad
(
tensor
,
(
0
,
padding_needed_cols
,
0
,
padding_needed_rows
),
mode
=
"constant"
,
value
=
0.0
,
).
contiguous
()
return
out
@
staticmethod
def
_rm_pad_tensor
(
tensor
:
torch
.
Tensor
,
original_size
:
tuple
[
int
,
...])
->
torch
.
Tensor
:
assert
tensor
.
dim
()
==
2
,
"only supports 2D tensors"
M
,
N
=
original_size
out
=
tensor
[:
M
,
:
N
].
contiguous
()
return
out
def
_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
]:
"""
Python implementation of microblock FP4 quantization.
Parameters
----------
tensor : torch.Tensor
Input tensor to quantize (should be 2D)
Returns
-------
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor]
(qx, sx, qx_t, sx_t, global_amax) where:
- qx: quantized data in row-major order (if rowwise_usage), None otherwise
- sx: scale tensor for qx (if rowwise_usage), None otherwise
- qx_t: quantized data in column-major order (if columnwise_usage), None otherwise
- sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise
- global_amax: global amax tensor
"""
if
self
.
pow_2_scales
:
assert
self
.
quant_tile_shape
==
(
1
,
32
,
),
"MXFP4 only supports 1x32 tile shape."
# TODO(etsykunov): Fix bug where global_amax_row and
# global_amax_col are not defined
# global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32)
else
:
assert
self
.
quant_tile_shape
in
(
(
1
,
16
),
(
16
,
16
),
),
"NVFP4 only supports 1x16 or 16x16 tile shape."
# Prepare inputs once so we can reuse for both amax and quantization
# Row-input will always be the original input.
row_input
=
tensor
col_input
=
(
self
.
_apply_rht
(
tensor
.
t
().
contiguous
())
if
self
.
with_rht
else
tensor
.
t
().
contiguous
()
)
# Compute amax for rowwise and columnwise paths separately
global_amax_row
=
torch
.
max
(
torch
.
abs
(
row_input
)).
to
(
torch
.
float32
).
view
(
1
)
global_amax_col
=
(
torch
.
max
(
torch
.
abs
(
col_input
)).
to
(
torch
.
float32
).
view
(
1
)
if
self
.
columnwise_usage
else
global_amax_row
)
transpose_scales
=
False
M
,
N
=
tensor
.
shape
if
self
.
rowwise_usage
:
x_input
=
row_input
x_padded
=
self
.
_pad_tensor
(
x_input
,
row_divisor
=
self
.
quant_tile_shape
[
0
],
col_divisor
=
self
.
quant_tile_shape
[
1
]
)
qx
,
sx
=
self
.
_quantize_blockwise_reference
(
x_padded
,
global_amax_row
,
self
.
quant_tile_shape
[
1
],
self
.
quant_tile_shape
[
0
],
pow_2_scales
=
self
.
pow_2_scales
,
eps
=
self
.
eps
,
)
if
transpose_scales
:
sx
=
sx
.
T
qx
=
self
.
_rm_pad_tensor
(
qx
,
(
M
,
N
//
2
))
else
:
qx
=
None
sx
=
None
if
self
.
columnwise_usage
:
x_t
=
col_input
x_t_padded
=
self
.
_pad_tensor
(
x_t
,
row_divisor
=
self
.
quant_tile_shape
[
0
],
col_divisor
=
self
.
quant_tile_shape
[
1
]
)
qx_t
,
sx_t
=
self
.
_quantize_blockwise_reference
(
x_t_padded
,
global_amax_col
,
self
.
quant_tile_shape
[
1
],
self
.
quant_tile_shape
[
0
],
pow_2_scales
=
self
.
pow_2_scales
,
eps
=
self
.
eps
,
)
qx_t
=
self
.
_rm_pad_tensor
(
qx_t
,
(
N
,
M
//
2
))
if
transpose_scales
:
sx_t
=
sx_t
.
T
else
:
qx_t
=
None
sx_t
=
None
return
qx
,
sx
,
qx_t
,
sx_t
,
global_amax_row
,
global_amax_col
def
quantize
(
self
,
tensor
:
torch
.
Tensor
,
**
kwargs
,
# pylint: disable=unused-argument
)
->
NVFP4TensorRef
:
# sanity checks
assert
tensor
.
dtype
in
utils
.
HIGH_PRECISION_FLOAT_DTYPES
,
"Unsupported input dtype."
# Make it work with 3D tensors
original_shape
=
tensor
.
shape
if
tensor
.
ndim
>
2
:
tensor
=
tensor
.
view
(
-
1
,
tensor
.
shape
[
-
1
])
qx
,
sx
,
qx_t
,
sx_t
,
global_amax_row
,
global_amax_col
=
self
.
_quantize
(
tensor
)
return
NVFP4TensorRef
(
data
=
qx
,
scale
=
sx
,
data_t
=
qx_t
,
scale_t
=
sx_t
,
global_amax_row
=
global_amax_row
,
global_amax_col
=
global_amax_col
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
,
quant_dtype
=
self
.
dtype
,
quantizer
=
self
,
original_shape
=
original_shape
,
)
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
dst
:
ExperimentalQuantizedTensor
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
ExperimentalQuantizedTensor
:
"""Update the quantized tensor with the given tensor in-place
Parameters
----------
src: torch.Tensor
Source tensor to copy from
dst: ExperimentalQuantizedTensor
Destination ExperimentalQuantizedTensor to update
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
# Handle noop flag
if
noop_flag
is
not
None
and
noop_flag
.
item
()
!=
0
:
return
dst
# Make sure input is in expected format
if
not
src
.
is_contiguous
():
src
=
src
.
contiguous
()
# Store the original shape and reshape for processing
original_shape
=
src
.
shape
if
src
.
ndim
>
2
:
src
=
src
.
view
(
-
1
,
src
.
shape
[
-
1
])
qx
,
sx
,
qx_t
,
sx_t
,
global_amax
=
self
.
_quantize
(
src
)
# Update the destination with new data
dst
.
data
=
qx
dst
.
scale
=
sx
dst
.
data_t
=
qx_t
dst
.
scale_t
=
sx_t
dst
.
global_amax
=
global_amax
dst
.
dtype
=
src
.
dtype
dst
.
quant_dtype
=
self
.
dtype
dst
.
original_shape
=
original_shape
return
dst
@
property
def
supports_allgather_fp8
(
self
)
->
bool
:
"""Whether the tensor data can be all-gathered with an FP8 all-gather.
TODO(etsykunov): Confirm docstring is correct. Also, this API
seems too FP8-specific and should be reconsidered.
"""
return
False
def
transpose_qresult
(
self
,
qresult
:
quantization
.
ExperimentalQuantizedTensor
)
->
quantization
.
ExperimentalQuantizedTensor
:
"""Convert row-wise data to column-wise data (?)
TODO(etsykunov): Confirm docstring is correct.
"""
raise
NotImplementedError
(
"Transpose qresult is not implemented for FP4."
)
@
property
def
supports_dequantize
(
self
)
->
bool
:
"""Whether quantized tensor can converted to high-precision tensor"""
return
False
@
property
def
is_data_t_transposed_in_memory
(
self
)
->
bool
:
"""Whether column-wise data is stored in transposed layout.
TODO(etsykunov): Confirm docstring is correct.
"""
raise
NotImplementedError
(
"Not implemented yet"
)
def
dequantize
(
self
,
tensor
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
torch
.
Tensor
:
"""Dequantize the quantized tensor"""
raise
NotImplementedError
(
"Not implemented yet"
)
def
qgemm
(
self
,
qx
:
torch
.
Tensor
,
qw
:
torch
.
Tensor
,
m_params
:
quantization
.
MMParams
,
out_dtype
:
torch
.
dtype
,
sx
:
torch
.
Tensor
,
sw
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
out
:
torch
.
Tensor
|
None
=
None
,
accumulate
:
bool
=
False
,
gemm_type
:
quantization
.
GEMMType
=
quantization
.
GEMMType
.
FPROP
,
qresult_x
:
quantization
.
ExperimentalQuantizedTensor
|
None
=
None
,
qresult_w
:
quantization
.
ExperimentalQuantizedTensor
|
None
=
None
,
)
->
torch
.
Tensor
:
assert
bias
is
None
,
"Bias is implemented for FP4 GEMM."
high_precision_x
=
cast_from_fp4x2
(
qx
,
out_dtype
)
high_precision_w
=
cast_from_fp4x2
(
qw
,
out_dtype
)
if
self
.
pow_2_scales
:
if
sx
.
dtype
==
torch
.
uint8
:
# if scaling factor is stored in uint8 container
sx
=
torch
.
tensor
(
2.0
,
device
=
sx
.
device
,
dtype
=
torch
.
float32
)
**
(
(
sx
.
to
(
torch
.
float32
)
-
torch
.
tensor
(
127
,
device
=
sx
.
device
,
dtype
=
torch
.
float32
)
)
)
sw
=
torch
.
tensor
(
2.0
,
device
=
sw
.
device
,
dtype
=
torch
.
float32
)
**
(
(
sw
.
to
(
torch
.
float32
)
-
torch
.
tensor
(
127
,
device
=
sw
.
device
,
dtype
=
torch
.
float32
)
)
)
else
:
# if scaling factor is torch.float8_e8m0fnu
sx
=
sx
.
to
(
torch
.
float32
)
sw
=
sw
.
to
(
torch
.
float32
)
alpha
=
torch
.
tensor
(
1.0
,
device
=
high_precision_x
.
device
,
dtype
=
torch
.
float32
)
else
:
assert
qresult_x
is
not
None
assert
qresult_w
is
not
None
assert
qresult_x
.
global_amax_row
is
not
None
assert
qresult_w
.
global_amax_col
is
not
None
sx
=
sx
.
to
(
torch
.
float32
)
sw
=
sw
.
to
(
torch
.
float32
)
factor
=
6.0
*
6.0
*
448.0
*
448.0
if
gemm_type
==
quantization
.
GEMMType
.
WGRAD
:
partial_alpha
=
qresult_x
.
global_amax_col
*
qresult_w
.
global_amax_col
else
:
partial_alpha
=
qresult_x
.
global_amax_row
*
qresult_w
.
global_amax_row
alpha
=
torch
.
div
(
partial_alpha
,
factor
).
squeeze
(
-
1
)
M
,
K
=
high_precision_x
.
shape
N
,
K_w
=
high_precision_w
.
shape
assert
K
==
K_w
,
"K dimension mismatch between qx and qw"
assert
K
%
32
==
0
,
"K dimension must be divisible by 32"
assert
N
%
8
==
0
,
"N dimension must be divisible by 8"
block_length
=
32
if
self
.
pow_2_scales
else
16
grid_k
=
K
//
block_length
assert
sx
.
shape
==
(
M
,
K
//
block_length
,
),
f
"sx shape mismatch: expected (
{
M
}
,
{
K
//
block_length
}
), got
{
sx
.
shape
}
"
assert
sw
.
shape
==
(
N
,
K
//
block_length
,
),
f
"sw shape mismatch: expected (
{
N
}
,
{
K
//
block_length
}
), got
{
sw
.
shape
}
"
y
=
torch
.
zeros
(
M
,
N
,
dtype
=
torch
.
float32
,
device
=
qx
.
device
)
# below implementation is to match the FP4 tensor core implementation
# Each output element (i, j) is fp32 accumulation of (K // block_length) inner products
# Each inner product is sx * sw * (1, block_length) x (block_length, 1) with precision in fp32
# Then batch the computation in M, N dimension
for
k
in
range
(
grid_k
):
k_start
=
k
*
block_length
k_end
=
k_start
+
block_length
qx_block
=
high_precision_x
[:,
k_start
:
k_end
].
clone
().
contiguous
()
qw_block
=
high_precision_w
[:,
k_start
:
k_end
].
clone
().
contiguous
()
# Extract scaling factors for the current blocks
sx_block
=
sx
[:,
k
]
sw_block
=
sw
[:,
k
]
y
+=
torch
.
outer
(
sx_block
,
sw_block
)
*
high_precision_gemm_ref
(
qx_block
,
qw_block
,
torch
.
float32
,
is_b_transposed
=
True
)
if
not
self
.
pow_2_scales
and
K
>
0
:
# only apply global scale for NVFP4 and non-empty cases
y
=
alpha
*
y
# accumulation happens at epilogue in float32
if
accumulate
:
assert
out
is
not
None
,
"Output tensor must be provided for accumulation."
y
+=
out
.
to
(
torch
.
float32
)
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
y
=
y
.
to
(
out_dtype
)
return
y
transformer_engine/pytorch/experimental/utils.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility functions for experimental middleware between Transformer Engine and Kitchen."""
import
enum
import
torch
HIGH_PRECISION_FLOAT_DTYPES
=
(
torch
.
float
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
)
class
Fp4Formats
(
enum
.
Enum
):
"""FP4 data format"""
E2M1
=
"e2m1"
def
roundup_div
(
x
:
int
,
y
:
int
)
->
int
:
"""Round up division"""
assert
x
>=
0
assert
y
>
0
return
(
x
+
y
-
1
)
//
y
transformer_engine/pytorch/fp8.py
View file @
53fa872c
...
...
@@ -21,6 +21,7 @@ from transformer_engine.common.recipe import (
MXFP8BlockScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
NVFP4BlockScaling
,
)
from
.constants
import
dist_group_type
...
...
@@ -64,6 +65,13 @@ def check_mxfp8_support() -> Tuple[bool, str]:
return
False
,
"Device compute capability 10.0 or higher required for MXFP8 execution."
def
check_nvfp4_support
()
->
Tuple
[
bool
,
str
]:
"""Return if nvfp4 support is available"""
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
return
True
,
""
return
False
,
"Device compute capability 10.0 or higher required for NVFP4 execution."
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 block scaling support is available"""
if
IS_HIP_EXTENSION
:
...
...
@@ -121,6 +129,13 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType
return
tex
.
DType
.
kFloat8E5M2
def
get_fp4_te_dtype
(
fp4_recipe
:
Recipe
)
->
tex
.
DType
:
"""Get fp4 data type according to recipe and tensor"""
if
fp4_recipe
.
fp4_format
==
Format
.
E2M1
:
return
tex
.
DType
.
kFloat4E2M1
raise
ValueError
(
f
"Unsupported FP4 format:
{
fp4_recipe
.
fp4_format
}
"
)
def
get_fp8_max
(
fp8_recipe
:
Recipe
,
fprop_tensor
:
bool
=
True
)
->
tex
.
DType
:
"""Get max representible FP8 value."""
if
fp8_recipe
.
fp8_format
==
Format
.
E4M3
or
(
...
...
@@ -158,6 +173,8 @@ class FP8GlobalStateManager:
reason_for_no_mxfp8
=
""
fp8_block_scaling_available
=
None
reason_for_no_fp8_block_scaling
=
None
nvfp4_available
=
None
reason_for_no_nvfp4
=
""
@
classmethod
def
reset
(
cls
)
->
None
:
...
...
@@ -221,6 +238,13 @@ class FP8GlobalStateManager:
)
return
cls
.
fp8_block_scaling_available
,
cls
.
reason_for_no_fp8_block_scaling
@
classmethod
def
is_nvfp4_available
(
cls
)
->
Tuple
[
bool
,
str
]:
"""Return if NVFP4 support is available."""
if
cls
.
nvfp4_available
is
None
:
cls
.
nvfp4_available
,
cls
.
reason_for_no_nvfp4
=
check_nvfp4_support
()
return
cls
.
nvfp4_available
,
cls
.
reason_for_no_nvfp4
@
staticmethod
def
get_meta_tensor_key
(
forward
:
bool
=
True
)
->
str
:
"""Returns scaling key in `fp8_meta`."""
...
...
@@ -497,6 +521,9 @@ class FP8GlobalStateManager:
if
isinstance
(
fp8_recipe
,
Float8BlockScaling
):
fp8_block_available
,
reason_for_no_fp8_block
=
cls
.
is_fp8_block_scaling_available
()
assert
fp8_block_available
,
reason_for_no_fp8_block
if
isinstance
(
fp8_recipe
,
NVFP4BlockScaling
):
nvfp4_available
,
reason_for_no_nvfp4
=
cls
.
is_nvfp4_available
()
assert
nvfp4_available
,
reason_for_no_nvfp4
@
classmethod
def
fp8_autocast_exit
(
cls
,
enabled
:
bool
,
_graph
:
bool
)
->
None
:
...
...
@@ -853,6 +880,8 @@ class RecipeState(abc.ABC):
cls
=
Float8CurrentScalingRecipeState
elif
recipe
.
float8_block_scaling
():
cls
=
Float8BlockScalingRecipeState
elif
recipe
.
nvfp4
():
cls
=
NVFP4BlockScalingRecipeState
else
:
raise
ValueError
(
f
"
{
recipe
.
__class__
.
__name__
}
is not supported"
)
return
cls
(
...
...
@@ -957,7 +986,9 @@ class Float8CurrentScalingRecipeState(RecipeState):
from
.tensor.float8_tensor
import
Float8CurrentScalingQuantizer
return
[
Float8CurrentScalingQuantizer
(
self
.
dtype
,
device
=
self
.
device
)
Float8CurrentScalingQuantizer
(
self
.
dtype
,
device
=
self
.
device
,
force_pow_2_scales
=
self
.
recipe
.
use_power_2_scales
)
for
i
in
range
(
self
.
num_quantizers
)
]
...
...
@@ -1100,3 +1131,79 @@ class Float8BlockScalingRecipeState(RecipeState):
]
)
)
class
NVFP4BlockScalingRecipeState
(
RecipeState
):
"""Configuration for NVFP4 quantization.
NVFP4 quantization does not require state.
"""
recipe
:
NVFP4BlockScaling
mode
:
str
dtype
:
tex
.
DType
def
__init__
(
self
,
recipe
:
NVFP4BlockScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
dtype
=
get_fp4_te_dtype
(
recipe
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
def
make_quantizers
(
self
)
->
list
:
from
.tensor.nvfp4_tensor
import
NVFP4Quantizer
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward. It assumes forward quantizers are
# ordered [input, weight, output, ...] and backward quantizers
# are ordered [grad_output, grad_input, ...]. This doesn't
# play nicely with fusible ops: Linear op doesn't own output
# or grad input quantizers, Quantize op only owns input and
# grad output quantizers.
if
self
.
mode
==
"forward"
:
def
_make_quantizer
(
idx
:
int
)
->
NVFP4Quantizer
:
qparams
=
(
self
.
recipe
.
fp4_quant_fwd_weight
if
idx
%
3
==
1
else
self
.
recipe
.
fp4_quant_fwd_inp
)
return
NVFP4Quantizer
(
fp4_dtype
=
self
.
dtype
,
rowwise
=
True
,
columnwise
=
True
,
with_rht
=
qparams
.
random_hadamard_transform
,
with_post_rht_amax
=
qparams
.
random_hadamard_transform
,
with_2d_quantization
=
qparams
.
fp4_2d_quantization
,
stochastic_rounding
=
qparams
.
stochastic_rounding
,
)
return
[
_make_quantizer
(
idx
)
for
idx
in
range
(
self
.
num_quantizers
)]
if
self
.
mode
==
"backward"
:
return
[
NVFP4Quantizer
(
fp4_dtype
=
self
.
dtype
,
rowwise
=
True
,
columnwise
=
True
,
with_rht
=
self
.
recipe
.
fp4_quant_bwd_grad
.
random_hadamard_transform
,
with_post_rht_amax
=
self
.
recipe
.
fp4_quant_bwd_grad
.
random_hadamard_transform
,
with_2d_quantization
=
self
.
recipe
.
fp4_quant_bwd_grad
.
fp4_2d_quantization
,
stochastic_rounding
=
self
.
recipe
.
fp4_quant_bwd_grad
.
stochastic_rounding
,
)
for
_
in
range
(
self
.
num_quantizers
)
]
raise
RuntimeError
(
f
"Unexpected recipe mode (
{
self
.
mode
}
)"
)
transformer_engine/pytorch/module/_common.py
View file @
53fa872c
...
...
@@ -4,16 +4,18 @@
"""Internal function used by multiple modules."""
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
,
Callable
from
dataclasses
import
dataclass
import
dataclasses
import
queue
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
..
import
cpp_extensions
as
tex
from
..
import
experimental
from
..constants
import
TE_DType
from
..utils
import
get_default_init_method
from
..export
import
is_in_onnx_export_mode
from
..tensor.utils
import
is_experimental
from
..utils
import
get_default_init_method
import
warnings
try
:
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
...
...
@@ -179,7 +181,33 @@ def noop_cat(
return
_NoopCatFunc
.
apply
(
dim
,
*
tensors
)
@
dataclass
def
get_module_quantizers
(
module
:
torch
.
nn
.
Module
,
fp8_output
:
bool
,
fp8_grad
:
bool
,
debug
:
bool
,
):
"""Return the 6-tuple of quantizers for a module in a centralized way.
Routing policy:
- If experimental quantization is enabled via environment and module.fp8 is True,
return experimental quantizers.
- Otherwise, return the module's own quantizers (debug or regular).
"""
if
getattr
(
module
,
"fp8"
,
False
)
and
is_experimental
():
# TODO(etsykunov): Quantizer instantiation should be better
# done in the module's constructor
qlinear_params
=
experimental
.
config
.
set_qlinear_params
()
if
qlinear_params
is
not
None
:
return
experimental
.
config
.
get_experimental_quantizers
(
module
.
fp8
,
qlinear_params
)
if
not
debug
:
return
module
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
return
module
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
@
dataclasses
.
dataclass
class
_ParameterInitMeta
:
"""
Stores essential metadata needed to support deferred parameter initialization.
...
...
transformer_engine/pytorch/module/base.py
View file @
53fa872c
...
...
@@ -27,6 +27,7 @@ from ..fp8 import (
DelayedScalingRecipeState
,
Float8CurrentScalingRecipeState
,
Float8BlockScalingRecipeState
,
NVFP4BlockScalingRecipeState
,
FP8GlobalStateManager
,
RecipeState
,
)
...
...
@@ -39,6 +40,7 @@ from ..distributed import (
from
..constants
import
dist_group_type
from
..tensor.quantized_tensor
import
QuantizedTensor
,
QuantizedTensorBase
,
Quantizer
from
..tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
..tensor.nvfp4_tensor
import
NVFP4Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
...
...
@@ -82,7 +84,8 @@ def get_cublas_workspace_size_bytes() -> None:
if
IS_HIP_EXTENSION
:
return
134_217_728
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
return
33_554_432
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return
32
*
1024
*
1024
+
1024
return
4_194_304
...
...
@@ -802,6 +805,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe_state
,
Float8BlockScalingRecipeState
):
return
if
recipe
.
nvfp4
()
and
isinstance
(
recipe_state
,
NVFP4BlockScalingRecipeState
):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
...
...
@@ -1011,12 +1016,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return
dtype
=
inp
.
dtype
for
name
,
param
in
self
.
named_parameters
():
if
param
is
not
None
:
assert
dtype
==
param
.
dtype
,
(
"Data types for parameters must match when outside of autocasted region. "
f
" Found input dtype:
{
dtype
}
and
{
name
!
r
}
dtype:
{
param
.
dtype
}
"
)
if
not
self
.
allow_different_data_and_param_types
:
for
name
,
param
in
self
.
named_parameters
():
if
param
is
not
None
:
assert
dtype
==
param
.
dtype
,
(
"Data types for parameters must match when outside of autocasted region. "
f
" Found input dtype:
{
dtype
}
and
{
name
!
r
}
dtype:
{
param
.
dtype
}
"
)
self
.
activation_dtype
=
dtype
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
...
...
@@ -1105,6 +1111,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
inp
:
torch
.
Tensor
,
num_gemms
:
int
=
1
,
allow_non_contiguous
:
bool
=
False
,
allow_different_data_and_param_types
:
bool
=
False
,
)
->
Generator
[
torch
.
Tensor
,
None
,
None
]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
...
...
@@ -1112,6 +1119,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
self
.
allow_different_data_and_param_types
=
allow_different_data_and_param_types
self
.
forwarded_at_least_once
=
True
# Activation recomputation is used and this is the second forward phase.
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
...
...
@@ -1260,15 +1268,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
):
grad_bias
=
grad_output
.
dequantize
().
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
else
:
if
isinstance
(
quantizer
,
Float8BlockQuantizer
)
or
(
isinstance
(
quantizer
,
Float8CurrentScalingQuantizer
)
and
IS_HIP_EXTENSION
):
if
isinstance
(
quantizer
,
(
Float8BlockQuantizer
,
NVFP4Quantizer
)
)
or
(
isinstance
(
quantizer
,
Float8CurrentScalingQuantizer
)
and
IS_HIP_EXTENSION
):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
grad_bias
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
else
:
grad_bias
,
grad_output
=
tex
.
bgrad_quantize
(
grad_output
,
quantizer
)
if
not
isinstance
(
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
,
Float8BlockwiseQTensorBase
),
):
if
not
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
=
quantizer
(
grad_output
)
return
grad_output
,
grad_bias
...
...
Prev
1
…
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment