Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2623 additions
and
1375 deletions
+2623
-1375
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+176
-94
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+39
-14
transformer_engine/pytorch/csrc/pybind.h
transformer_engine/pytorch/csrc/pybind.h
+20
-12
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+595
-9
transformer_engine/pytorch/csrc/type_converters.cpp
transformer_engine/pytorch/csrc/type_converters.cpp
+40
-0
transformer_engine/pytorch/csrc/util.cpp
transformer_engine/pytorch/csrc/util.cpp
+101
-24
transformer_engine/pytorch/csrc/util.h
transformer_engine/pytorch/csrc/util.h
+12
-0
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+290
-37
transformer_engine/pytorch/experimental/__init__.py
transformer_engine/pytorch/experimental/__init__.py
+2
-1
transformer_engine/pytorch/experimental/gemm.py
transformer_engine/pytorch/experimental/gemm.py
+137
-0
transformer_engine/pytorch/experimental/quantization.py
transformer_engine/pytorch/experimental/quantization.py
+29
-0
transformer_engine/pytorch/experimental/quantization_nvfp4.py
...sformer_engine/pytorch/experimental/quantization_nvfp4.py
+887
-0
transformer_engine/pytorch/experimental/utils.py
transformer_engine/pytorch/experimental/utils.py
+30
-0
transformer_engine/pytorch/float8_tensor.py
transformer_engine/pytorch/float8_tensor.py
+10
-0
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+55
-1086
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+140
-49
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+6
-5
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+52
-42
transformer_engine/pytorch/module/fp8_padding.py
transformer_engine/pytorch/module/fp8_padding.py
+1
-1
transformer_engine/pytorch/module/fp8_unpadding.py
transformer_engine/pytorch/module/fp8_unpadding.py
+1
-1
No files found.
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
063ef88d
...
@@ -66,67 +66,102 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -66,67 +66,102 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Input and param tensors
// Input and param tensors
auto
none
=
py
::
none
();
auto
none
=
py
::
none
();
const
TensorWrapper
&
input_
cu
=
makeTransformerEngineTensor
(
input
,
none
);
const
TensorWrapper
&
input_
nvte
=
makeTransformerEngineTensor
(
input
,
none
);
const
TensorWrapper
&
weight_
cu
=
makeTransformerEngineTensor
(
weight
,
none
);
const
TensorWrapper
&
weight_
nvte
=
makeTransformerEngineTensor
(
weight
,
none
);
TensorWrapper
bias_
cu
;
TensorWrapper
bias_
nvte
;
if
(
bias
.
has_value
())
{
if
(
bias
.
has_value
())
{
bias_
cu
=
makeTransformerEngineTensor
(
*
bias
);
bias_
nvte
=
makeTransformerEngineTensor
(
*
bias
);
}
}
// Tensor dimensions
// Tensor dimensions
const
size_t
N
=
static_cast
<
size_t
>
(
input_cu
.
siz
e
(
0
));
const
auto
shape
=
nvte_shape_to_vector
(
input_nvte
.
shap
e
());
const
size_t
H
=
static_cast
<
size_t
>
(
input_cu
.
size
(
1
)
);
const
auto
outer_size
=
product
(
shape
)
/
shape
.
back
(
);
const
std
::
vector
<
size_t
>
size
=
{
N
,
H
}
;
const
auto
inner_size
=
shape
.
back
()
;
// Tensors to save for backward pass
// Tensors to save for backward pass
at
::
Tensor
mu
=
at
::
empty
({
static_cast
<
int64_t
>
(
N
)},
at
::
CUDA
(
at
::
kFloat
));
at
::
Tensor
mu
_py
=
at
::
empty
({
static_cast
<
int64_t
>
(
outer_size
)},
at
::
CUDA
(
at
::
kFloat
));
at
::
Tensor
rsigma
=
at
::
empty
({
static_cast
<
int64_t
>
(
N
)},
at
::
CUDA
(
at
::
kFloat
));
at
::
Tensor
rsigma
_py
=
at
::
empty
({
static_cast
<
int64_t
>
(
outer_size
)},
at
::
CUDA
(
at
::
kFloat
));
TensorWrapper
mu_
cu
=
makeTransformerEngineTensor
(
mu
);
TensorWrapper
mu_
nvte
=
makeTransformerEngineTensor
(
mu
_py
);
TensorWrapper
rsigma_
cu
=
makeTransformerEngineTensor
(
rsigma
);
TensorWrapper
rsigma_
nvte
=
makeTransformerEngineTensor
(
rsigma
_py
);
// Output tensor
// Output tensor
std
::
unique_ptr
<
Quantizer
>
my_
quantizer
=
convert_quantizer
(
quantizer
);
auto
quantizer
_cpp
=
convert_quantizer
(
quantizer
);
TensorWrapper
out_
cu
;
TensorWrapper
out_
nvte
;
if
(
out
.
is_none
())
{
if
(
out
.
is_none
())
{
std
::
tie
(
out_
cu
,
out
)
=
my_
quantizer
->
create_tensor
(
s
iz
e
,
out_dtype
);
std
::
tie
(
out_
nvte
,
out
)
=
quantizer
_cpp
->
create_tensor
(
s
hap
e
,
out_dtype
);
}
else
{
}
else
{
out_
cu
=
makeTransformerEngineTensor
(
out
,
quantizer
);
out_
nvte
=
makeTransformerEngineTensor
(
out
,
quantizer
);
}
}
// Determine whether to avoid fused kernel
// Choose implementation
bool
force_unfused_kernel
=
true
;
enum
class
Impl
{
if
(
quantizer
.
is_none
())
{
// Compute norm in high precision, then quantize
// No need for separate quantization step if output is unquantized
UNFUSED
,
force_unfused_kernel
=
false
;
// Compute norm directly
}
else
if
(
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
FULLY_FUSED
,
// Always used fused kernel for FP8 delayed scaling
// Compute norm and amax in high precision, then quantize to FP8
force_unfused_kernel
=
false
;
FUSED_NORM_AMAX_FP8
,
// Compute norm and amax in high precision, then quantize to NVFP4
FUSED_NORM_AMAX_NVFP4
};
Impl
impl
=
Impl
::
UNFUSED
;
if
(
quantizer
.
is_none
()
||
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
impl
=
Impl
::
FULLY_FUSED
;
}
else
if
(
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
}
else
if
(
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
)
&&
outer_size
%
128
==
0
&&
// cuDNN MXFP8 kernel requires full tile
inner_size
%
128
==
0
)
{
force_unfused_kernel
=
N
%
128
!=
0
||
H
%
128
!=
0
;
// cuDNN MXFP8 kernel requires full 128x128 tiles
impl
=
Impl
::
FULLY_FUSED
;
}
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
fp8_quantizer_cpp
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
fp8_quantizer_cpp
!=
nullptr
,
"Could not cast to FP8 current scaling quantizer"
);
impl
=
Impl
::
FUSED_NORM_AMAX_FP8
;
}
else
if
(
detail
::
IsNVFP4Quantizers
(
quantizer
.
ptr
()))
{
auto
nvfp4_quantizer_cpp
=
dynamic_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
nvfp4_quantizer_cpp
!=
nullptr
,
"Could not cast to NVFP4 quantizer"
);
if
(
nvfp4_quantizer_cpp
->
with_rht
&&
nvfp4_quantizer_cpp
->
with_post_rht_amax
)
{
// Post-RHT amax is handled within NVFP4 quantizer
impl
=
Impl
::
UNFUSED
;
}
else
if
(
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
// TE kernel supports amax output
impl
=
Impl
::
FUSED_NORM_AMAX_NVFP4
;
}
}
}
}
TensorWrapper
unquantized_out_cu
;
// Construct unquantized output tensor if needed
TensorWrapper
unquantized_out_nvte
;
py
::
object
unquantized_out
;
py
::
object
unquantized_out
;
if
(
force_unfused_kernel
)
{
TensorWrapper
*
kernel_out_nvte
=
&
out_nvte
;
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
switch
(
impl
)
{
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
case
Impl
::
UNFUSED
:
{
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
std
::
tie
(
unquantized_out_cu
,
unquantized_out
)
=
my_quantizer_cs
->
create_hp_tensor_with_amax
(
size
,
out_dtype
);
}
else
{
NoneQuantizer
q
{
none
};
NoneQuantizer
q
{
none
};
std
::
tie
(
unquantized_out_cu
,
unquantized_out
)
=
q
.
create_tensor
(
size
,
out_dtype
);
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
q
.
create_tensor
(
shape
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
case
Impl
::
FUSED_NORM_AMAX_FP8
:
{
auto
fp8_quantizer_cpp
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
fp8_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
shape
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
case
Impl
::
FUSED_NORM_AMAX_NVFP4
:
{
auto
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
nvfp4_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
out_nvte
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
default:
{
}
}
}
}
TensorWrapper
&
kernel_out_cu
=
force_unfused_kernel
?
unquantized_out_cu
:
out_cu
;
// Query workspace size
// Query workspace size
TensorWrapper
workspace
;
TensorWrapper
workspace
;
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_layernorm_fwd
(
input_cu
.
data
(),
weight_cu
.
data
(),
bias_cu
.
data
(),
eps
,
kernel_out_cu
.
data
(),
nvte_layernorm_fwd
(
input_nvte
.
data
(),
weight_nvte
.
data
(),
bias_nvte
.
data
(),
eps
,
mu_cu
.
data
(),
rsigma_cu
.
data
(),
workspace
.
data
(),
kernel_out_nvte
->
data
(),
mu_nvte
.
data
(),
rsigma_nvte
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
});
...
@@ -138,24 +173,31 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -138,24 +173,31 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Launch kernel
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_layernorm_fwd
(
input_cu
.
data
(),
weight_cu
.
data
(),
bias_cu
.
data
(),
eps
,
kernel_out_cu
.
data
(),
nvte_layernorm_fwd
(
input_nvte
.
data
(),
weight_nvte
.
data
(),
bias_nvte
.
data
(),
eps
,
mu_cu
.
data
(),
rsigma_cu
.
data
(),
workspace
.
data
(),
kernel_out_nvte
->
data
(),
mu_nvte
.
data
(),
rsigma_nvte
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
});
// Quantize output if using unfused kernel
// Quantize output if needed
if
(
force_unfused_kernel
)
{
switch
(
impl
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
case
Impl
::
UNFUSED
:
{
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
quantizer_cpp
->
quantize
(
unquantized_out_nvte
,
out_nvte
);
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
}
break
;
my_quantizer_cs
->
quantize_with_amax
(
unquantized_out_cu
,
out_cu
);
case
Impl
::
FUSED_NORM_AMAX_FP8
:
{
}
else
{
auto
fp8_quantizer_cpp
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
my_quantizer
->
quantize
(
unquantized_out_cu
,
out_cu
);
fp8_quantizer_cpp
->
quantize_with_amax
(
unquantized_out_nvte
,
out_nvte
);
}
break
;
case
Impl
::
FUSED_NORM_AMAX_NVFP4
:
{
auto
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
nvfp4_quantizer_cpp
->
quantize_with_amax
(
unquantized_out_nvte
,
out_nvte
);
}
break
;
default:
{
}
}
}
}
return
{
out
,
py
::
cast
(
mu
),
py
::
cast
(
rsigma
)};
return
{
out
,
py
::
cast
(
mu
_py
),
py
::
cast
(
rsigma
_py
)};
}
}
std
::
vector
<
py
::
object
>
rmsnorm_bwd
(
const
at
::
Tensor
&
dz
,
const
at
::
Tensor
&
x
,
std
::
vector
<
py
::
object
>
rmsnorm_bwd
(
const
at
::
Tensor
&
dz
,
const
at
::
Tensor
&
x
,
...
@@ -254,61 +296,95 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -254,61 +296,95 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Input and param tensors
// Input and param tensors
auto
none
=
py
::
none
();
auto
none
=
py
::
none
();
const
TensorWrapper
&
input_
cu
=
makeTransformerEngineTensor
(
input
,
none
);
const
TensorWrapper
&
input_
nvte
=
makeTransformerEngineTensor
(
input
,
none
);
const
TensorWrapper
&
weight_
cu
=
makeTransformerEngineTensor
(
weight
,
none
);
const
TensorWrapper
&
weight_
nvte
=
makeTransformerEngineTensor
(
weight
,
none
);
// Tensor dimensions
// Tensor dimensions
const
size_t
N
=
static_cast
<
size_t
>
(
input_
cu
.
shape
()
.
data
[
0
]
);
const
auto
shape
=
nvte_shape_to_vector
(
input_
nvte
.
shape
());
const
size_t
H
=
static_cast
<
size_t
>
(
input_cu
.
shape
().
data
[
1
]
);
const
auto
outer_size
=
product
(
shape
)
/
shape
.
back
(
);
const
std
::
vector
<
size_t
>
size
=
{
N
,
H
}
;
const
auto
inner_size
=
shape
.
back
()
;
// Tensors to save for backward pass
// Tensors to save for backward pass
a
uto
rsigma
=
at
::
empty
({
static_cast
<
int64_t
>
(
N
)},
at
::
CUDA
(
at
::
kFloat
));
a
t
::
Tensor
rsigma
_py
=
at
::
empty
({
static_cast
<
int64_t
>
(
outer_size
)},
at
::
CUDA
(
at
::
kFloat
));
auto
rsigma_
cu
=
makeTransformerEngineTensor
(
rsigma
);
TensorWrapper
rsigma_
nvte
=
makeTransformerEngineTensor
(
rsigma
_py
);
// Output tensor
// Output tensor
std
::
unique_ptr
<
Quantizer
>
my_
quantizer
=
convert_quantizer
(
quantizer
);
auto
quantizer
_cpp
=
convert_quantizer
(
quantizer
);
TensorWrapper
out_
cu
;
TensorWrapper
out_
nvte
;
if
(
out
.
is_none
())
{
if
(
out
.
is_none
())
{
std
::
tie
(
out_
cu
,
out
)
=
my_
quantizer
->
create_tensor
(
s
iz
e
,
out_dtype
);
std
::
tie
(
out_
nvte
,
out
)
=
quantizer
_cpp
->
create_tensor
(
s
hap
e
,
out_dtype
);
}
else
{
}
else
{
out_
cu
=
makeTransformerEngineTensor
(
out
,
quantizer
);
out_
nvte
=
makeTransformerEngineTensor
(
out
,
quantizer
);
}
}
// Determine whether to avoid fused kernel
// Choose implementation
bool
force_unfused_kernel
=
true
;
enum
class
Impl
{
if
(
quantizer
.
is_none
())
{
// Compute norm in high precision, then quantize
// No need for separate quantization step if output is unquantized
UNFUSED
,
force_unfused_kernel
=
false
;
// Compute norm directly
}
else
if
(
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
FULLY_FUSED
,
// Always used fused kernel for FP8 delayed scaling
// Compute norm and amax in high precision, then quantize to FP8
force_unfused_kernel
=
false
;
FUSED_NORM_AMAX_FP8
,
// Compute norm and amax in high precision, then quantize to NVFP4
FUSED_NORM_AMAX_NVFP4
};
Impl
impl
=
Impl
::
UNFUSED
;
if
(
quantizer
.
is_none
()
||
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
impl
=
Impl
::
FULLY_FUSED
;
}
else
if
(
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
}
else
if
(
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
)
&&
outer_size
%
128
==
0
&&
// cuDNN MXFP8 kernel requires full tile
inner_size
%
128
==
0
)
{
force_unfused_kernel
=
N
%
128
!=
0
||
H
%
128
!=
0
;
// cuDNN MXFP8 kernel requires full 128x128 tiles
impl
=
Impl
::
FULLY_FUSED
;
}
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
auto
fp8_quantizer_cpp
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
fp8_quantizer_cpp
!=
nullptr
,
"Could not cast to FP8 current scaling quantizer"
);
impl
=
Impl
::
FUSED_NORM_AMAX_FP8
;
}
else
if
(
detail
::
IsNVFP4Quantizers
(
quantizer
.
ptr
()))
{
auto
nvfp4_quantizer_cpp
=
dynamic_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
nvfp4_quantizer_cpp
!=
nullptr
,
"Could not cast to NVFP4 quantizer"
);
if
(
nvfp4_quantizer_cpp
->
with_rht
&&
nvfp4_quantizer_cpp
->
with_post_rht_amax
)
{
// Post-RHT amax is handled within NVFP4 quantizer
impl
=
Impl
::
UNFUSED
;
}
else
if
(
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
// TE kernel supports amax output
impl
=
Impl
::
FUSED_NORM_AMAX_NVFP4
;
}
}
}
}
TensorWrapper
unquantized_out_cu
;
// Construct unquantized output tensor if needed
TensorWrapper
unquantized_out_nvte
;
py
::
object
unquantized_out
;
py
::
object
unquantized_out
;
if
(
force_unfused_kernel
)
{
TensorWrapper
*
kernel_out_nvte
=
&
out_nvte
;
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
switch
(
impl
)
{
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
case
Impl
::
UNFUSED
:
{
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
std
::
tie
(
unquantized_out_cu
,
unquantized_out
)
=
my_quantizer_cs
->
create_hp_tensor_with_amax
(
size
,
out_dtype
);
}
else
{
NoneQuantizer
q
{
none
};
NoneQuantizer
q
{
none
};
std
::
tie
(
unquantized_out_cu
,
unquantized_out
)
=
q
.
create_tensor
(
size
,
out_dtype
);
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
q
.
create_tensor
(
shape
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
case
Impl
::
FUSED_NORM_AMAX_FP8
:
{
auto
fp8_quantizer_cpp
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
fp8_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
shape
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
case
Impl
::
FUSED_NORM_AMAX_NVFP4
:
{
auto
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
std
::
tie
(
unquantized_out_nvte
,
unquantized_out
)
=
nvfp4_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
out_nvte
,
out_dtype
);
kernel_out_nvte
=
&
unquantized_out_nvte
;
}
break
;
default:
{
}
}
}
}
TensorWrapper
&
kernel_out_cu
=
force_unfused_kernel
?
unquantized_out_cu
:
out_cu
;
// Query workspace size
// Query workspace size
TensorWrapper
workspace
;
TensorWrapper
workspace
;
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_rmsnorm_fwd
(
input_
cu
.
data
(),
weight_
cu
.
data
(),
eps
,
kernel_out_
cu
.
data
(),
rsigma_cu
.
data
(),
nvte_rmsnorm_fwd
(
input_
nvte
.
data
(),
weight_
nvte
.
data
(),
eps
,
kernel_out_
nvte
->
data
(),
workspace
.
data
(),
rsigma_nvte
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
});
...
@@ -320,24 +396,30 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -320,24 +396,30 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Launch kernel
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_rmsnorm_fwd
(
input_
cu
.
data
(),
weight_
cu
.
data
(),
eps
,
kernel_out_
cu
.
data
(),
rsigma_cu
.
data
(),
nvte_rmsnorm_fwd
(
input_
nvte
.
data
(),
weight_
nvte
.
data
(),
eps
,
kernel_out_
nvte
->
data
(),
workspace
.
data
(),
rsigma_nvte
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
});
// Quantize output if using unfused kernel
// Quantize output if needed
if
(
force_unfused_kernel
)
{
switch
(
impl
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
())
&&
case
Impl
::
UNFUSED
:
{
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
quantizer_cpp
->
quantize
(
unquantized_out_nvte
,
out_nvte
);
auto
my_quantizer_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
}
break
;
my_quantizer_cs
->
quantize_with_amax
(
unquantized_out_cu
,
out_cu
);
case
Impl
::
FUSED_NORM_AMAX_FP8
:
{
}
else
{
auto
fp8_quantizer_cpp
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
my_quantizer
->
quantize
(
unquantized_out_cu
,
out_cu
);
fp8_quantizer_cpp
->
quantize_with_amax
(
unquantized_out_nvte
,
out_nvte
);
}
break
;
case
Impl
::
FUSED_NORM_AMAX_NVFP4
:
{
auto
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
nvfp4_quantizer_cpp
->
quantize_with_amax
(
unquantized_out_nvte
,
out_nvte
);
}
break
;
default:
{
}
}
}
}
return
{
out
,
py
::
none
(),
py
::
cast
(
rsigma
)};
return
{
out
,
py
::
none
(),
py
::
cast
(
rsigma
_py
)};
}
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
063ef88d
...
@@ -23,15 +23,18 @@
...
@@ -23,15 +23,18 @@
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
PyTypeObject
*
Float8TensorPythonClass
=
nullptr
;
/// TODO Remove
PyTypeObject
*
Float8TensorPythonClass
=
nullptr
;
/// TODO Remove
PyTypeObject
*
Float8Tensor
Bas
ePythonClass
=
nullptr
;
PyTypeObject
*
Float8Tensor
Storag
ePythonClass
=
nullptr
;
PyTypeObject
*
Float8QuantizerClass
=
nullptr
;
PyTypeObject
*
Float8QuantizerClass
=
nullptr
;
PyTypeObject
*
Float8CurrentScalingQuantizerClass
=
nullptr
;
PyTypeObject
*
Float8CurrentScalingQuantizerClass
=
nullptr
;
PyTypeObject
*
MXFP8TensorPythonClass
=
nullptr
;
/// TODO Remove
PyTypeObject
*
MXFP8TensorPythonClass
=
nullptr
;
/// TODO Remove
PyTypeObject
*
MXFP8Tensor
Bas
ePythonClass
=
nullptr
;
PyTypeObject
*
MXFP8Tensor
Storag
ePythonClass
=
nullptr
;
PyTypeObject
*
MXFP8QuantizerClass
=
nullptr
;
PyTypeObject
*
MXFP8QuantizerClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQTensorPythonClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQTensorPythonClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQTensor
Bas
ePythonClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQTensor
Storag
ePythonClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQuantizerClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQuantizerClass
=
nullptr
;
PyTypeObject
*
NVFP4TensorPythonClass
=
nullptr
;
PyTypeObject
*
NVFP4TensorStoragePythonClass
=
nullptr
;
PyTypeObject
*
NVFP4QuantizerClass
=
nullptr
;
void
init_float8_extension
()
{
void
init_float8_extension
()
{
if
(
Float8TensorPythonClass
)
return
;
if
(
Float8TensorPythonClass
)
return
;
...
@@ -43,9 +46,9 @@ void init_float8_extension() {
...
@@ -43,9 +46,9 @@ void init_float8_extension() {
Float8TensorPythonClass
=
Float8TensorPythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"Float8Tensor"
));
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"Float8Tensor"
));
auto
fp8_base_module
=
auto
fp8_base_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.
_internal
.float8_tensor_
bas
e"
);
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.
storage
.float8_tensor_
storag
e"
);
Float8Tensor
Bas
ePythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
Float8Tensor
Storag
ePythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_base_module
.
ptr
(),
"Float8Tensor
Bas
e"
));
PyObject_GetAttrString
(
fp8_base_module
.
ptr
(),
"Float8Tensor
Storag
e"
));
NVTE_CHECK
(
Float8TensorPythonClass
!=
nullptr
,
NVTE_CHECK
(
Float8TensorPythonClass
!=
nullptr
,
"Internal error: could not initialize pyTorch Float8 extension."
);
"Internal error: could not initialize pyTorch Float8 extension."
);
}
}
...
@@ -58,38 +61,54 @@ void init_mxfp8_extension() {
...
@@ -58,38 +61,54 @@ void init_mxfp8_extension() {
MXFP8TensorPythonClass
=
MXFP8TensorPythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"MXFP8Tensor"
));
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"MXFP8Tensor"
));
auto
fp8_base_module
=
auto
fp8_base_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.
_internal
.mxfp8_tensor_
bas
e"
);
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.
storage
.mxfp8_tensor_
storag
e"
);
MXFP8Tensor
Bas
ePythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
MXFP8Tensor
Storag
ePythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_base_module
.
ptr
(),
"MXFP8Tensor
Bas
e"
));
PyObject_GetAttrString
(
fp8_base_module
.
ptr
(),
"MXFP8Tensor
Storag
e"
));
NVTE_CHECK
(
MXFP8TensorPythonClass
!=
nullptr
,
NVTE_CHECK
(
MXFP8TensorPythonClass
!=
nullptr
,
"Internal error: could not initialize pyTorch MXFP8 extension."
);
"Internal error: could not initialize pyTorch MXFP8 extension."
);
}
}
void
init_float8blockwise_extension
()
{
void
init_float8blockwise_extension
()
{
if
(
Float8BlockwiseQTensor
Bas
ePythonClass
)
return
;
if
(
Float8BlockwiseQTensor
Storag
ePythonClass
)
return
;
auto
fp8_module
=
auto
fp8_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.float8_blockwise_tensor"
);
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.float8_blockwise_tensor"
);
auto
fp8_base_module
=
py
::
module_
::
import
(
auto
fp8_base_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.
_internal
.float8_blockwise_tensor_
bas
e"
);
"transformer_engine.pytorch.tensor.
storage
.float8_blockwise_tensor_
storag
e"
);
Float8BlockwiseQuantizerClass
=
reinterpret_cast
<
PyTypeObject
*>
(
Float8BlockwiseQuantizerClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"Float8BlockQuantizer"
));
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"Float8BlockQuantizer"
));
Float8BlockwiseQTensor
Bas
ePythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
Float8BlockwiseQTensor
Storag
ePythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_base_module
.
ptr
(),
"Float8BlockwiseQTensor
Bas
e"
));
PyObject_GetAttrString
(
fp8_base_module
.
ptr
(),
"Float8BlockwiseQTensor
Storag
e"
));
Float8BlockwiseQTensorPythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
Float8BlockwiseQTensorPythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"Float8BlockwiseQTensor"
));
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"Float8BlockwiseQTensor"
));
NVTE_CHECK
(
Float8BlockwiseQuantizerClass
!=
nullptr
,
NVTE_CHECK
(
Float8BlockwiseQuantizerClass
!=
nullptr
,
"Internal error: could not initialize pyTorch float8blockwise extension."
);
"Internal error: could not initialize pyTorch float8blockwise extension."
);
NVTE_CHECK
(
Float8BlockwiseQTensor
Bas
ePythonClass
!=
nullptr
,
NVTE_CHECK
(
Float8BlockwiseQTensor
Storag
ePythonClass
!=
nullptr
,
"Internal error: could not initialize pyTorch float8blockwise extension."
);
"Internal error: could not initialize pyTorch float8blockwise extension."
);
NVTE_CHECK
(
Float8BlockwiseQTensorPythonClass
!=
nullptr
,
NVTE_CHECK
(
Float8BlockwiseQTensorPythonClass
!=
nullptr
,
"Internal error: could not initialize pyTorch float8blockwise extension."
);
"Internal error: could not initialize pyTorch float8blockwise extension."
);
}
}
void
init_nvfp4_extensions
()
{
if
(
NVFP4TensorPythonClass
)
return
;
auto
nvfp4_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.nvfp4_tensor"
);
NVFP4QuantizerClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
nvfp4_module
.
ptr
(),
"NVFP4Quantizer"
));
NVFP4TensorPythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
nvfp4_module
.
ptr
(),
"NVFP4Tensor"
));
auto
nvfp4_base_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage"
);
NVFP4TensorStoragePythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
nvfp4_base_module
.
ptr
(),
"NVFP4TensorStorage"
));
NVTE_CHECK
(
NVFP4TensorPythonClass
!=
nullptr
,
"Internal error: could not initialize pyTorch NVFP4 extension."
);
}
void
init_extension
()
{
void
init_extension
()
{
init_float8_extension
();
init_float8_extension
();
init_mxfp8_extension
();
init_mxfp8_extension
();
init_float8blockwise_extension
();
init_float8blockwise_extension
();
init_nvfp4_extensions
();
}
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
...
@@ -136,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -136,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"quantizer"
));
py
::
arg
(
"quantizer"
));
m
.
def
(
"swiglu"
,
transformer_engine
::
pytorch
::
swiglu
,
"SwiGLU activation"
,
py
::
arg
(
"input"
),
m
.
def
(
"swiglu"
,
transformer_engine
::
pytorch
::
swiglu
,
"SwiGLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
py
::
arg
(
"quantizer"
));
m
.
def
(
"clamped_swiglu"
,
transformer_engine
::
pytorch
::
clamped_swiglu
,
"SwiGLU activation used in GPT OSS"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"limit"
)
=
7.0
f
,
py
::
arg
(
"alpha"
)
=
1.702
f
);
/* Backward of GELU and variants */
/* Backward of GELU and variants */
m
.
def
(
"dgelu"
,
transformer_engine
::
pytorch
::
dgelu
,
"Backward of GeLU"
,
py
::
arg
(
"grad"
),
m
.
def
(
"dgelu"
,
transformer_engine
::
pytorch
::
dgelu
,
"Backward of GeLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
...
@@ -159,6 +181,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -159,6 +181,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"dswiglu"
,
transformer_engine
::
pytorch
::
dswiglu
,
"Backward of SwiGLU"
,
py
::
arg
(
"grad"
),
m
.
def
(
"dswiglu"
,
transformer_engine
::
pytorch
::
dswiglu
,
"Backward of SwiGLU"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"clamped_dswiglu"
,
transformer_engine
::
pytorch
::
clamped_dswiglu
,
"Backward of SwiGLU used in GPT OSS"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"limit"
)
=
7.0
f
,
py
::
arg
(
"alpha"
)
=
1.702
f
);
/* DBias + DAct fusions*/
/* DBias + DAct fusions*/
m
.
def
(
"dbias_dgelu"
,
transformer_engine
::
pytorch
::
dbias_dgelu
,
"DGeLU + DBias + Quantize"
,
m
.
def
(
"dbias_dgelu"
,
transformer_engine
::
pytorch
::
dbias_dgelu
,
"DGeLU + DBias + Quantize"
,
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
py
::
arg
(
"grad"
),
py
::
arg
(
"fwd_input"
),
py
::
arg
(
"quantizer"
));
...
...
transformer_engine/pytorch/csrc/pybind.h
View file @
063ef88d
...
@@ -31,22 +31,21 @@ namespace transformer_engine::pytorch {
...
@@ -31,22 +31,21 @@ namespace transformer_engine::pytorch {
} while (false);
} while (false);
extern
PyTypeObject
*
Float8TensorPythonClass
;
extern
PyTypeObject
*
Float8TensorPythonClass
;
extern
PyTypeObject
*
Float8Tensor
Bas
ePythonClass
;
extern
PyTypeObject
*
Float8Tensor
Storag
ePythonClass
;
extern
PyTypeObject
*
Float8QuantizerClass
;
extern
PyTypeObject
*
Float8QuantizerClass
;
extern
PyTypeObject
*
Float8CurrentScalingQuantizerClass
;
extern
PyTypeObject
*
Float8CurrentScalingQuantizerClass
;
extern
PyTypeObject
*
MXFP8TensorPythonClass
;
extern
PyTypeObject
*
MXFP8TensorPythonClass
;
extern
PyTypeObject
*
MXFP8Tensor
Bas
ePythonClass
;
extern
PyTypeObject
*
MXFP8Tensor
Storag
ePythonClass
;
extern
PyTypeObject
*
MXFP8QuantizerClass
;
extern
PyTypeObject
*
MXFP8QuantizerClass
;
extern
PyTypeObject
*
Float8BlockwiseQTensorPythonClass
;
extern
PyTypeObject
*
Float8BlockwiseQTensorPythonClass
;
extern
PyTypeObject
*
Float8BlockwiseQTensor
Bas
ePythonClass
;
extern
PyTypeObject
*
Float8BlockwiseQTensor
Storag
ePythonClass
;
extern
PyTypeObject
*
Float8BlockwiseQuantizerClass
;
extern
PyTypeObject
*
Float8BlockwiseQuantizerClass
;
extern
PyTypeObject
*
NVFP4TensorPythonClass
;
extern
PyTypeObject
*
NVFP4TensorStoragePythonClass
;
extern
PyTypeObject
*
NVFP4QuantizerClass
;
void
init_extension
();
void
init_extension
();
void
init_float8_extension
();
void
init_mxfp8_extension
();
namespace
detail
{
namespace
detail
{
inline
bool
IsFloat8Quantizers
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
Float8QuantizerClass
;
}
inline
bool
IsFloat8Quantizers
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
Float8QuantizerClass
;
}
...
@@ -56,22 +55,28 @@ inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) {
...
@@ -56,22 +55,28 @@ inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) {
}
}
inline
bool
IsFloat8Tensor
(
PyObject
*
obj
)
{
inline
bool
IsFloat8Tensor
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
Float8TensorPythonClass
||
Py_TYPE
(
obj
)
==
Float8Tensor
Bas
ePythonClass
;
return
Py_TYPE
(
obj
)
==
Float8TensorPythonClass
||
Py_TYPE
(
obj
)
==
Float8Tensor
Storag
ePythonClass
;
}
}
inline
bool
IsMXFP8Quantizers
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
MXFP8QuantizerClass
;
}
inline
bool
IsMXFP8Quantizers
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
MXFP8QuantizerClass
;
}
inline
bool
IsMXFP8Tensor
(
PyObject
*
obj
)
{
inline
bool
IsMXFP8Tensor
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
MXFP8TensorPythonClass
||
Py_TYPE
(
obj
)
==
MXFP8Tensor
Bas
ePythonClass
;
return
Py_TYPE
(
obj
)
==
MXFP8TensorPythonClass
||
Py_TYPE
(
obj
)
==
MXFP8Tensor
Storag
ePythonClass
;
}
}
inline
bool
IsFloat8BlockwiseQuantizers
(
PyObject
*
obj
)
{
inline
bool
IsFloat8BlockwiseQuantizers
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
Float8BlockwiseQuantizerClass
;
return
Py_TYPE
(
obj
)
==
Float8BlockwiseQuantizerClass
;
}
}
inline
bool
IsNVFP4Quantizers
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
NVFP4QuantizerClass
;
}
inline
bool
IsFloat8BlockwiseQTensor
(
PyObject
*
obj
)
{
inline
bool
IsFloat8BlockwiseQTensor
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
Float8BlockwiseQTensorPythonClass
||
return
Py_TYPE
(
obj
)
==
Float8BlockwiseQTensorPythonClass
||
Py_TYPE
(
obj
)
==
Float8BlockwiseQTensorBasePythonClass
;
Py_TYPE
(
obj
)
==
Float8BlockwiseQTensorStoragePythonClass
;
}
inline
bool
IsNVFP4Tensor
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
NVFP4TensorPythonClass
||
Py_TYPE
(
obj
)
==
NVFP4TensorStoragePythonClass
;
}
}
TensorWrapper
NVTETensorFromFloat8Tensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
);
TensorWrapper
NVTETensorFromFloat8Tensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
);
...
@@ -88,6 +93,8 @@ std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params);
...
@@ -88,6 +93,8 @@ std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params);
TensorWrapper
NVTETensorFromFloat8BlockwiseQTensor
(
py
::
handle
tensor
,
TensorWrapper
NVTETensorFromFloat8BlockwiseQTensor
(
py
::
handle
tensor
,
Quantizer
*
quantization_params
);
Quantizer
*
quantization_params
);
TensorWrapper
NVTETensorFromNVFP4Tensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
);
inline
bool
IsFloatingPointType
(
at
::
ScalarType
type
)
{
inline
bool
IsFloatingPointType
(
at
::
ScalarType
type
)
{
return
type
==
at
::
kFloat
||
type
==
at
::
kHalf
||
type
==
at
::
kBFloat16
;
return
type
==
at
::
kFloat
||
type
==
at
::
kHalf
||
type
==
at
::
kBFloat16
;
}
}
...
@@ -100,8 +107,9 @@ constexpr std::array custom_types_converters = {
...
@@ -100,8 +107,9 @@ constexpr std::array custom_types_converters = {
std
::
make_tuple
(
IsMXFP8Tensor
,
IsMXFP8Quantizers
,
NVTETensorFromMXFP8Tensor
,
std
::
make_tuple
(
IsMXFP8Tensor
,
IsMXFP8Quantizers
,
NVTETensorFromMXFP8Tensor
,
CreateQuantizer
<
MXFP8Quantizer
>
),
CreateQuantizer
<
MXFP8Quantizer
>
),
std
::
make_tuple
(
IsFloat8BlockwiseQTensor
,
IsFloat8BlockwiseQuantizers
,
std
::
make_tuple
(
IsFloat8BlockwiseQTensor
,
IsFloat8BlockwiseQuantizers
,
NVTETensorFromFloat8BlockwiseQTensor
,
CreateQuantizer
<
Float8BlockQuantizer
>
)};
NVTETensorFromFloat8BlockwiseQTensor
,
CreateQuantizer
<
Float8BlockQuantizer
>
),
std
::
make_tuple
(
IsNVFP4Tensor
,
IsNVFP4Quantizers
,
NVTETensorFromNVFP4Tensor
,
CreateQuantizer
<
NVFP4Quantizer
>
)};
}
// namespace detail
}
// namespace detail
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
063ef88d
...
@@ -31,8 +31,20 @@ std::vector<T> make_transpose_shape(const std::vector<S>& shape) {
...
@@ -31,8 +31,20 @@ std::vector<T> make_transpose_shape(const std::vector<S>& shape) {
return
ret
;
return
ret
;
}
}
/*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */
template
<
typename
T
=
size_t
>
std
::
vector
<
T
>
convert_shape_for_fp4
(
const
std
::
vector
<
T
>&
shape
)
{
std
::
vector
<
T
>
ret
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
()
-
1
;
++
i
)
{
ret
.
push_back
(
shape
[
i
]);
}
ret
.
push_back
(
shape
.
back
()
/
2
);
return
ret
;
}
}
// namespace
}
// namespace
constexpr
size_t
NVFP4_BLOCK_SIZE
=
16
;
constexpr
size_t
MXFP8_BLOCK_SIZE
=
32
;
constexpr
size_t
MXFP8_BLOCK_SIZE
=
32
;
Quantizer
::
Quantizer
(
const
py
::
handle
&
quantizer
)
{
Quantizer
::
Quantizer
(
const
py
::
handle
&
quantizer
)
{
...
@@ -140,7 +152,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
...
@@ -140,7 +152,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
// Construct Python FP8 tensor
// Construct Python FP8 tensor
py
::
object
out_py
;
py
::
object
out_py
;
if
(
internal
)
{
if
(
internal
)
{
py
::
handle
Float8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8Tensor
Bas
ePythonClass
));
py
::
handle
Float8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8Tensor
Storag
ePythonClass
));
out_py
=
Float8TensorClass
(
"data"
_a
=
data_py
,
"fp8_scale_inv"
_a
=
*
scale_inv
,
out_py
=
Float8TensorClass
(
"data"
_a
=
data_py
,
"fp8_scale_inv"
_a
=
*
scale_inv
,
"fp8_dtype"
_a
=
this
->
dtype
,
"data_transpose"
_a
=
transpose_py
,
"fp8_dtype"
_a
=
this
->
dtype
,
"data_transpose"
_a
=
transpose_py
,
"quantizer"
_a
=
this
->
quantizer
);
"quantizer"
_a
=
this
->
quantizer
);
...
@@ -345,7 +357,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
...
@@ -345,7 +357,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
py
::
object
data_py
=
with_data
?
py
::
cast
(
data_tensor
)
:
py
::
none
();
py
::
object
data_py
=
with_data
?
py
::
cast
(
data_tensor
)
:
py
::
none
();
py
::
object
transpose_py
=
with_transpose
?
py
::
cast
(
transpose_tensor
)
:
py
::
none
();
py
::
object
transpose_py
=
with_transpose
?
py
::
cast
(
transpose_tensor
)
:
py
::
none
();
if
(
internal
)
{
if
(
internal
)
{
py
::
handle
Float8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8Tensor
Bas
ePythonClass
));
py
::
handle
Float8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8Tensor
Storag
ePythonClass
));
out_py
=
Float8TensorClass
(
"data"
_a
=
data_py
,
"fp8_scale_inv"
_a
=
scale_inv_tensor
,
out_py
=
Float8TensorClass
(
"data"
_a
=
data_py
,
"fp8_scale_inv"
_a
=
scale_inv_tensor
,
"fp8_dtype"
_a
=
this
->
dtype
,
"data_transpose"
_a
=
transpose_py
,
"fp8_dtype"
_a
=
this
->
dtype
,
"data_transpose"
_a
=
transpose_py
,
"quantizer"
_a
=
this
->
quantizer
);
"quantizer"
_a
=
this
->
quantizer
);
...
@@ -376,10 +388,15 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
...
@@ -376,10 +388,15 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
}
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
Float8CurrentScalingQuantizer
::
create_hp_tensor_with_amax
(
std
::
pair
<
TensorWrapper
,
py
::
object
>
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
{
Float8CurrentScalingQuantizer
::
create_unquantized_tensor_with_amax
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
data
)
{
amax
.
zero_
();
amax
.
zero_
();
auto
[
out_cpp
,
out_py
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
shape
,
dtype
);
auto
out
=
data
.
has_value
()
?
NoneQuantizer
(
py
::
none
()).
create_tensor
(
shape
,
dtype
,
data
.
value
())
:
NoneQuantizer
(
py
::
none
()).
create_tensor
(
shape
,
dtype
);
TensorWrapper
out_cpp
=
std
::
move
(
out
.
first
);
py
::
object
out_py
=
std
::
move
(
out
.
second
);
out_cpp
.
set_amax
(
amax
.
data_ptr
(),
GetTransformerEngineDType
(
amax
.
scalar_type
()),
out_cpp
.
set_amax
(
amax
.
data_ptr
(),
GetTransformerEngineDType
(
amax
.
scalar_type
()),
getTensorShape
(
amax
));
getTensorShape
(
amax
));
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
...
@@ -613,7 +630,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -613,7 +630,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
py
::
object
ret
;
py
::
object
ret
;
if
(
internal
)
{
if
(
internal
)
{
py
::
handle
Float8BlockwiseQTensorClass
(
py
::
handle
Float8BlockwiseQTensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8BlockwiseQTensor
Bas
ePythonClass
));
reinterpret_cast
<
PyObject
*>
(
Float8BlockwiseQTensor
Storag
ePythonClass
));
ret
=
Float8BlockwiseQTensorClass
(
ret
=
Float8BlockwiseQTensorClass
(
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
...
@@ -899,7 +916,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
...
@@ -899,7 +916,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
}
}
const
size_t
flat_last_dim
=
shape
.
size
()
>
0
?
shape
.
back
()
:
1
;
const
size_t
flat_last_dim
=
shape
.
size
()
>
0
?
shape
.
back
()
:
1
;
NVTE_CHECK
(
flat_first_dim
%
MXFP8_BLOCK_SIZE
==
0
&&
flat_last_dim
%
MXFP8_BLOCK_SIZE
==
0
,
NVTE_CHECK
(
flat_first_dim
%
MXFP8_BLOCK_SIZE
==
0
&&
flat_last_dim
%
MXFP8_BLOCK_SIZE
==
0
,
"MXFP8 requires tensor dims that are divisble by "
,
MXFP8_BLOCK_SIZE
,
"MXFP8 requires tensor dims that are divis
i
ble by "
,
MXFP8_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
" (got shape="
,
shape
,
")"
);
const
auto
rowwise_scale_inv_shape
=
get_scale_shape
(
shape
,
false
);
const
auto
rowwise_scale_inv_shape
=
get_scale_shape
(
shape
,
false
);
const
auto
columnwise_scale_inv_shape
=
get_scale_shape
(
shape
,
true
);
const
auto
columnwise_scale_inv_shape
=
get_scale_shape
(
shape
,
true
);
...
@@ -933,7 +950,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
...
@@ -933,7 +950,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
// Construct Python MXFP8 tensor
// Construct Python MXFP8 tensor
py
::
object
out_py
;
py
::
object
out_py
;
if
(
internal
)
{
if
(
internal
)
{
py
::
handle
MXFP8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
MXFP8Tensor
Bas
ePythonClass
));
py
::
handle
MXFP8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
MXFP8Tensor
Storag
ePythonClass
));
out_py
=
MXFP8TensorClass
(
"rowwise_data"
_a
=
rowwise_data_py
,
out_py
=
MXFP8TensorClass
(
"rowwise_data"
_a
=
rowwise_data_py
,
"columnwise_data"
_a
=
columnwise_data_py
,
"columnwise_data"
_a
=
columnwise_data_py
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
...
@@ -1095,7 +1112,7 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s
...
@@ -1095,7 +1112,7 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s
auto
last_dim
=
shape
.
back
();
auto
last_dim
=
shape
.
back
();
NVTE_CHECK
(
last_dim
%
MXFP8_BLOCK_SIZE
==
0
&&
(
numel
/
last_dim
)
%
MXFP8_BLOCK_SIZE
==
0
,
NVTE_CHECK
(
last_dim
%
MXFP8_BLOCK_SIZE
==
0
&&
(
numel
/
last_dim
)
%
MXFP8_BLOCK_SIZE
==
0
,
"MXFP8 requires tensor dims that are divisble by "
,
MXFP8_BLOCK_SIZE
,
"MXFP8 requires tensor dims that are divis
i
ble by "
,
MXFP8_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
" (got shape="
,
shape
,
")"
);
std
::
vector
<
size_t
>
scale_shape
;
std
::
vector
<
size_t
>
scale_shape
;
...
@@ -1116,4 +1133,573 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s
...
@@ -1116,4 +1133,573 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s
return
scale_shape
;
return
scale_shape
;
}
}
NVFP4Quantizer
::
NVFP4Quantizer
(
const
py
::
handle
&
quantizer
)
:
Quantizer
(
quantizer
)
{
this
->
dtype
=
quantizer
.
attr
(
"dtype"
).
cast
<
DType
>
();
this
->
with_rht
=
quantizer
.
attr
(
"with_rht"
).
cast
<
bool
>
();
this
->
with_post_rht_amax
=
quantizer
.
attr
(
"with_post_rht_amax"
).
cast
<
bool
>
();
this
->
with_2d_quantization
=
quantizer
.
attr
(
"with_2d_quantization"
).
cast
<
bool
>
();
this
->
stochastic_rounding
=
quantizer
.
attr
(
"stochastic_rounding"
).
cast
<
bool
>
();
// Get amax reduction group if needed for NVFP4 AG
const
bool
with_amax_reduction
=
quantizer
.
attr
(
"with_amax_reduction"
).
cast
<
bool
>
();
c10
::
intrusive_ptr
<
dist_group_type
>
amax_reduction_group
;
if
(
with_amax_reduction
)
{
auto
group
=
quantizer
.
attr
(
"_canonicalized_amax_reduction_group"
)();
NVTE_CHECK
(
!
group
.
is_none
(),
"NVFP4Quantizer could not canonicalize amax reduction group"
);
amax_reduction_group
=
group
.
cast
<
c10
::
intrusive_ptr
<
dist_group_type
>>
();
}
this
->
with_amax_reduction
=
with_amax_reduction
;
this
->
amax_reduction_group
=
amax_reduction_group
;
this
->
rht_matrix_random_sign_mask_t
=
quantizer
.
attr
(
"rht_matrix_random_sign_mask_t"
).
cast
<
int
>
();
this
->
rht_matrix
=
quantizer
.
attr
(
"rht_matrix"
).
cast
<
at
::
Tensor
>
();
}
void
NVFP4Quantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{
// set dtype for rowwise and columnwise data in tensor wrapper
auto
rowwise_data
=
tensor
->
get_rowwise_data
();
rowwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
this
->
dtype
);
auto
columnwise_data
=
tensor
->
get_columnwise_data
();
columnwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
this
->
dtype
);
tensor
->
set_rowwise_data
(
rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
rowwise_data
.
dtype
),
rowwise_data
.
shape
);
tensor
->
set_columnwise_data
(
columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
columnwise_data
.
dtype
),
columnwise_data
.
shape
);
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
NVFP4Quantizer
::
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
{
using
namespace
pybind11
::
literals
;
// Tensor dimensions
const
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
size_t
flat_first_dim
=
1
;
if
(
shape
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
shape
.
size
()
-
1
;
++
i
)
{
flat_first_dim
*=
shape
[
i
];
}
}
const
size_t
flat_last_dim
=
shape
.
size
()
>
0
?
shape
.
back
()
:
1
;
NVTE_CHECK
(
flat_first_dim
%
NVFP4_BLOCK_SIZE
==
0
,
"First dim for NVFP4 must be divisible by "
,
NVFP4_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
NVTE_CHECK
(
flat_last_dim
%
NVFP4_BLOCK_SIZE
==
0
,
"NVFP4 requires tensor dims that are divisible by "
,
NVFP4_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
const
auto
rowwise_scale_inv_shape
=
get_scale_shape
(
shape
,
false
);
const
auto
columnwise_scale_inv_shape
=
get_scale_shape
(
shape
,
true
);
// Allocate tensors
at
::
Tensor
rowwise_data_tensor
,
rowwise_scale_inv_tensor
,
amax_rowwise
;
at
::
Tensor
columnwise_data_tensor
,
columnwise_scale_inv_tensor
,
amax_columnwise
;
const
auto
bit8_tensor_opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
const
auto
bit32_tensor_opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
if
(
rowwise_usage
)
{
const
std
::
vector
<
int64_t
>
scale_inv_shape_int64
(
rowwise_scale_inv_shape
.
begin
(),
rowwise_scale_inv_shape
.
end
());
rowwise_data_tensor
=
at
::
empty
(
convert_shape_for_fp4
(
shape_int64
),
bit8_tensor_opts
);
rowwise_scale_inv_tensor
=
at
::
empty
(
scale_inv_shape_int64
,
bit8_tensor_opts
);
amax_rowwise
=
at
::
empty
({
1
},
bit32_tensor_opts
);
}
if
(
columnwise_usage
)
{
const
std
::
vector
<
int64_t
>
scale_inv_shape_int64
(
columnwise_scale_inv_shape
.
begin
(),
columnwise_scale_inv_shape
.
end
());
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std
::
vector
<
int64_t
>
shape_int64_2d
=
{
static_cast
<
int64_t
>
(
flat_first_dim
),
static_cast
<
int64_t
>
(
flat_last_dim
)};
const
auto
transpose_shape_int64
=
make_transpose_shape
<
int64_t
>
(
shape_int64_2d
);
columnwise_data_tensor
=
at
::
empty
(
convert_shape_for_fp4
(
transpose_shape_int64
),
bit8_tensor_opts
);
columnwise_scale_inv_tensor
=
at
::
empty
(
scale_inv_shape_int64
,
bit8_tensor_opts
);
amax_columnwise
=
at
::
empty
({
1
},
bit32_tensor_opts
);
}
// Convert tensors to Python
auto
py_cast
=
[](
at
::
Tensor
&
tensor
,
bool
need_cast
)
->
py
::
object
{
return
need_cast
?
py
::
cast
(
tensor
)
:
py
::
none
();
};
auto
rowwise_data_py
=
py_cast
(
rowwise_data_tensor
,
rowwise_usage
);
auto
rowwise_scale_inv_py
=
py_cast
(
rowwise_scale_inv_tensor
,
rowwise_usage
);
auto
columnwise_data_py
=
py_cast
(
columnwise_data_tensor
,
columnwise_usage
);
auto
columnwise_scale_inv_py
=
py_cast
(
columnwise_scale_inv_tensor
,
columnwise_usage
);
auto
amax_rowwise_py
=
py_cast
(
amax_rowwise
,
rowwise_usage
);
auto
amax_columnwise_py
=
py_cast
(
amax_columnwise
,
columnwise_usage
);
// Construct Python NVFP4 tensor
py
::
object
out_py
;
if
(
internal
)
{
py
::
handle
NVFP4TensorClass
(
reinterpret_cast
<
PyObject
*>
(
NVFP4TensorStoragePythonClass
));
out_py
=
NVFP4TensorClass
(
"rowwise_data"
_a
=
rowwise_data_py
,
"columnwise_data"
_a
=
columnwise_data_py
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
"columnwise_scale_inv"
_a
=
columnwise_scale_inv_py
,
"amax_rowwise"
_a
=
amax_rowwise_py
,
"amax_columnwise"
_a
=
amax_columnwise_py
,
"fp4_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
);
}
else
{
py
::
handle
NVFP4TensorClass
(
reinterpret_cast
<
PyObject
*>
(
NVFP4TensorPythonClass
));
out_py
=
NVFP4TensorClass
(
"shape"
_a
=
shape_int64
,
"dtype"
_a
=
GetATenDType
(
dtype
),
"rowwise_data"
_a
=
rowwise_data_py
,
"columnwise_data"
_a
=
columnwise_data_py
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
"columnwise_scale_inv"
_a
=
columnwise_scale_inv_py
,
"amax_rowwise"
_a
=
amax_rowwise_py
,
"amax_columnwise"
_a
=
amax_columnwise_py
,
"fp4_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
);
}
// Construct C++ tensor
TensorWrapper
out_cpp
(
NVTE_NVFP4_1D_SCALING
);
if
(
rowwise_usage
)
{
out_cpp
.
set_rowwise_data
(
rowwise_data_tensor
.
data_ptr
(),
DType
::
kFloat4E2M1
,
shape
);
out_cpp
.
set_rowwise_scale_inv
(
rowwise_scale_inv_tensor
.
data_ptr
(),
DType
::
kFloat8E4M3
,
rowwise_scale_inv_shape
);
out_cpp
.
set_amax
(
amax_rowwise
.
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
if
(
columnwise_usage
)
{
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std
::
vector
<
size_t
>
shape_2d
=
{
flat_first_dim
,
flat_last_dim
};
auto
col_data_shape_fp4
=
make_transpose_shape
<
size_t
>
(
shape_2d
);
out_cpp
.
set_columnwise_data
(
columnwise_data_tensor
.
data_ptr
(),
DType
::
kFloat4E2M1
,
col_data_shape_fp4
);
out_cpp
.
set_columnwise_scale_inv
(
columnwise_scale_inv_tensor
.
data_ptr
(),
DType
::
kFloat8E4M3
,
columnwise_scale_inv_shape
);
out_cpp
.
set_columnwise_amax
(
amax_columnwise
.
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
this
->
set_quantization_params
(
&
out_cpp
);
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
NVFP4Quantizer
::
create_unquantized_tensor_with_amax
(
TensorWrapper
&
quantized_tensor
,
DType
dtype
)
{
// Construct tensor
auto
shape
=
convertShape
(
quantized_tensor
.
shape
());
auto
[
out_cpp
,
out_py
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
shape
,
dtype
);
// Register amax pointer from quantized tensor
void
*
amax_ptr
=
quantized_tensor
.
amax
();
if
(
amax_ptr
==
nullptr
)
{
amax_ptr
=
quantized_tensor
.
get_columnwise_amax
().
data_ptr
;
}
NVTE_CHECK
(
amax_ptr
!=
nullptr
,
"Could not extract amax pointer from NVFP4 tensor."
);
out_cpp
.
set_amax
(
amax_ptr
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
// Zero out amax
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
amax_ptr
,
0
,
sizeof
(
float
),
at
::
cuda
::
getCurrentCUDAStream
()));
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
NVFP4Quantizer
::
convert_and_update_tensor
(
py
::
object
tensor
)
const
{
NVTE_CHECK
(
detail
::
IsNVFP4Tensor
(
tensor
.
ptr
()),
"NVFP4Quantizer must output to IsNVFP4Tensor."
);
// Extract buffers from Python tensor
auto
get_tensor
=
[
&
tensor
](
const
char
*
name
)
->
std
::
optional
<
at
::
Tensor
>
{
auto
attr_py
=
tensor
.
attr
(
name
);
if
(
attr_py
.
is_none
())
{
return
std
::
nullopt
;
}
return
attr_py
.
cast
<
at
::
Tensor
>
();
};
auto
rowwise_data
=
get_tensor
(
"_rowwise_data"
);
auto
rowwise_scale_inv
=
get_tensor
(
"_rowwise_scale_inv"
);
auto
columnwise_data
=
get_tensor
(
"_columnwise_data"
);
auto
columnwise_scale_inv
=
get_tensor
(
"_columnwise_scale_inv"
);
auto
amax_rowwise
=
get_tensor
(
"_amax_rowwise"
);
auto
amax_columnwise
=
get_tensor
(
"_amax_columnwise"
);
NVTE_CHECK
(
rowwise_data
||
columnwise_data
,
"NVFP4Tensor has no data."
);
// Tensor dimensions, shape means original shape
std
::
vector
<
size_t
>
shape
;
if
(
columnwise_data
)
{
shape
=
convert_shape_back_from_fp4
(
getTensorShape
(
*
columnwise_data
),
true
);
if
(
rowwise_data
)
{
auto
expected_shape
=
convert_shape_back_from_fp4
(
getTensorShape
(
*
rowwise_data
),
false
);
NVTE_CHECK
(
shape
==
expected_shape
,
"NVFP4 row-wise data (shape="
,
expected_shape
,
") and column-wise data (shape="
,
shape
,
") do not match"
);
}
}
else
{
// Already checked columnwise_data_tensor == true
shape
=
convert_shape_back_from_fp4
(
getTensorShape
(
*
rowwise_data
),
false
);
}
size_t
flat_first_dim
=
1
;
if
(
shape
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
shape
.
size
()
-
1
;
++
i
)
{
flat_first_dim
*=
shape
[
i
];
}
}
const
size_t
flat_last_dim
=
shape
.
size
()
>
0
?
shape
.
back
()
:
1
;
// Coerce row-wise data
if
(
rowwise_usage
)
{
if
(
!
rowwise_data
)
{
const
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
rowwise_data
=
at
::
empty
(
convert_shape_for_fp4
(
shape_int64
),
opts
);
tensor
.
attr
(
"_rowwise_data"
)
=
*
rowwise_data
;
}
if
(
!
rowwise_scale_inv
)
{
const
auto
scale_inv_shape
=
get_scale_shape
(
shape
,
false
);
const
std
::
vector
<
int64_t
>
scale_inv_shape_int64
(
scale_inv_shape
.
begin
(),
scale_inv_shape
.
end
());
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
rowwise_scale_inv
=
at
::
empty
(
scale_inv_shape_int64
,
opts
);
tensor
.
attr
(
"_rowwise_scale_inv"
)
=
*
rowwise_scale_inv
;
}
if
(
!
amax_rowwise
)
{
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
amax_rowwise
=
at
::
empty
({
1
},
opts
);
tensor
.
attr
(
"_amax_rowwise"
)
=
*
amax_rowwise
;
}
}
else
{
// rowwise_usage == false
if
(
rowwise_data
)
{
rowwise_data
.
reset
();
tensor
.
attr
(
"_rowwise_data"
)
=
py
::
none
();
}
if
(
rowwise_scale_inv
)
{
rowwise_scale_inv
.
reset
();
tensor
.
attr
(
"_rowwise_scale_inv"
)
=
py
::
none
();
}
if
(
amax_rowwise
)
{
amax_rowwise
.
reset
();
tensor
.
attr
(
"_amax_rowwise"
)
=
py
::
none
();
}
}
// Coerce column-wise data
if
(
columnwise_usage
)
{
if
(
!
columnwise_data
)
{
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std
::
vector
<
int64_t
>
shape_int64_2d
=
{
static_cast
<
int64_t
>
(
flat_first_dim
),
static_cast
<
int64_t
>
(
flat_last_dim
)};
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
const
auto
transpose_shape_int64
=
make_transpose_shape
<
int64_t
>
(
shape_int64_2d
);
columnwise_data
=
at
::
empty
(
convert_shape_for_fp4
(
transpose_shape_int64
),
opts
);
tensor
.
attr
(
"_columnwise_data"
)
=
*
columnwise_data
;
}
if
(
!
columnwise_scale_inv
)
{
const
auto
scale_inv_shape
=
get_scale_shape
(
shape
,
true
);
const
std
::
vector
<
int64_t
>
scale_inv_shape_int64
(
scale_inv_shape
.
begin
(),
scale_inv_shape
.
end
());
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
columnwise_scale_inv
=
at
::
empty
(
scale_inv_shape_int64
,
opts
);
tensor
.
attr
(
"_columnwise_scale_inv"
)
=
*
columnwise_scale_inv
;
}
if
(
!
amax_columnwise
)
{
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
amax_columnwise
=
at
::
zeros
({
1
},
opts
);
tensor
.
attr
(
"_amax_columnwise"
)
=
*
amax_columnwise
;
}
}
else
{
// columnwise_usage == false
if
(
columnwise_data
)
{
columnwise_data
.
reset
();
tensor
.
attr
(
"_columnwise_data"
)
=
py
::
none
();
}
if
(
columnwise_scale_inv
)
{
columnwise_scale_inv
.
reset
();
tensor
.
attr
(
"_columnwise_scale_inv"
)
=
py
::
none
();
}
if
(
amax_columnwise
)
{
amax_columnwise
.
reset
();
tensor
.
attr
(
"_amax_columnwise"
)
=
py
::
none
();
}
}
// Construct C++ tensor
TensorWrapper
out_cpp
(
NVTE_NVFP4_1D_SCALING
);
if
(
rowwise_usage
)
{
out_cpp
.
set_rowwise_data
(
rowwise_data
->
data_ptr
(),
DType
::
kFloat4E2M1
,
shape
);
out_cpp
.
set_rowwise_scale_inv
(
rowwise_scale_inv
->
data_ptr
(),
DType
::
kFloat8E4M3
,
getTensorShape
(
*
rowwise_scale_inv
));
out_cpp
.
set_amax
(
amax_rowwise
->
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
if
(
columnwise_usage
)
{
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std
::
vector
<
size_t
>
shape_2d
=
{
flat_first_dim
,
flat_last_dim
};
auto
col_data_shape_fp4
=
make_transpose_shape
<
size_t
>
(
shape_2d
);
out_cpp
.
set_columnwise_data
(
columnwise_data
->
data_ptr
(),
DType
::
kFloat4E2M1
,
col_data_shape_fp4
);
out_cpp
.
set_columnwise_scale_inv
(
columnwise_scale_inv
->
data_ptr
(),
DType
::
kFloat8E4M3
,
getTensorShape
(
*
columnwise_scale_inv
));
out_cpp
.
set_columnwise_amax
(
amax_columnwise
->
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
this
->
set_quantization_params
(
&
out_cpp
);
return
{
std
::
move
(
out_cpp
),
std
::
move
(
tensor
)};
}
void
NVFP4Quantizer
::
quantize_impl
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
,
bool
compute_amax
)
{
// Nothing to be done if input is empty
if
(
input
.
numel
()
==
0
)
{
return
;
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
QuantizationConfigWrapper
quant_config
;
if
(
noop_flag
)
{
quant_config
.
set_noop_tensor
(
noop_flag
->
data
());
}
quant_config
.
set_nvfp4_2d_quantization
(
this
->
with_2d_quantization
);
quant_config
.
set_stochastic_rounding
(
this
->
stochastic_rounding
);
// We only need RHT for columnwise usage.
// flat first dim and last dim for multi dimensional input
size_t
rows
=
1
;
for
(
size_t
i
=
0
;
i
<
input
.
ndim
()
-
1
;
++
i
)
{
rows
*=
input
.
size
(
i
);
}
size_t
cols
=
input
.
size
(
input
.
ndim
()
-
1
);
TensorWrapper
te_rng_state
;
if
(
this
->
stochastic_rounding
)
{
const
size_t
rng_elts_per_thread
=
1024
;
// Wild guess, probably can be tightened
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
std
::
nullopt
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
at
::
PhiloxCudaState
philox_args
=
init_philox_state
(
gen
,
rng_elts_per_thread
);
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
);
auto
rng_state
=
torch
::
empty
({
2
},
opts
);
philox_unpack
(
philox_args
,
static_cast
<
int64_t
*>
(
rng_state
.
data_ptr
()));
te_rng_state
=
makeTransformerEngineTensor
(
rng_state
);
quant_config
.
set_rng_state
(
te_rng_state
.
data
());
}
// Restriction for the RHT cast fusion kernel.
bool
eligible_for_rht_cast_fusion
=
input
.
dtype
()
==
DType
::
kBFloat16
&&
rows
%
64
==
0
&&
cols
%
128
==
0
;
// Compute amax.
if
(
this
->
with_rht
)
{
if
(
input
.
dtype
()
!=
DType
::
kBFloat16
)
{
NVTE_CHECK
(
false
,
"RHT is only supported for bfloat16 input"
);
}
if
(
this
->
with_post_rht_amax
)
{
// We need:
// 1. Rowwise amax = amax for input
// 2. Columnwise amax = amax for RHT(input.t)
NVTE_SCOPED_GIL_RELEASE
({
nvte_hadamard_transform_amax
(
input
.
data
(),
out
.
data
(),
0
,
this
->
rht_matrix_random_sign_mask_t
,
stream
);
});
}
else
{
// raise error since it's not supported yet
NVTE_CHECK
(
false
,
"Pre-RHT amax is not supported yet"
);
}
}
else
{
// Without RHT
if
(
compute_amax
)
{
// Amax pointers
auto
rowwise_amax_ptr
=
out
.
get_amax
().
data_ptr
;
auto
columnwise_amax_ptr
=
out
.
get_columnwise_amax
().
data_ptr
;
void
*
amax_ptr
=
rowwise_amax_ptr
!=
nullptr
?
rowwise_amax_ptr
:
columnwise_amax_ptr
;
NVTE_CHECK
(
amax_ptr
!=
nullptr
,
"Could not find amax pointer"
);
// Compute amax of input tensor
out
.
set_amax
(
amax_ptr
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_compute_amax_with_config
(
input
.
data
(),
out
.
data
(),
quant_config
,
stream
);
});
out
.
set_amax
(
rowwise_amax_ptr
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
// Make sure row-wise and column-wise amaxes match
if
(
rowwise_amax_ptr
!=
amax_ptr
&&
rowwise_amax_ptr
!=
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
rowwise_amax_ptr
,
amax_ptr
,
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
stream
));
}
if
(
columnwise_amax_ptr
!=
amax_ptr
&&
columnwise_amax_ptr
!=
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
columnwise_amax_ptr
,
amax_ptr
,
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
stream
));
}
}
}
// amax reduction
if
(
this
->
with_amax_reduction
)
{
std
::
vector
<
at
::
Tensor
>
amax_tensors
;
// push amax tensors inside if they need to be reduced
auto
make_amax_tensor
=
[](
void
*
data_ptr
)
{
return
at
::
from_blob
(
data_ptr
,
std
::
vector
<
int64_t
>
{
1
},
[](
void
*
)
{},
// deleter doing nothing since it doesn't own the data
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kFloat32
));
};
if
(
rowwise_usage
)
{
amax_tensors
.
push_back
(
make_amax_tensor
(
out
.
get_amax
().
data_ptr
));
}
if
(
columnwise_usage
)
{
amax_tensors
.
push_back
(
make_amax_tensor
(
out
.
get_columnwise_amax
().
data_ptr
));
}
c10d
::
AllreduceCoalescedOptions
opts
;
opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
NVTE_SCOPED_GIL_RELEASE
(
{
this
->
amax_reduction_group
->
allreduce_coalesced
(
amax_tensors
,
opts
)
->
wait
();
});
}
if
(
this
->
with_rht
)
{
if
(
rowwise_usage
)
{
// For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise
TensorWrapper
out_identity
(
out
.
scaling_mode
());
auto
out_identity_data
=
out
.
get_rowwise_data
();
auto
out_identity_scale_inv
=
out
.
get_rowwise_scale_inv
();
auto
out_identity_amax
=
out
.
get_amax
();
out_identity
.
set_rowwise_data
(
out_identity_data
.
data_ptr
,
static_cast
<
DType
>
(
out_identity_data
.
dtype
),
out_identity_data
.
shape
);
out_identity
.
set_rowwise_scale_inv
(
out_identity_scale_inv
.
data_ptr
,
static_cast
<
DType
>
(
out_identity_scale_inv
.
dtype
),
out_identity_scale_inv
.
shape
);
out_identity
.
set_amax
(
out_identity_amax
.
data_ptr
,
static_cast
<
DType
>
(
out_identity_amax
.
dtype
),
out_identity_amax
.
shape
);
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_quantize_v2
(
input
.
data
(),
out_identity
.
data
(),
quant_config
,
stream
);
});
}
if
(
columnwise_usage
)
{
// Get the output columnwise data, scale_inv, and amax
auto
out_columnwise_data
=
out
.
get_columnwise_data
();
auto
out_columnwise_scale_inv
=
out
.
get_columnwise_scale_inv
();
// NOTE: should already be populated.
auto
out_columnwise_amax
=
out
.
get_columnwise_amax
();
// Create a wrapper for the columnwise output, as the rowwise output.
// The reason is due to the input `rht_output_t` is already in the transposed layout.
// Thus, we only need a rowwise quantization to generate the columnwise output.
TensorWrapper
out_transpose
(
out
.
scaling_mode
());
// Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail
// need to convert the shape to 2D here
auto
colwise_data_shape
=
out_columnwise_data
.
shape
;
std
::
vector
<
size_t
>
colwise_data_shape_2d
;
// shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte
// the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again
// so the multiple 2 get cancelled out
colwise_data_shape_2d
.
push_back
(
colwise_data_shape
.
data
[
0
]);
size_t
last_dim
=
1
;
for
(
size_t
i
=
1
;
i
<
colwise_data_shape
.
ndim
;
++
i
)
{
last_dim
*=
colwise_data_shape
.
data
[
i
];
}
colwise_data_shape_2d
.
push_back
(
last_dim
);
out_transpose
.
set_rowwise_data
(
out_columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
out_columnwise_data
.
dtype
),
colwise_data_shape_2d
);
out_transpose
.
set_rowwise_scale_inv
(
out_columnwise_scale_inv
.
data_ptr
,
static_cast
<
DType
>
(
out_columnwise_scale_inv
.
dtype
),
out_columnwise_scale_inv
.
shape
);
out_transpose
.
set_amax
(
out_columnwise_amax
.
data_ptr
,
static_cast
<
DType
>
(
out_columnwise_amax
.
dtype
),
out_columnwise_amax
.
shape
);
if
(
!
eligible_for_rht_cast_fusion
)
{
// Invoking fallback RHT kernel.
// If using RHT, then amax will be computed in the RHT step
// If not using RHT, then amax will be computed based on input x
at
::
Tensor
rht_output_t
;
// The RHT(x_t) output, in columnwise layout
// This wrapper is going to be passed as input to the quantization kernel.
TensorWrapper
rht_output_t_cpp
;
// Wrapper to contain the RHT(x) and RHT(x_t) outputs
rht_output_t
=
allocateTorchTensor
(
static_cast
<
int
>
(
cols
),
static_cast
<
int
>
(
rows
),
input
.
dtype
());
// NOTE (frsun): This is non-intuitive, we are writing the
// result of transposed RHT to the output of rowwise.
rht_output_t_cpp
.
set_rowwise_data
(
rht_output_t
.
data_ptr
(),
input
.
dtype
(),
std
::
vector
<
size_t
>
{
cols
,
rows
});
NVTE_SCOPED_GIL_RELEASE
({
// Perform the RHT(input.t), and write to rht_output_cpp.columnwise.
nvte_hadamard_transform
(
input
.
data
(),
rht_output_t_cpp
.
data
(),
0
,
this
->
rht_matrix_random_sign_mask_t
,
stream
);
});
// Quantize kernel will treat everything as rowwise input/output, which is
// intended.
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
rht_output_t_cpp
.
data
(),
out_transpose
.
data
(),
quant_config
,
stream
);
});
}
else
{
// RHT cast fusion kernel.
NVTE_CHECK
(
this
->
rht_matrix
.
defined
()
&&
this
->
rht_matrix
.
numel
()
>
0
,
"RHT matrix is not set"
);
auto
rht_matrix_nvte
=
makeTransformerEngineTensor
(
this
->
rht_matrix
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_hadamard_transform_cast_fusion_columnwise
(
input
.
data
(),
out_transpose
.
data
(),
rht_matrix_nvte
.
data
(),
quant_config
,
stream
);
});
}
}
}
else
{
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
input
.
data
(),
out
.
data
(),
quant_config
,
stream
);
});
}
}
void
NVFP4Quantizer
::
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
)
{
this
->
quantize_impl
(
input
,
out
,
noop_flag
,
true
);
}
void
NVFP4Quantizer
::
quantize_with_amax
(
TensorWrapper
&
input
,
TensorWrapper
&
out
)
{
// Update output tensor amaxes with input tensor amax
auto
input_amax_ptr
=
input
.
amax
();
auto
output_rowwise_amax_ptr
=
out
.
get_amax
().
data_ptr
;
auto
output_columnwise_amax_ptr
=
out
.
get_columnwise_amax
().
data_ptr
;
NVTE_CHECK
(
input_amax_ptr
!=
nullptr
||
(
output_rowwise_amax_ptr
==
nullptr
&&
output_columnwise_amax_ptr
==
nullptr
),
"Input tensor does not have pre-computed amax"
);
if
(
input_amax_ptr
!=
output_rowwise_amax_ptr
&&
input_amax_ptr
!=
nullptr
&&
output_rowwise_amax_ptr
!=
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
output_rowwise_amax_ptr
,
input_amax_ptr
,
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
at
::
cuda
::
getCurrentCUDAStream
()));
}
if
(
input_amax_ptr
!=
output_columnwise_amax_ptr
&&
input_amax_ptr
!=
nullptr
&&
output_columnwise_amax_ptr
!=
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
output_columnwise_amax_ptr
,
input_amax_ptr
,
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
at
::
cuda
::
getCurrentCUDAStream
()));
}
input
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
input
.
defaultShape
);
// Perform quantization
this
->
quantize_impl
(
input
,
out
,
std
::
nullopt
,
false
);
}
std
::
vector
<
size_t
>
NVFP4Quantizer
::
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
{
size_t
numel
=
1
;
for
(
auto
s
:
shape
)
{
numel
*=
s
;
}
auto
last_dim
=
shape
.
back
();
auto
flat_first_dim
=
numel
/
last_dim
;
NVTE_CHECK
(
last_dim
%
NVFP4_BLOCK_SIZE
==
0
,
"Last dim for NVFP4 must be divisible by "
,
NVFP4_BLOCK_SIZE
,
" (got dim="
,
last_dim
,
")"
);
NVTE_CHECK
(
flat_first_dim
%
NVFP4_BLOCK_SIZE
==
0
,
"NVFP4 requires tensor dims that are divisible by "
,
NVFP4_BLOCK_SIZE
,
" (got shape="
,
shape
,
")"
);
std
::
vector
<
size_t
>
scale_shape
;
bool
rowwise_usage
=
!
columnwise
;
if
(
rowwise_usage
)
{
// rowwise scaling factor shape
size_t
sinv0
=
roundup
(
flat_first_dim
,
128
);
size_t
sinv1
=
roundup
(
last_dim
/
NVFP4_BLOCK_SIZE
,
4
);
scale_shape
=
{
sinv0
,
sinv1
};
}
else
{
// columnwise scaling factor shape
size_t
sinv0
=
roundup
(
last_dim
,
128
);
size_t
sinv1
=
roundup
(
flat_first_dim
/
NVFP4_BLOCK_SIZE
,
4
);
scale_shape
=
{
sinv0
,
sinv1
};
}
return
scale_shape
;
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/type_converters.cpp
View file @
063ef88d
...
@@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
...
@@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
return
ret
;
return
ret
;
}
}
TensorWrapper
NVTETensorFromNVFP4Tensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
)
{
const
DType
dtype
=
tensor
.
attr
(
"_fp4_dtype"
).
cast
<
DType
>
();
auto
ret
=
TensorWrapper
(
NVTE_NVFP4_1D_SCALING
);
bool
rowwise_usage
=
!
(
tensor
.
attr
(
"_rowwise_data"
).
is_none
());
bool
columnwise_usage
=
!
(
tensor
.
attr
(
"_columnwise_data"
).
is_none
());
NVTE_CHECK
(
rowwise_usage
||
columnwise_usage
,
"No data found for NVFP4 Tensor."
);
// Row-scaled data
if
(
rowwise_usage
)
{
const
auto
&
data
=
tensor
.
attr
(
"_rowwise_data"
).
cast
<
at
::
Tensor
>
();
const
auto
&
scale_inv
=
tensor
.
attr
(
"_rowwise_scale_inv"
).
cast
<
at
::
Tensor
>
();
const
auto
&
amax_rowwise
=
tensor
.
attr
(
"_amax_rowwise"
).
cast
<
at
::
Tensor
>
();
ret
.
set_rowwise_data
(
data
.
data_ptr
(),
dtype
,
convert_shape_back_from_fp4
(
getTensorShape
(
data
),
false
));
ret
.
set_rowwise_scale_inv
(
scale_inv
.
data_ptr
(),
DType
::
kFloat8E4M3
,
getTensorShape
(
scale_inv
));
ret
.
set_amax
(
amax_rowwise
.
data_ptr
(),
DType
::
kFloat32
,
getTensorShape
(
amax_rowwise
));
}
// Column-scaled data
if
(
columnwise_usage
)
{
const
auto
&
data
=
tensor
.
attr
(
"_columnwise_data"
).
cast
<
at
::
Tensor
>
();
const
auto
&
scale_inv
=
tensor
.
attr
(
"_columnwise_scale_inv"
).
cast
<
at
::
Tensor
>
();
const
auto
&
amax_columnwise
=
tensor
.
attr
(
"_amax_columnwise"
).
cast
<
at
::
Tensor
>
();
ret
.
set_columnwise_data
(
data
.
data_ptr
(),
DType
::
kFloat4E2M1
,
convert_shape_back_from_fp4
(
getTensorShape
(
data
),
false
));
ret
.
set_columnwise_scale_inv
(
scale_inv
.
data_ptr
(),
DType
::
kFloat8E4M3
,
getTensorShape
(
scale_inv
));
ret
.
set_columnwise_amax
(
amax_columnwise
.
data_ptr
(),
DType
::
kFloat32
,
getTensorShape
(
amax_columnwise
));
}
// Quantizer state
quantizer
->
set_quantization_params
(
&
ret
);
return
ret
;
}
}
// namespace detail
}
// namespace detail
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/util.cpp
View file @
063ef88d
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "util.h"
#include "util.h"
#include "common.h"
#include "common.h"
#include "common/common.h"
std
::
optional
<
at
::
Tensor
>
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
std
::
optional
<
at
::
Tensor
>
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
bool
rowwise
)
{
bool
rowwise
)
{
...
@@ -14,22 +15,31 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
...
@@ -14,22 +15,31 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
if
(
input
.
scaling_mode
()
==
NVTE_INVALID_SCALING
)
{
if
(
input
.
scaling_mode
()
==
NVTE_INVALID_SCALING
)
{
NVTE_ERROR
(
"Invalid scaling mode for swizzle."
);
NVTE_ERROR
(
"Invalid scaling mode for swizzle."
);
}
else
if
(
input
.
scaling_mode
()
!=
NVTE_MXFP8_1D_SCALING
)
{
}
else
if
(
input
.
scaling_mode
()
!=
NVTE_MXFP8_1D_SCALING
&&
input
.
scaling_mode
()
!=
NVTE_NVFP4_1D_SCALING
)
{
return
std
::
nullopt
;
return
std
::
nullopt
;
}
}
NVTE_CHECK
(
input
.
element_size
()
==
1
,
"8-bit input required for swizzling scaling factors."
);
NVTE_CHECK
(
input
.
element_size_bits
()
==
4
||
input
.
element_size_bits
()
==
8
,
"4-bit or 8-bit input required for swizzling scaling factors."
);
const
auto
nvfp4
=
input
.
scaling_mode
()
==
NVTE_NVFP4_1D_SCALING
;
NVTEBasicTensor
scale_inv
;
NVTEBasicTensor
scale_inv
;
NVTEShape
nvte_input_shape
;
if
(
rowwise
)
{
if
(
rowwise
)
{
nvte_input_shape
=
input
.
shape
();
scale_inv
=
input
.
get_rowwise_scale_inv
();
scale_inv
=
input
.
get_rowwise_scale_inv
();
}
else
{
}
else
{
nvte_input_shape
=
input
.
get_columnwise_data
().
shape
;
scale_inv
=
input
.
get_columnwise_scale_inv
();
scale_inv
=
input
.
get_columnwise_scale_inv
();
}
}
auto
input_shape
=
nvte_shape_to_vector
(
input
.
shape
()
);
auto
input_shape
=
nvte_shape_to_vector
(
nvte_
input
_
shape
);
auto
scale_inv_shape
=
nvte_shape_to_vector
(
scale_inv
.
shape
);
auto
scale_inv_shape
=
nvte_shape_to_vector
(
scale_inv
.
shape
);
NVTE_CHECK
(
input_shape
.
size
()
>=
2
,
"Wrong ndims for swizzle input shape."
);
// Allocate memory for swizzled output.
// Allocate memory for swizzled output.
auto
options
=
at
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
torch
::
kCUDA
);
auto
options
=
at
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
torch
::
kCUDA
);
std
::
vector
<
int64_t
>
scale_inv_shape_int
;
std
::
vector
<
int64_t
>
scale_inv_shape_int
;
...
@@ -41,36 +51,34 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
...
@@ -41,36 +51,34 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
// Reconstruct input only to avoid swizzling both directions if not needed.
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
// The specific dtype used is irrelevant, just needs to be correct bits.
transformer_engine
::
TensorWrapper
input_cu
(
NVTE_MXFP8_1D_SCALING
);
transformer_engine
::
TensorWrapper
input_cu
(
input
.
scaling_mode
());
transformer_engine
::
TensorWrapper
output_cu
(
NVTE_MXFP8_1D_SCALING
);
transformer_engine
::
TensorWrapper
output_cu
(
input
.
scaling_mode
());
const
auto
input_dtype
=
(
nvfp4
)
?
transformer_engine
::
DType
::
kFloat4E2M1
:
transformer_engine
::
DType
::
kFloat8E4M3
;
const
auto
scale_inv_dtype
=
(
nvfp4
)
?
transformer_engine
::
DType
::
kFloat8E4M3
:
transformer_engine
::
DType
::
kFloat8E8M0
;
if
(
rowwise
)
{
if
(
rowwise
)
{
input_cu
.
set_rowwise_data
(
input
.
dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_rowwise_data
(
input
.
dptr
(),
input_dtype
,
input_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
scale_inv_shape
);
output_cu
.
set_rowwise_data
(
input
.
dptr
(),
input_dtype
,
input_shape
);
output_cu
.
set_rowwise_data
(
input
.
dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
else
{
}
else
{
input_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
input_dtype
,
input_shape
);
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
output_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
input_dtype
,
input_shape
);
scale_inv_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
output_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
}
// Launch kernel
// Launch kernel
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
if
(
rowwise
)
{
if
(
rowwise
)
{
input
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
input
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
scale_inv_shape
);
}
else
{
}
else
{
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
scale_inv_shape
);
}
}
return
swizzled_scale_inv
;
return
swizzled_scale_inv
;
...
@@ -170,3 +178,72 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
...
@@ -170,3 +178,72 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
return
buffer
;
return
buffer
;
}
}
at
::
Tensor
convert_block_scaling_to_mxfp8_tensor
(
transformer_engine
::
TensorWrapper
&
input
,
bool
rowwise
)
{
using
namespace
transformer_engine
::
pytorch
;
using
transformer_engine
::
DIVUP
;
// Check input tensor
const
NVTEScalingMode
scaling_mode
=
input
.
scaling_mode
();
NVTE_CHECK
(
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
scaling_mode
==
NVTE_BLOCK_SCALING_2D
,
"Input tensor must be a block scaling tensor"
);
// Get tensor data
NVTEBasicTensor
data
;
size_t
data_flat_first_dim
=
1
;
size_t
data_flat_last_dim
=
1
;
if
(
rowwise
)
{
data
=
input
.
get_rowwise_data
();
for
(
int
i
=
0
;
i
<
data
.
shape
.
ndim
-
1
;
++
i
)
{
data_flat_first_dim
*=
data
.
shape
.
data
[
i
];
}
data_flat_last_dim
=
data
.
shape
.
data
[
data
.
shape
.
ndim
-
1
];
}
else
{
data
=
input
.
get_columnwise_data
();
data_flat_first_dim
=
data
.
shape
.
data
[
0
];
for
(
int
i
=
1
;
i
<
data
.
shape
.
ndim
;
++
i
)
{
data_flat_last_dim
*=
data
.
shape
.
data
[
i
];
}
}
NVTEShape
data_shape
{};
data_shape
.
data
[
0
]
=
data_flat_first_dim
;
data_shape
.
data
[
1
]
=
data_flat_last_dim
;
data_shape
.
ndim
=
2
;
// Recreate input tensor with rowwise usage
transformer_engine
::
TensorWrapper
input_cu
(
scaling_mode
);
input_cu
.
set_rowwise_data
(
data
.
data_ptr
,
input
.
dtype
(),
data_shape
);
const
NVTEBasicTensor
scale_inv
=
rowwise
?
input
.
get_rowwise_scale_inv
()
:
input
.
get_columnwise_scale_inv
();
input_cu
.
set_rowwise_scale_inv
(
scale_inv
.
data_ptr
,
static_cast
<
transformer_engine
::
DType
>
(
scale_inv
.
dtype
),
scale_inv
.
shape
);
// Create output tensor
transformer_engine
::
TensorWrapper
output_cu
(
NVTE_MXFP8_1D_SCALING
);
output_cu
.
set_rowwise_data
(
data
.
data_ptr
,
input
.
dtype
(),
data_shape
);
// Output swizzled mxfp8 scaling factor dimensions
const
size_t
swizzled_scale_inv_first_dim
=
DIVUP
<
size_t
>
(
data_flat_first_dim
,
128
)
*
128
;
const
size_t
swizzled_scale_inv_last_dim
=
DIVUP
<
size_t
>
(
data_flat_last_dim
,
128
)
*
4
;
// Allocate memory for swizzled mxfp8 scaling factors
const
auto
options
=
at
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
torch
::
kCUDA
);
at
::
Tensor
swizzled_scale_inv
=
at
::
empty
(
std
::
vector
<
int64_t
>
{
swizzled_scale_inv_first_dim
,
swizzled_scale_inv_last_dim
},
options
);
// Set rowwise scaling factors on output
void
*
const
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
NVTEShape
swizzled_scale_inv_shape
{};
swizzled_scale_inv_shape
.
data
[
0
]
=
swizzled_scale_inv_first_dim
;
swizzled_scale_inv_shape
.
data
[
1
]
=
swizzled_scale_inv_last_dim
;
swizzled_scale_inv_shape
.
ndim
=
2
;
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
swizzled_scale_inv_shape
);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input
=
std
::
move
(
output_cu
);
return
swizzled_scale_inv
;
}
transformer_engine/pytorch/csrc/util.h
View file @
063ef88d
...
@@ -27,4 +27,16 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
...
@@ -27,4 +27,16 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
std
::
optional
<
at
::
Tensor
>
multi_tensor_swizzle_scaling_factors
(
std
::
optional
<
at
::
Tensor
>
multi_tensor_swizzle_scaling_factors
(
std
::
vector
<
transformer_engine
::
TensorWrapper
>
&
inputs
,
bool
rowwise
);
std
::
vector
<
transformer_engine
::
TensorWrapper
>
&
inputs
,
bool
rowwise
);
/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place.
*
* If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid
* transposing it in memory. Due to differences in how block scaling and mxfp8 store data,
* this requires the calling code to treat the output tensor as having been tranposed in this case.
*
* Returns the swizzled scaling factor of the converted mxfp8 tensor.
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
*/
at
::
Tensor
convert_block_scaling_to_mxfp8_tensor
(
transformer_engine
::
TensorWrapper
&
input
,
bool
rowwise
);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
transformer_engine/pytorch/distributed.py
View file @
063ef88d
...
@@ -36,14 +36,17 @@ from .utils import (
...
@@ -36,14 +36,17 @@ from .utils import (
needs_quantized_gemm
,
needs_quantized_gemm
,
)
)
from
.constants
import
dist_group_type
from
.constants
import
dist_group_type
from
.
fp8
import
FP8GlobalStateManager
,
fp8_
autocast
from
.
quantization
import
FP8GlobalStateManager
,
autocast
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
.tensor.nvfp4_tensor
import
NVFP4Quantizer
from
.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
.tensor.quantized_tensor
import
QuantizedTensor
,
Quantizer
from
.tensor.quantized_tensor
import
QuantizedTensorStorage
,
QuantizedTensor
,
Quantizer
from
.tensor._internal.float8_tensor_base
import
Float8TensorBase
from
.tensor.storage.float8_tensor_storage
import
Float8TensorStorage
from
.tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
.tensor.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
from
.tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
.tensor.storage.nvfp4_tensor_storage
import
NVFP4TensorStorage
from
.tensor.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
.triton.pad
import
pad_columnwise_scale_inv
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
,
DebugQuantizer
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
,
DebugQuantizer
...
@@ -416,8 +419,8 @@ class _CheckpointFunction(torch.autograd.Function):
...
@@ -416,8 +419,8 @@ class _CheckpointFunction(torch.autograd.Function):
detached_inputs
=
detach_variable
(
inputs
)
detached_inputs
=
detach_variable
(
inputs
)
with
torch
.
enable_grad
(),
ctx
.
recompute_ctx
,
ctx
.
torch_gpu_amp_ctx
,
ctx
.
torch_cpu_amp_ctx
,
activation_recompute_forward
(
with
torch
.
enable_grad
(),
ctx
.
recompute_ctx
,
ctx
.
torch_gpu_amp_ctx
,
ctx
.
torch_cpu_amp_ctx
,
activation_recompute_forward
(
activation_recompute
=
True
,
recompute_phase
=
True
activation_recompute
=
True
,
recompute_phase
=
True
),
fp8_
autocast
(
),
autocast
(
enabled
=
ctx
.
fp8
,
fp8_
recipe
=
ctx
.
fp8_recipe
enabled
=
ctx
.
fp8
,
recipe
=
ctx
.
fp8_recipe
):
):
outputs
=
ctx
.
run_function
(
*
detached_inputs
,
**
ctx
.
kwargs
)
outputs
=
ctx
.
run_function
(
*
detached_inputs
,
**
ctx
.
kwargs
)
...
@@ -751,8 +754,8 @@ def checkpoint(
...
@@ -751,8 +754,8 @@ def checkpoint(
def
recompute_fn
(
*
args
,
**
kwargs
):
def
recompute_fn
(
*
args
,
**
kwargs
):
with
torch
.
autograd
.
enable_grad
(),
(
with
torch
.
autograd
.
enable_grad
(),
(
te_recompute_ctx
te_recompute_ctx
),
user_recompute_ctx
,
torch_gpu_amp_forward_ctx
,
torch_cpu_amp_forward_ctx
,
fp8_
autocast
(
),
user_recompute_ctx
,
torch_gpu_amp_forward_ctx
,
torch_cpu_amp_forward_ctx
,
autocast
(
enabled
=
fp8
,
fp8_
recipe
=
fp8_recipe
enabled
=
fp8
,
recipe
=
fp8_recipe
):
):
function
(
*
args
,
**
kwargs
)
function
(
*
args
,
**
kwargs
)
...
@@ -904,7 +907,7 @@ def _all_gather_fp8(
...
@@ -904,7 +907,7 @@ def _all_gather_fp8(
async_op
:
bool
=
False
,
async_op
:
bool
=
False
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
out_shape
:
Optional
[
list
[
int
]]
=
None
,
out_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
tuple
[
Float8Tensor
Bas
e
,
Optional
[
torch
.
distributed
.
Work
]]:
)
->
tuple
[
Float8Tensor
Storag
e
,
Optional
[
torch
.
distributed
.
Work
]]:
"""All-gather FP8 tensor along first dimension."""
"""All-gather FP8 tensor along first dimension."""
world_size
=
get_distributed_world_size
(
process_group
)
world_size
=
get_distributed_world_size
(
process_group
)
...
@@ -922,7 +925,7 @@ def _all_gather_fp8(
...
@@ -922,7 +925,7 @@ def _all_gather_fp8(
# Cast input tensor to FP8 if needed
# Cast input tensor to FP8 if needed
# Note: We cannot directly all-gather the transposed FP8 tensor,
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
# so temporarily modify quantizer to avoid creating FP8 transpose.
if
not
isinstance
(
inp
,
Float8Tensor
Bas
e
):
if
not
isinstance
(
inp
,
Float8Tensor
Storag
e
):
assert
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
))
assert
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
))
# we cannot directly gather the transposed fp8 tensor
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# so we need to disable columnwise usage for the quantizer
...
@@ -937,7 +940,7 @@ def _all_gather_fp8(
...
@@ -937,7 +940,7 @@ def _all_gather_fp8(
)
)
# Construct output tensor
# Construct output tensor
out
:
Float8Tensor
Bas
e
out
:
Float8Tensor
Storag
e
if
quantizer
is
not
None
:
if
quantizer
is
not
None
:
dtype
=
torch
.
float32
dtype
=
torch
.
float32
device
=
"cuda"
device
=
"cuda"
...
@@ -955,7 +958,7 @@ def _all_gather_fp8(
...
@@ -955,7 +958,7 @@ def _all_gather_fp8(
out
.
_transpose
=
None
out
.
_transpose
=
None
out
.
_transpose_invalid
=
True
out
.
_transpose_invalid
=
True
else
:
else
:
raise
RuntimeError
(
"F
P
8Tensor
Bas
e is not supported yet without Quantizer"
)
raise
RuntimeError
(
"F
loat
8Tensor
Storag
e is not supported yet without Quantizer"
)
# Assume scaling factors are identical across ranks
# Assume scaling factors are identical across ranks
out
.
_scale_inv
=
inp
.
_scale_inv
out
.
_scale_inv
=
inp
.
_scale_inv
...
@@ -1000,10 +1003,10 @@ def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
...
@@ -1000,10 +1003,10 @@ def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
def
_post_process_fp8_blockwise_gather
(
def
_post_process_fp8_blockwise_gather
(
out
:
Float8BlockwiseQTensor
Bas
e
,
out
:
Float8BlockwiseQTensor
Storag
e
,
quantizer
:
Float8BlockQuantizer
,
quantizer
:
Float8BlockQuantizer
,
handle
:
Optional
[
torch
.
distributed
.
Work
]
=
None
,
handle
:
Optional
[
torch
.
distributed
.
Work
]
=
None
,
)
->
Float8BlockwiseQTensor
Bas
e
:
)
->
Float8BlockwiseQTensor
Storag
e
:
"""Post-process FP8 blockwise gather."""
"""Post-process FP8 blockwise gather."""
if
handle
is
not
None
:
if
handle
is
not
None
:
handle
.
wait
()
handle
.
wait
()
...
@@ -1037,7 +1040,7 @@ def _post_process_fp8_blockwise_gather(
...
@@ -1037,7 +1040,7 @@ def _post_process_fp8_blockwise_gather(
class
_FP8BlockwiseAllGatherAsyncHandle
:
class
_FP8BlockwiseAllGatherAsyncHandle
:
"""Handle for asynchronous FP8 blockwise all-gather."""
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor
:
Float8BlockwiseQTensor
Bas
e
tensor
:
Float8BlockwiseQTensor
Storag
e
quantizer
:
Float8BlockQuantizer
quantizer
:
Float8BlockQuantizer
async_handle
:
torch
.
distributed
.
Work
async_handle
:
torch
.
distributed
.
Work
_synchronized
:
bool
=
False
_synchronized
:
bool
=
False
...
@@ -1075,18 +1078,18 @@ def _all_gather_fp8_blockwise(
...
@@ -1075,18 +1078,18 @@ def _all_gather_fp8_blockwise(
if
isinstance
(
inp
,
torch
.
Tensor
):
if
isinstance
(
inp
,
torch
.
Tensor
):
device
=
inp
.
device
device
=
inp
.
device
dtype
=
inp
.
dtype
dtype
=
inp
.
dtype
elif
isinstance
(
inp
,
Float8BlockwiseQTensor
Bas
e
):
elif
isinstance
(
inp
,
Float8BlockwiseQTensor
Storag
e
):
if
inp
.
_rowwise_data
is
not
None
:
if
inp
.
_rowwise_data
is
not
None
:
device
=
inp
.
_rowwise_data
.
device
device
=
inp
.
_rowwise_data
.
device
elif
inp
.
_columnwise_data
is
not
None
:
elif
inp
.
_columnwise_data
is
not
None
:
device
=
inp
.
_columnwise_data
.
device
device
=
inp
.
_columnwise_data
.
device
else
:
else
:
raise
ValueError
(
"Got Float8BlockwiseQTensor
Bas
e input tensor without any data"
)
raise
ValueError
(
"Got Float8BlockwiseQTensor
Storag
e input tensor without any data"
)
dtype
=
torch
.
bfloat16
# Only has fp8 dtype. Guess BF16 for dequant.
dtype
=
torch
.
bfloat16
# Only has fp8 dtype. Guess BF16 for dequant.
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Invalid type for input tensor (expected torch.Tensor or
Float8BlockwiseQTensorBase,
"
"Invalid type for input tensor (expected torch.Tensor or"
f
"found
{
inp
.
__class__
.
__name__
}
)"
f
"
Float8BlockwiseQTensorStorage,
found
{
inp
.
__class__
.
__name__
}
)"
)
)
world_size
=
get_distributed_world_size
(
process_group
)
world_size
=
get_distributed_world_size
(
process_group
)
...
@@ -1103,7 +1106,7 @@ def _all_gather_fp8_blockwise(
...
@@ -1103,7 +1106,7 @@ def _all_gather_fp8_blockwise(
# Doing BF16 gather for now as baseline because it's simpler
# Doing BF16 gather for now as baseline because it's simpler
if
(
if
(
not
isinstance
(
inp
,
Float8BlockwiseQTensor
Bas
e
)
not
isinstance
(
inp
,
Float8BlockwiseQTensor
Storag
e
)
and
quantizer
is
not
None
and
quantizer
is
not
None
and
not
quantizer
.
is_quantizable
(
inp
)
and
not
quantizer
.
is_quantizable
(
inp
)
):
):
...
@@ -1128,7 +1131,7 @@ def _all_gather_fp8_blockwise(
...
@@ -1128,7 +1131,7 @@ def _all_gather_fp8_blockwise(
# Set to compact usage in case the quantizer is not correctly configured
# Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage
=
quantizer
.
all_gather_usage
orig_all_gather_usage
=
quantizer
.
all_gather_usage
quantizer
.
all_gather_usage
=
True
quantizer
.
all_gather_usage
=
True
if
not
isinstance
(
inp
,
Float8BlockwiseQTensor
Bas
e
):
if
not
isinstance
(
inp
,
Float8BlockwiseQTensor
Storag
e
):
inp
=
quantizer
(
inp
)
inp
=
quantizer
(
inp
)
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
quantizer
.
columnwise_usage
and
inp
.
_columnwise_data
is
None
quantizer
.
columnwise_usage
and
inp
.
_columnwise_data
is
None
...
@@ -1204,6 +1207,245 @@ def _all_gather_fp8_blockwise(
...
@@ -1204,6 +1207,245 @@ def _all_gather_fp8_blockwise(
return
out
,
handle
return
out
,
handle
def
_swap_first_dims
(
tensor
:
torch
.
Tensor
,
world_size
:
int
):
"""
Swap first 2 dimensions of a tensor to fix interleaved
data format after gathering transposed data.
For more than 2 dimensions, we squash the trailing dimensions,
instead of the first few dimensions, that's because the shape
passed in this function is already transposed.
"""
shape
=
tensor
.
shape
assert
tensor
.
ndim
>=
2
,
"Wrong number of dimensions for fixing interleave."
first_dim
=
shape
[
0
]
flattened_trailing
=
math
.
prod
(
shape
[
1
:])
assert
first_dim
%
world_size
==
0
,
"Wrong dimensions for fixing interleave."
tensor
=
tensor
.
reshape
(
world_size
,
first_dim
//
world_size
,
flattened_trailing
)
tensor
=
tex
.
swap_first_dims
(
tensor
,
out
=
None
)
return
tensor
.
reshape
(
first_dim
//
world_size
,
flattened_trailing
*
world_size
)
def
_post_process_nvfp4_gather
(
out
:
NVFP4TensorStorage
,
columnwise_data_interleaved
:
torch
.
Tensor
,
columnwise_scale_inv_interleaved
:
torch
.
Tensor
,
world_size
:
int
,
handle
:
Optional
[
torch
.
distributed
.
Work
]
=
None
,
)
->
NVFP4TensorStorage
:
"""Post-process FP8 blockwise gather."""
if
handle
is
not
None
:
handle
.
wait
()
handle
=
None
# Fix the interleaved transposed data from gathering along first dim.
out
.
_columnwise_scale_inv
=
_swap_first_dims
(
columnwise_scale_inv_interleaved
,
world_size
)
out
.
_columnwise_data
=
_swap_first_dims
(
columnwise_data_interleaved
,
world_size
)
# Optionally pad the scaling inverse if needed.
out
.
_columnwise_scale_inv
=
pad_columnwise_scale_inv
(
out
.
_columnwise_scale_inv
)
@
dataclass
class
_NVFP4AllGatherAsyncHandle
:
"""Handle for asynchronous NVFP4 all-gather."""
output
:
NVFP4TensorStorage
columnwise_data_interleaved
:
torch
.
Tensor
columnwise_scale_inv_interleaved
:
torch
.
Tensor
world_size
:
int
async_handle
:
torch
.
distributed
.
Work
_synchronized
:
bool
=
False
def
wait
(
self
)
->
None
:
"""Wait for the async operation to complete and post-process the tensor."""
if
self
.
_synchronized
:
return
self
.
async_handle
.
wait
()
_post_process_nvfp4_gather
(
self
.
output
,
self
.
columnwise_data_interleaved
,
self
.
columnwise_scale_inv_interleaved
,
self
.
world_size
,
)
self
.
_synchronized
=
True
def
_all_gather_nvfp4
(
inp
:
torch
.
Tensor
,
process_group
:
dist_group_type
,
*
,
async_op
:
bool
=
False
,
quantizer
:
NVFP4Quantizer
,
out_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
tuple
[
NVFP4TensorStorage
,
Optional
[
torch
.
distributed
.
Work
]]:
"""All-gather NVFP4 tensor along first dimension."""
# Input tensor attributes
in_shape
:
Iterable
[
int
]
=
None
in_shape_t
:
Iterable
[
int
]
=
None
device
:
torch
.
device
dtype
:
torch
.
dtype
# Construct packed shapes for input and input_t.
if
isinstance
(
inp
,
torch
.
Tensor
)
and
not
isinstance
(
inp
,
NVFP4TensorStorage
):
# High-precision tensor.
in_shape
=
NVFP4Quantizer
.
convert_shape_for_fp4
(
inp
.
size
())
in_shape_t
=
NVFP4Quantizer
.
convert_shape_for_fp4
(
NVFP4Quantizer
.
get_columnwise_shape
(
inp
.
size
())
)
device
=
inp
.
device
dtype
=
inp
.
dtype
elif
isinstance
(
inp
,
NVFP4TensorStorage
):
if
inp
.
_rowwise_data
is
not
None
:
in_shape
=
inp
.
_rowwise_data
.
size
()
device
=
inp
.
_rowwise_data
.
device
if
inp
.
_columnwise_data
is
not
None
:
in_shape_t
=
inp
.
_columnwise_data
.
size
()
device
=
inp
.
_columnwise_data
.
device
dtype
=
torch
.
bfloat16
else
:
raise
ValueError
(
"Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, "
f
"found
{
inp
.
__class__
.
__name__
}
)"
)
assert
in_shape
is
not
None
or
in_shape_t
is
not
None
,
"No data found."
world_size
=
get_distributed_world_size
(
process_group
)
if
out_shape
is
None
:
out_shape
=
[
in_shape
[
0
]
*
world_size
]
+
in_shape
[
1
:]
# For cases where inp has dimensions that cannot be quantized,
# we gather in high precision followed by a cast to NVFP4.
if
(
not
isinstance
(
inp
,
NVFP4TensorStorage
)
and
quantizer
is
not
None
and
not
quantizer
.
is_quantizable
(
inp
)
):
out
=
torch
.
empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
,
memory_format
=
torch
.
contiguous_format
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
)
out
=
quantizer
(
out
)
return
out
,
None
# Cast input tensor to NVFP4 with required data
if
not
isinstance
(
inp
,
NVFP4TensorStorage
):
inp
=
quantizer
(
inp
)
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
quantizer
.
columnwise_usage
and
inp
.
_columnwise_data
is
None
):
warnings
.
warn
(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to NVFP4."
)
inp
=
quantizer
(
inp
.
dequantize
())
# Construct NVFP4 output tensor
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
# Coalesce NCCL collectives for gathering data and scale inverses.
with
torch
.
distributed
.
_coalescing_manager
(
group
=
process_group
,
device
=
device
,
async_ops
=
async_op
,
)
as
gather_coalescing_manager
:
# Gather NVFP4 data for row-wise usage
if
quantizer
.
rowwise_usage
:
# Remove padding from NVFP4 scale-inverses
assert
in_shape
is
not
None
,
"Shape not found."
in_scale_inv
=
inp
.
_rowwise_scale_inv
out_scale_inv
=
out
.
_rowwise_scale_inv
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
# Launch all-gathers
torch
.
distributed
.
all_gather_into_tensor
(
out_scale_inv
,
in_scale_inv
,
group
=
process_group
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_rowwise_data
,
inp
.
_rowwise_data
,
group
=
process_group
,
)
# Transfer amax to output.
out
.
_amax_rowwise
=
inp
.
_amax_rowwise
# Gather the transposed NVFP4 data along first dimension. Fix format later.
if
quantizer
.
columnwise_usage
:
# Remove padding from NVFP4 scale-inverses
# For doing an all-gather on transposed scale inverses,
# we need to remove padding from both dimension.
in_scale_inv
=
inp
.
_columnwise_scale_inv
# take caution that for in_shape_t, flatten in the trailing dimensions!
flattened_in_shape0
=
in_shape_t
[
0
]
flattened_in_shape1
=
math
.
prod
(
in_shape_t
[
1
:])
# Remove dim0 padding
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
# Remove dim1 padding (pack first).
unpadded_dim1
=
flattened_in_shape1
*
2
//
16
if
in_scale_inv
.
size
(
1
)
!=
unpadded_dim1
:
in_scale_inv
=
in_scale_inv
[:,
:
unpadded_dim1
].
contiguous
()
# Construct tensor to gather transposed scale_inv (interleaved) and launch AG.
out_scale_inv
=
torch
.
empty
(
[
flattened_in_shape0
*
world_size
]
+
[
in_scale_inv
.
shape
[
1
]],
dtype
=
in_scale_inv
.
dtype
,
layout
=
in_scale_inv
.
layout
,
device
=
in_scale_inv
.
device
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out_scale_inv
,
in_scale_inv
,
group
=
process_group
,
)
# Construct tensor to gather transposed data (interleaved) and launch AG.
out_columnwise_data
=
torch
.
empty
(
[
inp
.
_columnwise_data
.
shape
[
0
]
*
world_size
]
+
list
(
inp
.
_columnwise_data
.
shape
[
1
:]),
dtype
=
inp
.
_columnwise_data
.
dtype
,
layout
=
inp
.
_columnwise_data
.
layout
,
device
=
inp
.
_columnwise_data
.
device
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out_columnwise_data
,
inp
.
_columnwise_data
,
group
=
process_group
,
)
# Transfer amax to output.
out
.
_amax_columnwise
=
inp
.
_amax_columnwise
handle
=
gather_coalescing_manager
if
async_op
else
None
# Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed.
if
async_op
and
quantizer
.
columnwise_usage
:
handle
=
_NVFP4AllGatherAsyncHandle
(
out
,
out_columnwise_data
,
out_scale_inv
,
world_size
,
handle
)
elif
quantizer
.
columnwise_usage
:
_post_process_nvfp4_gather
(
out
,
out_columnwise_data
,
out_scale_inv
,
world_size
,
handle
)
return
out
,
handle
def
_all_gather_mxfp8
(
def
_all_gather_mxfp8
(
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
process_group
:
dist_group_type
,
process_group
:
dist_group_type
,
...
@@ -1211,7 +1453,7 @@ def _all_gather_mxfp8(
...
@@ -1211,7 +1453,7 @@ def _all_gather_mxfp8(
async_op
:
bool
=
False
,
async_op
:
bool
=
False
,
quantizer
:
MXFP8Quantizer
,
quantizer
:
MXFP8Quantizer
,
out_shape
:
Optional
[
list
[
int
]]
=
None
,
out_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
tuple
[
MXFP8Tensor
Bas
e
,
Optional
[
torch
.
distributed
.
Work
]]:
)
->
tuple
[
MXFP8Tensor
Storag
e
,
Optional
[
torch
.
distributed
.
Work
]]:
"""All-gather MXFP8 tensor along first dimension."""
"""All-gather MXFP8 tensor along first dimension."""
# Input tensor attributes
# Input tensor attributes
...
@@ -1222,7 +1464,7 @@ def _all_gather_mxfp8(
...
@@ -1222,7 +1464,7 @@ def _all_gather_mxfp8(
in_shape
=
inp
.
size
()
in_shape
=
inp
.
size
()
device
=
inp
.
device
device
=
inp
.
device
dtype
=
inp
.
dtype
dtype
=
inp
.
dtype
elif
isinstance
(
inp
,
MXFP8Tensor
Bas
e
):
elif
isinstance
(
inp
,
MXFP8Tensor
Storag
e
):
if
inp
.
_rowwise_data
is
not
None
:
if
inp
.
_rowwise_data
is
not
None
:
in_shape
=
inp
.
_rowwise_data
.
size
()
in_shape
=
inp
.
_rowwise_data
.
size
()
device
=
inp
.
_rowwise_data
.
device
device
=
inp
.
_rowwise_data
.
device
...
@@ -1234,7 +1476,7 @@ def _all_gather_mxfp8(
...
@@ -1234,7 +1476,7 @@ def _all_gather_mxfp8(
dtype
=
torch
.
bfloat16
# Guess high-precision dtype.
dtype
=
torch
.
bfloat16
# Guess high-precision dtype.
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Invalid type for input tensor (expected torch.Tensor or MXFP8Tensor
Bas
e, "
"Invalid type for input tensor (expected torch.Tensor or MXFP8Tensor
Storag
e, "
f
"found
{
inp
.
__class__
.
__name__
}
)"
f
"found
{
inp
.
__class__
.
__name__
}
)"
)
)
...
@@ -1246,7 +1488,7 @@ def _all_gather_mxfp8(
...
@@ -1246,7 +1488,7 @@ def _all_gather_mxfp8(
# For cases where inp has dimensions that cannot be quantized,
# For cases where inp has dimensions that cannot be quantized,
# we gather in high precision followed by a cast to FP8.
# we gather in high precision followed by a cast to FP8.
if
(
if
(
not
isinstance
(
inp
,
MXFP8Tensor
Bas
e
)
not
isinstance
(
inp
,
MXFP8Tensor
Storag
e
)
and
quantizer
is
not
None
and
quantizer
is
not
None
and
not
quantizer
.
is_quantizable
(
inp
)
and
not
quantizer
.
is_quantizable
(
inp
)
):
):
...
@@ -1261,7 +1503,7 @@ def _all_gather_mxfp8(
...
@@ -1261,7 +1503,7 @@ def _all_gather_mxfp8(
return
out
,
None
return
out
,
None
# Cast input tensor to MXFP8 with required data
# Cast input tensor to MXFP8 with required data
if
not
isinstance
(
inp
,
MXFP8Tensor
Bas
e
):
if
not
isinstance
(
inp
,
MXFP8Tensor
Storag
e
):
inp
=
quantizer
(
inp
)
inp
=
quantizer
(
inp
)
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
quantizer
.
columnwise_usage
and
inp
.
_columnwise_data
is
None
quantizer
.
columnwise_usage
and
inp
.
_columnwise_data
is
None
...
@@ -1291,7 +1533,6 @@ def _all_gather_mxfp8(
...
@@ -1291,7 +1533,6 @@ def _all_gather_mxfp8(
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
out_scale_inv
[
flattened_in_shape0
*
world_size
:].
zero_
()
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
# Launch all-gathers
# Launch all-gathers
...
@@ -1315,7 +1556,6 @@ def _all_gather_mxfp8(
...
@@ -1315,7 +1556,6 @@ def _all_gather_mxfp8(
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
//
32
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
//
32
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
out_scale_inv
[
flattened_in_shape0
*
world_size
:].
zero_
()
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
# Launch all-gathers
# Launch all-gathers
...
@@ -1347,7 +1587,7 @@ def gather_along_first_dim(
...
@@ -1347,7 +1587,7 @@ def gather_along_first_dim(
# Return immediately if no communication is required
# Return immediately if no communication is required
world_size
=
get_distributed_world_size
(
process_group
)
world_size
=
get_distributed_world_size
(
process_group
)
if
world_size
==
1
:
if
world_size
==
1
:
if
quantizer
is
not
None
and
not
isinstance
(
inp
,
QuantizedTensor
):
if
quantizer
is
not
None
and
not
isinstance
(
inp
,
QuantizedTensor
Storage
):
inp
=
quantizer
(
inp
)
inp
=
quantizer
(
inp
)
return
inp
,
None
return
inp
,
None
...
@@ -1394,7 +1634,7 @@ def gather_along_first_dim(
...
@@ -1394,7 +1634,7 @@ def gather_along_first_dim(
out_shape
[
0
]
*=
world_size
out_shape
[
0
]
*=
world_size
# FP8 case: delayed scaling or current scaling
# FP8 case: delayed scaling or current scaling
if
isinstance
(
inp
,
Float8Tensor
Bas
e
)
or
isinstance
(
if
isinstance
(
inp
,
Float8Tensor
Storag
e
)
or
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
):
):
return
_all_gather_fp8
(
return
_all_gather_fp8
(
...
@@ -1406,7 +1646,9 @@ def gather_along_first_dim(
...
@@ -1406,7 +1646,9 @@ def gather_along_first_dim(
)
)
# FP8 block scaling case, block length = 128
# FP8 block scaling case, block length = 128
if
isinstance
(
inp
,
Float8BlockwiseQTensorBase
)
or
isinstance
(
quantizer
,
Float8BlockQuantizer
):
if
isinstance
(
inp
,
Float8BlockwiseQTensorStorage
)
or
isinstance
(
quantizer
,
Float8BlockQuantizer
):
return
_all_gather_fp8_blockwise
(
return
_all_gather_fp8_blockwise
(
inp
,
inp
,
process_group
,
process_group
,
...
@@ -1416,7 +1658,7 @@ def gather_along_first_dim(
...
@@ -1416,7 +1658,7 @@ def gather_along_first_dim(
)
)
# MXFP8 case
# MXFP8 case
if
isinstance
(
inp
,
MXFP8Tensor
Bas
e
)
or
isinstance
(
quantizer
,
MXFP8Quantizer
):
if
isinstance
(
inp
,
MXFP8Tensor
Storag
e
)
or
isinstance
(
quantizer
,
MXFP8Quantizer
):
assert
isinstance
(
quantizer
,
MXFP8Quantizer
)
assert
isinstance
(
quantizer
,
MXFP8Quantizer
)
return
_all_gather_mxfp8
(
return
_all_gather_mxfp8
(
inp
,
inp
,
...
@@ -1426,13 +1668,24 @@ def gather_along_first_dim(
...
@@ -1426,13 +1668,24 @@ def gather_along_first_dim(
out_shape
=
out_shape
,
out_shape
=
out_shape
,
)
)
# NVFP4 case
if
isinstance
(
inp
,
NVFP4TensorStorage
)
or
isinstance
(
quantizer
,
NVFP4Quantizer
):
assert
isinstance
(
quantizer
,
NVFP4Quantizer
)
return
_all_gather_nvfp4
(
inp
,
process_group
,
async_op
=
async_op
,
quantizer
=
quantizer
,
out_shape
=
out_shape
,
)
# High-precision communication for quantized tensors
# High-precision communication for quantized tensors
if
quantizer
is
not
None
:
if
quantizer
is
not
None
:
warnings
.
warn
(
warnings
.
warn
(
"Attempting to all-gather an unsupported quantized tensor. "
"Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather."
"Falling back to high-precision all-gather."
)
)
if
isinstance
(
inp
,
QuantizedTensor
):
if
isinstance
(
inp
,
QuantizedTensor
Storage
):
inp
=
inp
.
dequantize
()
inp
=
inp
.
dequantize
()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
# means that it should directly output GEMM_READY format
...
@@ -1450,7 +1703,7 @@ def gather_along_first_dim(
...
@@ -1450,7 +1703,7 @@ def gather_along_first_dim(
return
out
,
None
return
out
,
None
# Dequantize quantized tensor if not supported
# Dequantize quantized tensor if not supported
if
isinstance
(
inp
,
QuantizedTensor
):
if
isinstance
(
inp
,
QuantizedTensor
Storage
):
warnings
.
warn
(
warnings
.
warn
(
"Attempting to all-gather an unsupported quantized tensor. "
"Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather."
"Falling back to high-precision all-gather."
...
@@ -1720,7 +1973,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
...
@@ -1720,7 +1973,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
if
hasattr
(
fsdp_root
,
"primary_weights_in_fp8"
):
if
hasattr
(
fsdp_root
,
"primary_weights_in_fp8"
):
assert
not
fsdp_root
.
primary_weights_in_fp8
,
(
assert
not
fsdp_root
.
primary_weights_in_fp8
,
(
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.
fp8
_model_init(...) context."
"Please initialize your model without the te.
quantized
_model_init(...) context."
)
)
root_state
=
_get_module_fsdp_state
(
fsdp_root
)
root_state
=
_get_module_fsdp_state
(
fsdp_root
)
assert
root_state
is
not
None
,
"Root module does not have a valid _FSDPState."
assert
root_state
is
not
None
,
"Root module does not have a valid _FSDPState."
...
@@ -1733,7 +1986,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
...
@@ -1733,7 +1986,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
if
hasattr
(
fsdp_module
.
module
,
"primary_weights_in_fp8"
):
if
hasattr
(
fsdp_module
.
module
,
"primary_weights_in_fp8"
):
assert
not
fsdp_module
.
module
.
primary_weights_in_fp8
,
(
assert
not
fsdp_module
.
module
.
primary_weights_in_fp8
,
(
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.
fp8
_model_init(...) context."
"Please initialize your model without the te.
quantized
_model_init(...) context."
)
)
setattr
(
fsdp_module
.
module
,
"fsdp_group"
,
state
.
process_group
)
setattr
(
fsdp_module
.
module
,
"fsdp_group"
,
state
.
process_group
)
...
...
transformer_engine/pytorch/
tensor/_intern
al/__init__.py
→
transformer_engine/pytorch/
experiment
al/__init__.py
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""Internal data structures for quantized tensors."""
"""Experimental features and APIs."""
transformer_engine/pytorch/experimental/gemm.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""GEMM API for experimental middleware between Transformer Engine and Kitchen."""
from
typing
import
Iterable
,
Optional
import
torch
from
transformer_engine.pytorch.experimental.quantization
import
(
MMParams
,
GEMMType
,
)
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
from
transformer_engine.pytorch.tensor.utils
import
is_experimental
def
experimental_gemm
(
A
:
QuantizedTensorStorage
,
B
:
QuantizedTensorStorage
,
workspace
:
torch
.
Tensor
,
# pylint: disable=unused-argument
out_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quantization_params
:
Optional
[
Quantizer
]
=
None
,
# pylint: disable=unused-argument
gelu
:
bool
=
False
,
# pylint: disable=unused-argument
gelu_in
:
torch
.
Tensor
=
None
,
# pylint: disable=unused-argument
accumulate
:
bool
=
False
,
# pylint: disable=unused-argument
layout
:
str
=
"TN"
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
# pylint: disable=unused-argument
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_split_accumulator
:
bool
=
False
,
grad
:
bool
=
False
,
)
->
Iterable
[
Optional
[
torch
.
Tensor
]]:
"""Dispatch GEMM to quantizer's qgemm method."""
assert
is_experimental
(
A
)
and
is_experimental
(
B
),
"A and B must be experimental tensors"
A
,
B
=
B
,
A
# Determine GEMM type based on grad flag and layout
if
not
grad
:
gemm_type
=
GEMMType
.
FPROP
else
:
if
layout
==
"NN"
:
gemm_type
=
GEMMType
.
DGRAD
elif
layout
==
"NT"
:
gemm_type
=
GEMMType
.
WGRAD
else
:
# Default to FPROP for other layouts
gemm_type
=
GEMMType
.
FPROP
# Extract quantizer from QuantizedTensor to get qgemm logic
# TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B._quantizer?
quantizer
=
None
if
hasattr
(
A
,
"_quantizer"
)
and
A
.
_quantizer
is
not
None
:
quantizer
=
A
.
_quantizer
elif
hasattr
(
B
,
"_quantizer"
)
and
B
.
_quantizer
is
not
None
:
quantizer
=
B
.
_quantizer
else
:
raise
ValueError
(
"No quantizer found in QuantizedTensor objects"
)
# Create MMParams
m_params
=
MMParams
(
out_dtype
=
out_dtype
,
use_split_accumulator
=
use_split_accumulator
,
)
out_dtype
=
A
.
dtype
if
m_params
.
out_dtype
is
None
else
m_params
.
out_dtype
if
gemm_type
==
GEMMType
.
FPROP
:
qx
,
sx
=
A
.
data
,
A
.
scale
qw
,
sw
=
B
.
data
,
B
.
scale
assert
qx
is
not
None
assert
sx
is
not
None
assert
qw
is
not
None
assert
sw
is
not
None
assert
A
.
original_shape
is
not
None
# Call quantizer's qgemm method
result
=
quantizer
.
qgemm
(
qx
,
qw
,
m_params
,
out_dtype
,
sx
,
sw
,
bias
,
gemm_type
=
GEMMType
.
FPROP
,
qresult_x
=
A
,
qresult_w
=
B
,
)
if
len
(
A
.
original_shape
)
>
2
:
# Original input was 3D, so we need to reshape result back to 3D
batch_size
=
A
.
original_shape
[
0
]
seq_len
=
A
.
original_shape
[
1
]
result
=
result
.
view
(
batch_size
,
seq_len
,
result
.
shape
[
-
1
])
elif
gemm_type
==
GEMMType
.
DGRAD
:
qdy
,
sdy
=
A
.
data
,
A
.
scale
qw_t
,
sw_t
=
B
.
data_t
,
B
.
scale_t
assert
qdy
is
not
None
assert
sdy
is
not
None
assert
qw_t
is
not
None
assert
sw_t
is
not
None
result
=
quantizer
.
qgemm
(
qdy
,
qw_t
,
m_params
,
out_dtype
,
sdy
,
sw_t
,
None
,
gemm_type
=
GEMMType
.
DGRAD
,
qresult_x
=
A
,
qresult_w
=
B
,
)
elif
gemm_type
==
GEMMType
.
WGRAD
:
qdy_t
,
sdy_t
=
A
.
data_t
,
A
.
scale_t
qx_t
,
sx_t
=
B
.
data_t
,
B
.
scale_t
assert
qdy_t
is
not
None
assert
sdy_t
is
not
None
assert
qx_t
is
not
None
assert
sx_t
is
not
None
result
=
quantizer
.
qgemm
(
qdy_t
,
qx_t
,
m_params
,
out_dtype
,
sdy_t
,
sx_t
,
None
,
gemm_type
=
GEMMType
.
WGRAD
,
qresult_x
=
A
,
qresult_w
=
B
,
)
# Return in the same format as general_gemm
return
result
,
None
,
None
,
None
transformer_engine/pytorch/experimental/quantization.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Quantization API for experimental middleware between Transformer Engine and Kitchen."""
from
__future__
import
annotations
import
dataclasses
import
enum
import
torch
@
enum
.
unique
class
GEMMType
(
enum
.
Enum
):
"""Type of GEMM operation being performed."""
FPROP
=
"fprop"
DGRAD
=
"dgrad"
WGRAD
=
"wgrad"
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
MMParams
:
"""Matrix multiplication parameters."""
out_dtype
:
torch
.
dtype
|
None
=
None
# Use split accumulator for more accurate FP8 GEMM
use_split_accumulator
:
bool
=
True
transformer_engine/pytorch/experimental/quantization_nvfp4.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""NVFP4 recipe reference implementation."""
import
dataclasses
from
typing
import
Optional
,
Tuple
,
Union
import
torch
from
transformer_engine.pytorch.experimental
import
quantization
from
transformer_engine.pytorch.experimental
import
utils
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
def
nvfp4_ref_rht_2d_quantizer_factory
(
role
):
"""
Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights).
Usage with CustomRecipe and fp8_autocast:
custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory)
with fp8_autocast(fp8_recipe=custom_recipe):
output = model(input)
"""
if
role
==
"linear_input"
:
return
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
True
,
)
if
role
==
"linear_weight"
:
return
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
16
,
16
),
pow_2_scales
=
False
,
with_rht
=
False
,
)
if
role
==
"linear_grad_output"
:
return
NVFP4QuantizerRef
(
dtype
=
utils
.
Fp4Formats
.
E2M1
,
quant_tile_shape
=
(
1
,
16
),
pow_2_scales
=
False
,
with_rht
=
True
,
)
return
None
def
cast_to_fp4x2
(
x
):
"""Quantize a tensor to FP4 E2M1 and store in a byte tensor"""
result
=
torch
.
zeros_like
(
x
,
dtype
=
torch
.
uint8
)
result
[(
x
>=
0.0
)
&
(
x
<=
0.25
)]
=
0
result
[(
x
>
0.25
)
&
(
x
<
0.75
)]
=
1
result
[(
x
>=
0.75
)
&
(
x
<=
1.25
)]
=
2
result
[(
x
>
1.25
)
&
(
x
<
1.75
)]
=
3
result
[(
x
>=
1.75
)
&
(
x
<=
2.5
)]
=
4
result
[(
x
>
2.5
)
&
(
x
<
3.5
)]
=
5
result
[(
x
>=
3.5
)
&
(
x
<=
5.0
)]
=
6
result
[
x
>
5.0
]
=
7
result
[(
x
>=
-
0.25
)
&
(
x
<
-
0.0
)]
=
8
result
[(
x
<
-
0.25
)
&
(
x
>
-
0.75
)]
=
9
result
[(
x
<=
-
0.75
)
&
(
x
>=
-
1.25
)]
=
10
result
[(
x
<
-
1.25
)
&
(
x
>
-
1.75
)]
=
11
result
[(
x
<=
-
1.75
)
&
(
x
>=
-
2.5
)]
=
12
result
[(
x
<
-
2.5
)
&
(
x
>
-
3.5
)]
=
13
result
[(
x
<=
-
3.5
)
&
(
x
>=
-
5.0
)]
=
14
result
[
x
<
-
5.0
]
=
15
return
result
[:,
::
2
]
+
result
[:,
1
::
2
]
*
16
def
cast_from_fp4x2
(
x
,
dq_dtype
):
"""Dequantize FP4 E2M1 tensor that has been represented in a byte tensor"""
fp4_values
=
torch
.
tensor
(
[
0.0
,
0.5
,
1.0
,
1.5
,
2.0
,
3.0
,
4.0
,
6.0
,
-
0.0
,
-
0.5
,
-
1.0
,
-
1.5
,
-
2.0
,
-
3.0
,
-
4.0
,
-
6.0
,
],
device
=
x
.
device
,
dtype
=
dq_dtype
,
)
# Convert to long integers for indexing
second_bit
=
torch
.
div
(
x
,
16
,
rounding_mode
=
"floor"
).
to
(
torch
.
long
)
first_bit
=
(
x
-
second_bit
*
16
).
to
(
torch
.
long
)
# Use the long integers to index fp4_values
first_bit_values
=
fp4_values
[
first_bit
]
second_bit_values
=
fp4_values
[
second_bit
]
result
=
torch
.
zeros
(
(
first_bit_values
.
shape
[
0
],
first_bit_values
.
shape
[
1
]
*
2
),
device
=
x
.
device
,
dtype
=
dq_dtype
,
)
result
[:,
::
2
]
=
first_bit_values
result
[:,
1
::
2
]
=
second_bit_values
return
result
def
cast_to_e8
(
decode_scale
):
"""Cast to a value that is representable in FP8 E8M0.
The result is in FP32, not FP8 E8M0.
"""
max_exponent
=
torch
.
tensor
(
127
,
device
=
decode_scale
.
device
,
dtype
=
torch
.
float32
)
exponent
=
torch
.
ceil
(
torch
.
log2
(
decode_scale
))
exponent
=
torch
.
clamp
(
exponent
,
min
=-
max_exponent
,
max
=
max_exponent
)
return
torch
.
tensor
(
2.0
,
device
=
decode_scale
.
device
,
dtype
=
torch
.
float32
)
**
exponent
def
cast_to_e4m3
(
decode_scale
,
global_amax
):
"""Scale and cast to FP8 E4M3.
decode_scale is actually the encoding scaling factor. global_amax
can be any data tensor and not just the amax.
TODO(etsykunov): Make less unintuitive.
"""
decode_scale
=
decode_scale
*
global_amax
FLOAT8_E4M3_MAX
=
torch
.
tensor
(
448.0
,
device
=
decode_scale
.
device
,
dtype
=
torch
.
float32
)
decode_scale
=
torch
.
clamp
(
decode_scale
,
min
=-
FLOAT8_E4M3_MAX
,
max
=
FLOAT8_E4M3_MAX
)
return
decode_scale
.
to
(
torch
.
float8_e4m3fn
)
def
high_precision_gemm_ref
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
accumulate
:
bool
=
False
,
is_a_transposed
:
bool
=
False
,
is_b_transposed
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
scale_alpha
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""GEMM implementation with unquantized data"""
# Handle transpositions
mat1
,
mat2
=
a
,
b
if
is_a_transposed
:
mat1
=
a
.
T
if
is_b_transposed
:
mat2
=
b
.
T
# Ensure dtype compatibility for torch.addmm
mat1
=
mat1
.
to
(
out_dtype
)
mat2
=
mat2
.
to
(
out_dtype
)
# Determine output shape
y_shape
=
(
mat1
.
size
(
0
),
mat2
.
size
(
1
))
if
bias
is
not
None
:
assert
not
accumulate
,
"Bias is not supported with accumulation"
bias
=
bias
.
to
(
out_dtype
)
# With bias case
if
out_dtype
==
torch
.
float32
:
y_ref
=
torch
.
addmm
(
bias
.
repeat
(
mat1
.
size
(
0
),
1
),
mat1
,
mat2
,
beta
=
1
,
alpha
=
1
)
else
:
y_ref
=
torch
.
addmm
(
bias
,
mat1
,
mat2
,
beta
=
1
,
alpha
=
scale_alpha
)
else
:
# Without bias case
if
accumulate
and
out
is
not
None
:
y_ref
=
out
.
clone
().
to
(
out_dtype
)
else
:
y_ref
=
torch
.
zeros
(
y_shape
,
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
addmm
(
y_ref
,
mat1
,
mat2
,
beta
=
1
,
alpha
=
scale_alpha
,
out
=
y_ref
)
return
y_ref
@
dataclasses
.
dataclass
class
NVFP4TensorRef
(
QuantizedTensorStorage
):
"""NVFP4 tensor for middleware between Transformer Engine and Kitchen.
Custom container to hold quantization result, including quantized tensor, optional
transposed quantized tensor, and corresponding decoding scales.
data: torch.Tensor
the quantized tensor.
scale: torch.Tensor
the decoding scale for the quantized tensor. Shape depends on the scaling granularity.
- if scaling type is PER_TENSOR, it should be a 1D scalar tensor.
data_t: torch.Tensor
the transposed quantized tensor (computed lazily if needed).
scale_t: torch.Tensor
the decoding scale for the transposed quantized tensor.
dtype: torch.dtype
nominal tensor datatype.
device: torch.device
device of the tensor.
quant_dtype: Union[utils.Fp4Formats, torch.dtype]
low precision tensor datatype.
original_shape: Tuple[int, ...]
original shape of the tensor.
_quantizer: Quantizer
Builder class for quantized tensor.
"""
data
:
Optional
[
torch
.
Tensor
]
=
None
scale
:
Optional
[
torch
.
Tensor
]
=
None
data_t
:
Optional
[
torch
.
Tensor
]
=
None
scale_t
:
Optional
[
torch
.
Tensor
]
=
None
global_amax_row
:
Optional
[
torch
.
Tensor
]
=
None
global_amax_col
:
Optional
[
torch
.
Tensor
]
=
None
dtype
:
Optional
[
torch
.
dtype
]
=
None
device
:
Optional
[
torch
.
device
]
=
None
quant_dtype
:
Optional
[
Union
[
utils
.
Fp4Formats
,
torch
.
dtype
]]
=
None
original_shape
:
Optional
[
Tuple
[
int
,
...]]
=
None
_quantizer
:
Optional
[
Quantizer
]
=
None
@
property
def
experimental
(
self
)
->
bool
:
"""Flag to indicate this quantizer is using experimental Kitchen middleware."""
return
True
def
prepare_for_saving
(
self
,
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
QuantizedTensorStorage
]:
"""Prepare the quantization result for saving for backward"""
tensors
=
[
self
.
data
,
self
.
data_t
,
self
.
scale
,
self
.
scale_t
]
self
.
data
=
None
self
.
data_t
=
None
self
.
scale
=
None
self
.
scale_t
=
None
return
tensors
,
self
def
restore_from_saved
(
self
,
tensors
:
list
[
Optional
[
torch
.
Tensor
]]
)
->
list
[
Optional
[
torch
.
Tensor
]]:
"""Restore the quantization result from the saved tensors"""
self
.
data
=
tensors
[
0
]
self
.
data_t
=
tensors
[
1
]
self
.
scale
=
tensors
[
2
]
self
.
scale_t
=
tensors
[
3
]
return
tensors
[
4
:]
# Compatibility
@
property
def
_data
(
self
):
return
self
.
data
@
_data
.
setter
def
_data
(
self
,
value
):
self
.
data
=
value
@
property
def
_scale_inv
(
self
):
return
self
.
scale
@
_scale_inv
.
setter
def
_scale_inv
(
self
,
value
):
self
.
scale
=
value
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
("
f
"dtype=
{
self
.
dtype
}
, "
f
"device=
{
self
.
device
}
, "
f
"quant_dtype=
{
self
.
quant_dtype
}
, "
f
"original_shape=
{
self
.
original_shape
}
"
")"
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
"""Generate or remove quantized data based on provided usage."""
has_data
=
self
.
data
is
not
None
has_data_transpose
=
self
.
data_t
is
not
None
needs_data
=
has_data
needs_data_transpose
=
has_data_transpose
if
rowwise_usage
is
not
None
:
needs_data
=
rowwise_usage
if
columnwise_usage
is
not
None
:
needs_data_transpose
=
columnwise_usage
# Generate data that is required
if
needs_data
and
not
has_data
:
raise
RuntimeError
(
"Cannot generate FP4 data, even from FP4 data transpose"
)
if
needs_data_transpose
and
not
has_data_transpose
:
if
not
has_data
:
raise
RuntimeError
(
"FP4 data is required to generate FP4 data transpose"
)
self
.
_create_transpose
()
# Delete data that is not required
if
not
needs_data
:
self
.
data
=
None
if
not
needs_data_transpose
:
self
.
data_t
=
None
def
_create_transpose
(
self
):
"""Create transposed quantized tensor"""
if
not
self
.
data
.
is_contiguous
():
self
.
data
=
self
.
data
.
contiguous
()
self
.
data_t
=
self
.
data
.
t
().
contiguous
()
self
.
scale_t
=
self
.
scale
def
size
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
"""Return the original tensor shape, not the internal packed data shape.
FP4 quantization packs two 4-bit values into each 8-bit value, which reduces
the second dimension by half. This method returns the logical shape that
users expect, not the internal packed storage shape.
"""
assert
self
.
original_shape
is
not
None
return
torch
.
Size
(
self
.
original_shape
)
def
get_wgrad_sign_vector
()
->
torch
.
Tensor
:
"""Hard-coded signs for Hadamard transform"""
return
torch
.
tensor
(
[
1.0
,
1.0
,
1.0
,
-
1.0
,
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
,
1.0
,
-
1.0
,
1.0
,
-
1.0
,
-
1.0
],
dtype
=
torch
.
float32
,
)
class
NVFP4QuantizerRef
(
Quantizer
):
"""NVFP4 quantizer for middleware between Transformer Engine and Kitchen"""
def
__init__
(
self
,
dtype
:
utils
.
Fp4Formats
,
rowwise
:
bool
=
True
,
columnwise
:
bool
=
True
,
pow_2_scales
:
bool
=
False
,
eps
:
float
=
0.0
,
quant_tile_shape
:
Tuple
[
int
,
int
]
=
(
1
,
16
),
with_rht
:
bool
=
False
,
with_random_sign_mask
:
bool
=
True
,
):
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
internal
=
True
self
.
dtype
=
dtype
self
.
pow_2_scales
=
pow_2_scales
self
.
eps
=
eps
self
.
quant_tile_shape
=
quant_tile_shape
self
.
with_rht
=
with_rht
self
.
with_random_sign_mask
=
with_random_sign_mask
@
property
def
experimental
(
self
)
->
bool
:
"""Flag to indicate this quantizer is using experimental Kitchen middleware"""
return
True
@
staticmethod
def
_build_hadamard_matrix
(
size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
with_random_sign_mask
:
bool
=
True
)
->
torch
.
Tensor
:
"""Construct a Hadamard matrix of given power-of-two size with entries +-1.
Uses Sylvester construction to avoid SciPy dependency.
"""
assert
(
size
&
(
size
-
1
))
==
0
,
"Hadamard size must be a power of two"
h
=
torch
.
ones
((
1
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
while
h
.
shape
[
0
]
<
size
:
h
=
torch
.
cat
(
[
torch
.
cat
([
h
,
h
],
dim
=
1
),
torch
.
cat
([
h
,
-
h
],
dim
=
1
),
],
dim
=
0
,
)
if
with_random_sign_mask
:
sign_mat
=
get_wgrad_sign_vector
().
to
(
device
)
*
torch
.
eye
(
size
,
device
=
device
,
dtype
=
torch
.
float32
)
h
=
sign_mat
@
h
return
h
.
to
(
dtype
)
def
_apply_rht
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Apply randomized Hadamard transform without random signs (reference path).
This matches the reference used in tests: x_reshaped @ (H * (1/sqrt(g))).
"""
# Only apply when enabled
if
not
self
.
with_rht
:
return
x
# RHT dimension equals the quantization tile length (NVFP4 uses 16)
rht_dim
=
self
.
quant_tile_shape
[
1
]
assert
(
x
.
shape
[
-
1
]
%
rht_dim
==
0
),
f
"Inner dimension
{
x
.
shape
[
-
1
]
}
must be divisible by hadamard dimension
{
rht_dim
}
"
# Build H and scale
H
=
self
.
_build_hadamard_matrix
(
rht_dim
,
x
.
device
,
x
.
dtype
,
self
.
with_random_sign_mask
)
scale
=
1.0
/
float
(
rht_dim
)
**
0.5
# Perform blockwise transform along the last dimension
original_shape
=
x
.
shape
x_mat
=
x
.
contiguous
().
view
(
-
1
,
rht_dim
)
# Random sign matrix is identity in this reference (no sign flipping)
transform
=
H
*
scale
out
=
x_mat
@
transform
return
out
.
view
(
original_shape
)
@
staticmethod
def
_recover_swizzled_scales
(
swizzled_scale
:
bool
,
scale
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
block_length
:
int
)
->
torch
.
Tensor
:
if
not
swizzled_scale
:
return
scale
rounded_m
=
utils
.
roundup_div
(
m
,
128
)
*
128
scale_n
=
utils
.
roundup_div
(
n
,
block_length
)
rounded_n
=
utils
.
roundup_div
(
scale_n
,
4
)
*
4
# Recover swizzled scaling factor layout -> linear layout
tmp
=
torch
.
reshape
(
scale
,
(
rounded_m
//
128
,
rounded_n
//
4
,
32
,
4
,
4
))
# after permutation, the layout is [rounded_m // 128, 4, 32, rounded_n // 4, 4]
tmp
=
torch
.
permute
(
tmp
,
(
0
,
3
,
2
,
1
,
4
))
result
=
torch
.
reshape
(
tmp
,
(
rounded_m
,
rounded_n
))
return
result
[:
m
,
:
scale_n
]
@
classmethod
def
_quantize_blockwise_reference
(
cls
,
x
:
torch
.
Tensor
,
global_amax
:
torch
.
Tensor
,
tile_len_x
:
int
,
tile_len_y
:
int
,
*
,
pow_2_scales
:
bool
,
eps
:
float
,
# pylint: disable=unused-argument
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
ndim
==
2
using_2d_quantization
=
tile_len_x
==
16
and
tile_len_y
==
16
m
,
n
=
x
.
shape
# Compute vec_max based on the original x (before reshape)
# For 1D quantization: amax over each row chunk of 16
# For 2D quantization: amax over each 16x16 block, but output shape is still (128, 8, 1), filled with block amax
if
using_2d_quantization
:
# x shape: (128, 128)
x_blocks
=
(
x
.
unfold
(
0
,
tile_len_y
,
tile_len_y
)
.
unfold
(
1
,
tile_len_x
,
tile_len_x
)
.
to
(
torch
.
float32
)
)
# (8, 8, 16, 16)
block_amax
=
torch
.
amax
(
torch
.
abs
(
x_blocks
),
dim
=
(
-
1
,
-
2
))
# (8, 8)
# Now, expand to (128, 8, 1) by repeating each block_amax for 16 rows
vec_max
=
block_amax
.
repeat_interleave
(
tile_len_y
,
dim
=
0
).
unsqueeze
(
-
1
)
# (128, 8, 1)
else
:
# x shape: (128, 128)
x_reshaped
=
x
.
view
(
m
,
n
//
tile_len_x
,
tile_len_x
)
# (128, 8, 16)
vec_max
=
torch
.
amax
(
torch
.
abs
(
x_reshaped
),
dim
=-
1
,
keepdim
=
True
).
to
(
torch
.
float32
)
# (128, 8, 1)
x
=
x
.
view
(
m
,
n
//
tile_len_x
,
tile_len_x
)
FLOAT4_E2M1_MAX
=
torch
.
tensor
(
6.0
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
FLOAT8_E4M3_MAX
=
torch
.
tensor
(
448.0
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
decode_scale
=
torch
.
div
(
vec_max
,
FLOAT4_E2M1_MAX
)
if
pow_2_scales
:
decode_scale
=
cast_to_e8
(
decode_scale
)
encode_scale
=
torch
.
div
(
torch
.
tensor
(
1.0
,
device
=
x
.
device
,
dtype
=
torch
.
float32
),
decode_scale
.
to
(
torch
.
float32
),
)
else
:
global_encode_scale
=
torch
.
div
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
,
global_amax
)
global_encode_scale
=
torch
.
min
(
global_encode_scale
,
torch
.
tensor
(
torch
.
finfo
(
torch
.
float32
).
max
,
device
=
global_encode_scale
.
device
,
dtype
=
torch
.
float32
,
),
)
if
global_encode_scale
==
torch
.
tensor
(
0.0
,
device
=
x
.
device
,
dtype
=
torch
.
float32
):
global_encode_scale
=
torch
.
tensor
(
1.0
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
global_decode_scale
=
torch
.
div
(
1.0
,
global_encode_scale
)
decode_scale
=
decode_scale
*
global_encode_scale
decode_scale
=
torch
.
min
(
decode_scale
,
torch
.
tensor
(
torch
.
finfo
(
torch
.
float32
).
max
,
device
=
decode_scale
.
device
,
dtype
=
torch
.
float32
,
),
)
decode_scale
=
torch
.
clamp
(
decode_scale
,
min
=-
FLOAT8_E4M3_MAX
,
max
=
FLOAT8_E4M3_MAX
)
decode_scale
=
decode_scale
.
to
(
torch
.
float8_e4m3fn
)
encode_scale
=
torch
.
min
(
torch
.
div
(
1.0
,
decode_scale
.
to
(
torch
.
float32
)
*
global_decode_scale
),
torch
.
tensor
(
torch
.
finfo
(
torch
.
float32
).
max
,
device
=
decode_scale
.
device
,
dtype
=
torch
.
float32
,
),
)
scaled_x
=
x
.
to
(
torch
.
float32
)
*
encode_scale
clipped_x
=
torch
.
clamp
(
scaled_x
,
-
FLOAT4_E2M1_MAX
,
FLOAT4_E2M1_MAX
).
reshape
(
m
,
n
)
return
cast_to_fp4x2
(
clipped_x
),
decode_scale
.
squeeze
(
-
1
)
@
staticmethod
def
_pad_tensor
(
tensor
:
torch
.
Tensor
,
row_divisor
:
Optional
[
int
],
col_divisor
:
Optional
[
int
]
)
->
torch
.
Tensor
:
assert
tensor
.
dim
()
==
2
,
"only supports 2D tensors"
M
,
N
=
tensor
.
shape
padding_needed_rows
=
0
padding_needed_cols
=
0
if
row_divisor
is
not
None
and
M
%
row_divisor
!=
0
:
padding_needed_rows
=
row_divisor
-
(
M
%
row_divisor
)
# Check and calculate column padding if col_divisor is provided
if
col_divisor
is
not
None
and
N
%
col_divisor
!=
0
:
padding_needed_cols
=
col_divisor
-
(
N
%
col_divisor
)
# Return original tensor if no padding is needed
if
padding_needed_rows
==
0
and
padding_needed_cols
==
0
:
return
tensor
# pad the tensor
out
=
torch
.
nn
.
functional
.
pad
(
tensor
,
(
0
,
padding_needed_cols
,
0
,
padding_needed_rows
),
mode
=
"constant"
,
value
=
0.0
,
).
contiguous
()
return
out
@
staticmethod
def
_rm_pad_tensor
(
tensor
:
torch
.
Tensor
,
original_size
:
tuple
[
int
,
...])
->
torch
.
Tensor
:
assert
tensor
.
dim
()
==
2
,
"only supports 2D tensors"
M
,
N
=
original_size
out
=
tensor
[:
M
,
:
N
].
contiguous
()
return
out
def
_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
]:
"""
Python implementation of microblock FP4 quantization.
Parameters
----------
tensor : torch.Tensor
Input tensor to quantize (should be 2D)
Returns
-------
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor]
(qx, sx, qx_t, sx_t, global_amax) where:
- qx: quantized data in row-major order (if rowwise_usage), None otherwise
- sx: scale tensor for qx (if rowwise_usage), None otherwise
- qx_t: quantized data in column-major order (if columnwise_usage), None otherwise
- sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise
- global_amax_row, global_amax_col: global amax tensors
"""
if
self
.
pow_2_scales
:
assert
self
.
quant_tile_shape
==
(
1
,
32
,
),
"MXFP4 only supports 1x32 tile shape."
# TODO(etsykunov): Fix bug where global_amax_row and
# global_amax_col are not defined
# global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32)
else
:
assert
self
.
quant_tile_shape
in
(
(
1
,
16
),
(
16
,
16
),
),
"NVFP4 only supports 1x16 or 16x16 tile shape."
# Prepare inputs once so we can reuse for both amax and quantization
# Row-input will always be the original input.
row_input
=
tensor
col_input
=
(
self
.
_apply_rht
(
tensor
.
t
().
contiguous
())
if
self
.
with_rht
else
tensor
.
t
().
contiguous
()
)
# Compute amax for rowwise and columnwise paths separately
global_amax_row
=
torch
.
max
(
torch
.
abs
(
row_input
)).
to
(
torch
.
float32
).
view
(
1
)
global_amax_col
=
(
torch
.
max
(
torch
.
abs
(
col_input
)).
to
(
torch
.
float32
).
view
(
1
)
if
self
.
columnwise_usage
else
global_amax_row
)
transpose_scales
=
False
M
,
N
=
tensor
.
shape
if
self
.
rowwise_usage
:
x_input
=
row_input
x_padded
=
self
.
_pad_tensor
(
x_input
,
row_divisor
=
self
.
quant_tile_shape
[
0
],
col_divisor
=
self
.
quant_tile_shape
[
1
]
)
qx
,
sx
=
self
.
_quantize_blockwise_reference
(
x_padded
,
global_amax_row
,
self
.
quant_tile_shape
[
1
],
self
.
quant_tile_shape
[
0
],
pow_2_scales
=
self
.
pow_2_scales
,
eps
=
self
.
eps
,
)
if
transpose_scales
:
sx
=
sx
.
T
qx
=
self
.
_rm_pad_tensor
(
qx
,
(
M
,
N
//
2
))
else
:
qx
=
None
sx
=
None
if
self
.
columnwise_usage
:
x_t
=
col_input
x_t_padded
=
self
.
_pad_tensor
(
x_t
,
row_divisor
=
self
.
quant_tile_shape
[
0
],
col_divisor
=
self
.
quant_tile_shape
[
1
]
)
qx_t
,
sx_t
=
self
.
_quantize_blockwise_reference
(
x_t_padded
,
global_amax_col
,
self
.
quant_tile_shape
[
1
],
self
.
quant_tile_shape
[
0
],
pow_2_scales
=
self
.
pow_2_scales
,
eps
=
self
.
eps
,
)
qx_t
=
self
.
_rm_pad_tensor
(
qx_t
,
(
N
,
M
//
2
))
if
transpose_scales
:
sx_t
=
sx_t
.
T
else
:
qx_t
=
None
sx_t
=
None
return
qx
,
sx
,
qx_t
,
sx_t
,
global_amax_row
,
global_amax_col
def
quantize
(
self
,
tensor
:
torch
.
Tensor
,
**
kwargs
,
# pylint: disable=unused-argument
)
->
NVFP4TensorRef
:
# sanity checks
assert
tensor
.
dtype
in
utils
.
HIGH_PRECISION_FLOAT_DTYPES
,
"Unsupported input dtype."
# Make it work with 3D tensors
original_shape
=
tensor
.
shape
if
tensor
.
ndim
>
2
:
tensor
=
tensor
.
view
(
-
1
,
tensor
.
shape
[
-
1
])
qx
,
sx
,
qx_t
,
sx_t
,
global_amax_row
,
global_amax_col
=
self
.
_quantize
(
tensor
)
return
NVFP4TensorRef
(
data
=
qx
,
scale
=
sx
,
data_t
=
qx_t
,
scale_t
=
sx_t
,
global_amax_row
=
global_amax_row
,
global_amax_col
=
global_amax_col
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
,
quant_dtype
=
self
.
dtype
,
_quantizer
=
self
,
original_shape
=
original_shape
,
)
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
dst
:
QuantizedTensorStorage
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
QuantizedTensorStorage
:
"""Update the quantized tensor with the given tensor in-place
Parameters
----------
src: torch.Tensor
Source tensor to copy from
dst: QuantizedTensorStorage
Destination QuantizedTensorStorage to update
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
# Handle noop flag
if
noop_flag
is
not
None
and
noop_flag
.
item
()
!=
0
:
return
dst
# Make sure input is in expected format
if
not
src
.
is_contiguous
():
src
=
src
.
contiguous
()
# Store the original shape and reshape for processing
original_shape
=
src
.
shape
if
src
.
ndim
>
2
:
src
=
src
.
view
(
-
1
,
src
.
shape
[
-
1
])
qx
,
sx
,
qx_t
,
sx_t
,
global_amax_row
,
global_amax_col
=
self
.
_quantize
(
src
)
# Update the destination with new data
dst
.
data
=
qx
dst
.
scale
=
sx
dst
.
data_t
=
qx_t
dst
.
scale_t
=
sx_t
dst
.
global_amax_row
=
global_amax_row
dst
.
global_amax_col
=
global_amax_col
dst
.
dtype
=
src
.
dtype
dst
.
quant_dtype
=
self
.
dtype
dst
.
original_shape
=
original_shape
return
dst
@
property
def
supports_allgather_fp8
(
self
)
->
bool
:
"""Whether the tensor data can be all-gathered with an FP8 all-gather.
TODO(etsykunov): Confirm docstring is correct. Also, this API
seems too FP8-specific and should be reconsidered.
"""
return
False
def
transpose_qresult
(
self
,
qresult
:
QuantizedTensorStorage
)
->
QuantizedTensorStorage
:
"""Convert row-wise data to column-wise data (?)
TODO(etsykunov): Confirm docstring is correct.
"""
raise
NotImplementedError
(
"Transpose qresult is not implemented for FP4."
)
@
property
def
supports_dequantize
(
self
)
->
bool
:
"""Whether quantized tensor can converted to high-precision tensor"""
return
False
@
property
def
is_data_t_transposed_in_memory
(
self
)
->
bool
:
"""Whether column-wise data is stored in transposed layout.
TODO(etsykunov): Confirm docstring is correct.
"""
raise
NotImplementedError
(
"Not implemented yet"
)
def
qgemm
(
self
,
qx
:
torch
.
Tensor
,
qw
:
torch
.
Tensor
,
m_params
:
quantization
.
MMParams
,
# pylint: disable=unused-argument
out_dtype
:
torch
.
dtype
,
sx
:
torch
.
Tensor
,
sw
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
out
:
torch
.
Tensor
|
None
=
None
,
accumulate
:
bool
=
False
,
gemm_type
:
quantization
.
GEMMType
=
quantization
.
GEMMType
.
FPROP
,
qresult_x
:
QuantizedTensorStorage
|
None
=
None
,
qresult_w
:
QuantizedTensorStorage
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Python implementation of microblock FP4 GEMM."""
assert
bias
is
None
,
"Bias is implemented for FP4 GEMM."
high_precision_x
=
cast_from_fp4x2
(
qx
,
out_dtype
)
high_precision_w
=
cast_from_fp4x2
(
qw
,
out_dtype
)
if
self
.
pow_2_scales
:
if
sx
.
dtype
==
torch
.
uint8
:
# if scaling factor is stored in uint8 container
sx
=
torch
.
tensor
(
2.0
,
device
=
sx
.
device
,
dtype
=
torch
.
float32
)
**
(
(
sx
.
to
(
torch
.
float32
)
-
torch
.
tensor
(
127
,
device
=
sx
.
device
,
dtype
=
torch
.
float32
)
)
)
sw
=
torch
.
tensor
(
2.0
,
device
=
sw
.
device
,
dtype
=
torch
.
float32
)
**
(
(
sw
.
to
(
torch
.
float32
)
-
torch
.
tensor
(
127
,
device
=
sw
.
device
,
dtype
=
torch
.
float32
)
)
)
else
:
# if scaling factor is torch.float8_e8m0fnu
sx
=
sx
.
to
(
torch
.
float32
)
sw
=
sw
.
to
(
torch
.
float32
)
alpha
=
torch
.
tensor
(
1.0
,
device
=
high_precision_x
.
device
,
dtype
=
torch
.
float32
)
else
:
assert
qresult_x
is
not
None
assert
qresult_w
is
not
None
assert
qresult_x
.
global_amax_row
is
not
None
assert
qresult_w
.
global_amax_col
is
not
None
sx
=
sx
.
to
(
torch
.
float32
)
sw
=
sw
.
to
(
torch
.
float32
)
factor
=
6.0
*
6.0
*
448.0
*
448.0
if
gemm_type
==
quantization
.
GEMMType
.
WGRAD
:
partial_alpha
=
qresult_x
.
global_amax_col
*
qresult_w
.
global_amax_col
else
:
partial_alpha
=
qresult_x
.
global_amax_row
*
qresult_w
.
global_amax_row
alpha
=
torch
.
div
(
partial_alpha
,
factor
).
squeeze
(
-
1
)
M
,
K
=
high_precision_x
.
shape
N
,
K_w
=
high_precision_w
.
shape
assert
K
==
K_w
,
"K dimension mismatch between qx and qw"
assert
K
%
32
==
0
,
"K dimension must be divisible by 32"
assert
N
%
8
==
0
,
"N dimension must be divisible by 8"
block_length
=
32
if
self
.
pow_2_scales
else
16
grid_k
=
K
//
block_length
assert
sx
.
shape
==
(
M
,
K
//
block_length
,
),
f
"sx shape mismatch: expected (
{
M
}
,
{
K
//
block_length
}
), got
{
sx
.
shape
}
"
assert
sw
.
shape
==
(
N
,
K
//
block_length
,
),
f
"sw shape mismatch: expected (
{
N
}
,
{
K
//
block_length
}
), got
{
sw
.
shape
}
"
y
=
torch
.
zeros
(
M
,
N
,
dtype
=
torch
.
float32
,
device
=
qx
.
device
)
# below implementation is to match the FP4 tensor core implementation
# Each output element (i, j) is fp32 accumulation of (K // block_length) inner products
# Each inner product is sx * sw * (1, block_length) x (block_length, 1) with precision in fp32
# Then batch the computation in M, N dimension
for
k
in
range
(
grid_k
):
k_start
=
k
*
block_length
k_end
=
k_start
+
block_length
qx_block
=
high_precision_x
[:,
k_start
:
k_end
].
clone
().
contiguous
()
qw_block
=
high_precision_w
[:,
k_start
:
k_end
].
clone
().
contiguous
()
# Extract scaling factors for the current blocks
sx_block
=
sx
[:,
k
]
sw_block
=
sw
[:,
k
]
y
+=
torch
.
outer
(
sx_block
,
sw_block
)
*
high_precision_gemm_ref
(
qx_block
,
qw_block
,
torch
.
float32
,
is_b_transposed
=
True
)
if
not
self
.
pow_2_scales
and
K
>
0
:
# only apply global scale for NVFP4 and non-empty cases
y
=
alpha
*
y
# accumulation happens at epilogue in float32
if
accumulate
:
assert
out
is
not
None
,
"Output tensor must be provided for accumulation."
y
+=
out
.
to
(
torch
.
float32
)
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
y
=
y
.
to
(
out_dtype
)
return
y
transformer_engine/pytorch/experimental/utils.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility functions for experimental middleware between Transformer Engine and Kitchen."""
import
enum
import
torch
HIGH_PRECISION_FLOAT_DTYPES
=
(
torch
.
float
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
)
class
Fp4Formats
(
enum
.
Enum
):
"""FP4 data format"""
E2M1
=
"e2m1"
def
roundup_div
(
x
:
int
,
y
:
int
)
->
int
:
"""Round up division"""
assert
x
>=
0
assert
y
>
0
return
(
x
+
y
-
1
)
//
y
transformer_engine/pytorch/float8_tensor.py
View file @
063ef88d
...
@@ -4,6 +4,16 @@
...
@@ -4,6 +4,16 @@
"""Tensor class with FP8 data"""
"""Tensor class with FP8 data"""
import
warnings
from
.tensor.float8_tensor
import
Float8Tensor
from
.tensor.float8_tensor
import
Float8Tensor
warnings
.
warn
(
"transformer_engine.pytorch.float8_tensor is deprecated and will be removed"
" in a future release. Float8Tensor should be imported directly through "
"`from transformer_engine.pytorch import Float8Tensor`"
,
DeprecationWarning
,
stacklevel
=
2
,
)
__all__
=
[
"Float8Tensor"
]
__all__
=
[
"Float8Tensor"
]
transformer_engine/pytorch/fp8.py
View file @
063ef88d
...
@@ -2,18 +2,26 @@
...
@@ -2,18 +2,26 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""FP8 utilities for TransformerEngine"""
"""
from
__future__
import
annotations
DEPRECATED in favor of `transformer_engine.pytorch.quantization.py`.
"""
import
abc
# pylint: disable=wrong-import-position,unused-import
import
itertools
import
os
from
contextlib
import
contextmanager
from
collections
import
deque
from
typing
import
Callable
,
List
,
Optional
,
Dict
,
Any
,
Tuple
,
Union
import
torch
import
warnings
import
transformer_engine_torch
as
tex
warnings
.
warn
(
"Using deprecated internal API from Transformer Engine. "
"transformer_engine.pytorch.fp8 will be removed in a "
"future release."
,
DeprecationWarning
,
stacklevel
=
2
,
)
# There are some users indirectly importing these classes
# from fp8.py. This ensure backwards compatibility.
# https://github.com/Lightning-AI/lightning-thunder/pull/2635.
from
transformer_engine.common.recipe
import
(
from
transformer_engine.common.recipe
import
(
Recipe
,
Recipe
,
DelayedScaling
,
DelayedScaling
,
...
@@ -21,1082 +29,43 @@ from transformer_engine.common.recipe import (
...
@@ -21,1082 +29,43 @@ from transformer_engine.common.recipe import (
MXFP8BlockScaling
,
MXFP8BlockScaling
,
Float8CurrentScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
Float8BlockScaling
,
NVFP4BlockScaling
,
CustomRecipe
,
)
)
from
.constants
import
dist_group_type
# Importing each function instead of 'import *' allows us specify '__all__' in
from
.utils
import
get_device_compute_capability
# quantize.py and also makes any newer additions to quantize.py invisible via
from
.jit
import
jit_fuser
# fp8.py so that we don't reinforce importing internal TE functions.
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
.quantization
import
(
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
check_fp8_support
,
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
check_mxfp8_support
,
blockwise_fp8_block_len
=
int
(
os
.
getenv
(
"NVTE_BLOCKWISE_FP8_BLOCK_LEN"
,
"128"
))
check_nvfp4_support
,
check_fp8_block_scaling_support
,
__all__
=
[
"fp8_autocast"
,
"fp8_model_init"
]
check_recipe_support
,
get_default_fp8_recipe
,
if
IS_HIP_EXTENSION
:
get_fp8_torch_dtype
,
from
transformer_engine.pytorch.utils
import
is_K100_AI
,
is_BW
get_fp8_te_dtype
,
get_fp4_te_dtype
,
def
check_fp8_support
()
->
Tuple
[
bool
,
str
]:
get_fp8_max
,
"""Return if fp8 support is available"""
FP8GlobalStateManager
,
if
IS_HIP_EXTENSION
:
fp8_model_init
,
if
(
is_K100_AI
()
or
is_BW
())
and
int8_simulation_fp8
:
fp8_autocast
,
return
True
,
"DCU turn on fp8 simulation with int8"
_update_amax_history
,
else
:
_default_get_amax_and_update_history
,
return
False
,
"DCU not support fp8 for now"
_default_sf_compute
,
else
:
_compute_amax_and_update_history
,
if
get_device_compute_capability
()
>=
(
9
,
0
):
# hopper and above
_compute_scaling_factor
,
return
True
,
""
_amax_and_scale_update
,
if
get_device_compute_capability
()
<
(
8
,
9
):
# pre-ada
split_and_copy
,
return
False
,
"Device compute capability 8.9 or higher required for FP8 execution."
RecipeState
,
if
tex
.
get_cublasLt_version
()
<
120103
:
DelayedScalingRecipeState
,
return
False
,
"CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
Float8CurrentScalingRecipeState
,
if
float
(
torch
.
version
.
cuda
)
<
12.1
:
MXFP8BlockScalingRecipeState
,
return
False
,
"Cuda version 12.1 or higher required for FP8 execution on Ada."
Float8BlockScalingRecipeState
,
return
True
,
""
NVFP4BlockScalingRecipeState
,
CustomRecipeState
,
int8_simulation_fp8
,
def
check_mxfp8_support
()
->
Tuple
[
bool
,
str
]:
int8_simulation_fp8_tensorwise
,
"""Return if fp8 support is available"""
blockwise_fp8_block_len
if
get_device_compute_capability
()
>=
(
12
,
0
):
)
return
False
,
"MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
return
True
,
""
return
False
,
"Device compute capability 10.0 or higher required for MXFP8 execution."
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 block scaling support is available"""
if
IS_HIP_EXTENSION
:
if
is_K100_AI
()
or
is_BW
():
return
True
,
""
else
:
return
False
,
"DCU not support block_scaling fp8 for now"
if
(
get_device_compute_capability
()
>=
(
9
,
0
)
and
get_device_compute_capability
()
<
(
10
,
0
)
and
float
(
torch
.
version
.
cuda
)
>=
12.9
):
return
True
,
""
return
False
,
"FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def
check_recipe_support
(
recipe
:
Recipe
)
->
None
:
"""Check if the given recipe is supported."""
recipe_supported
=
True
unsupported_reason
=
""
if
isinstance
(
recipe
,
(
DelayedScaling
,
Float8CurrentScaling
)):
recipe_supported
,
unsupported_reason
=
check_fp8_support
()
elif
isinstance
(
recipe
,
Float8BlockScaling
):
recipe_supported
,
unsupported_reason
=
check_fp8_block_scaling_support
()
elif
isinstance
(
recipe
,
MXFP8BlockScaling
):
recipe_supported
,
unsupported_reason
=
check_mxfp8_support
()
assert
recipe_supported
,
unsupported_reason
def
get_default_fp8_recipe
()
->
Recipe
:
"""FP8 recipe with default args."""
if
check_mxfp8_support
()[
0
]:
return
MXFP8BlockScaling
()
if
get_device_compute_capability
()
>=
(
12
,
0
):
# This is a temporary restriction until MXFP8 is supported for all gemm layouts.
return
Float8CurrentScaling
()
return
DelayedScaling
()
def
get_fp8_torch_dtype
(
fp8_recipe
:
Recipe
,
fprop_tensor
:
bool
=
True
)
->
torch
.
dtype
:
"""Get fp8 data type according to recipe and tensor"""
if
fp8_recipe
.
fp8_format
==
Format
.
E4M3
or
(
fp8_recipe
.
fp8_format
==
Format
.
HYBRID
and
fprop_tensor
):
return
torch
.
float8_e4m3fn
return
torch
.
float8_e5m2
def
get_fp8_te_dtype
(
fp8_recipe
:
Recipe
,
fprop_tensor
:
bool
=
True
)
->
tex
.
DType
:
"""Get fp8 data type according to recipe and tensor"""
if
fp8_recipe
.
fp8_format
==
Format
.
E4M3
or
(
fp8_recipe
.
fp8_format
==
Format
.
HYBRID
and
fprop_tensor
):
return
tex
.
DType
.
kFloat8E4M3
return
tex
.
DType
.
kFloat8E5M2
def
get_fp8_max
(
fp8_recipe
:
Recipe
,
fprop_tensor
:
bool
=
True
)
->
tex
.
DType
:
"""Get max representible FP8 value."""
if
fp8_recipe
.
fp8_format
==
Format
.
E4M3
or
(
fp8_recipe
.
fp8_format
==
Format
.
HYBRID
and
fprop_tensor
):
return
Format
.
E4M3
.
value
.
max_fwd
return
Format
.
E5M2
.
value
.
max_fwd
class
FP8GlobalStateManager
:
"""Class to keep track of and manipulate the global
FP8 state at different stages of execution.
"""
FP8_ENABLED
=
False
FP8_CALIBRATION
=
False
FP8_RECIPE
=
None
FP8_DISTRIBUTED_GROUP
=
None
FP8_PARAMETERS
=
False
HIGH_PRECISION_INIT_VAL
=
False
IS_FIRST_FP8_MODULE
=
False
FP8_GRAPH_CAPTURING
=
False
FP8_AUTOCAST_DEPTH
=
0
global_amax_buffer
=
{}
global_amax_history_buffer
=
{}
global_scale_buffer
=
{}
fp8_tensors_recompute_buffer
=
[]
fp8_available
=
None
reason_for_no_fp8
=
""
autocast_arguments
=
{}
autocast_to_fp8_params
=
{}
fp8_param_to_autocast
=
{}
skip_fp8_weight_update_tensor
=
None
mxfp8_available
=
None
reason_for_no_mxfp8
=
""
fp8_block_scaling_available
=
None
reason_for_no_fp8_block_scaling
=
None
@
classmethod
def
reset
(
cls
)
->
None
:
"""Reset the global state"""
cls
.
FP8_ENABLED
=
False
cls
.
FP8_CALIBRATION
=
False
cls
.
FP8_RECIPE
=
None
cls
.
FP8_DISTRIBUTED_GROUP
=
None
cls
.
FP8_PARAMETERS
=
False
cls
.
HIGH_PRECISION_INIT_VAL
=
False
cls
.
IS_FIRST_FP8_MODULE
=
False
cls
.
FP8_GRAPH_CAPTURING
=
False
cls
.
FP8_AUTOCAST_DEPTH
=
0
cls
.
global_amax_buffer
=
{}
cls
.
global_amax_history_buffer
=
{}
cls
.
global_scale_buffer
=
{}
cls
.
fp8_tensors_recompute_buffer
=
[]
cls
.
fp8_available
=
None
cls
.
reason_for_no_fp8
=
""
cls
.
autocast_arguments
=
{}
cls
.
autocast_to_fp8_params
=
{}
cls
.
fp8_param_to_autocast
=
{}
cls
.
skip_fp8_weight_update_tensor
=
None
cls
.
mxfp8_available
=
None
cls
.
reason_for_no_mxfp8
=
""
cls
.
fp8_block_scaling_available
=
None
cls
.
reason_for_no_fp8_block_scaling
=
""
@
classmethod
def
set_skip_fp8_weight_update_tensor
(
cls
,
skip
:
bool
)
->
None
:
"""`skip_fp8_weight_update_tensor` inplace setter."""
if
cls
.
skip_fp8_weight_update_tensor
is
None
:
cls
.
skip_fp8_weight_update_tensor
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
cls
.
skip_fp8_weight_update_tensor
.
fill_
(
skip
)
@
classmethod
def
get_skip_fp8_weight_update_tensor
(
cls
)
->
None
:
"""`skip_fp8_weight_update_tensor` getter."""
return
cls
.
skip_fp8_weight_update_tensor
@
classmethod
def
is_fp8_available
(
cls
)
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
if
cls
.
fp8_available
is
None
:
cls
.
fp8_available
,
cls
.
reason_for_no_fp8
=
check_fp8_support
()
return
cls
.
fp8_available
,
cls
.
reason_for_no_fp8
@
classmethod
def
is_mxfp8_available
(
cls
)
->
Tuple
[
bool
,
str
]:
"""Return if MXFP8/current scaling support is available."""
if
cls
.
mxfp8_available
is
None
:
cls
.
mxfp8_available
,
cls
.
reason_for_no_mxfp8
=
check_mxfp8_support
()
return
cls
.
mxfp8_available
,
cls
.
reason_for_no_mxfp8
@
classmethod
def
is_fp8_block_scaling_available
(
cls
)
->
Tuple
[
bool
,
str
]:
"""Return if Float8 block scaling support is available."""
if
cls
.
fp8_block_scaling_available
is
None
:
cls
.
fp8_block_scaling_available
,
cls
.
reason_for_no_fp8_block_scaling
=
(
check_fp8_block_scaling_support
()
)
return
cls
.
fp8_block_scaling_available
,
cls
.
reason_for_no_fp8_block_scaling
@
staticmethod
def
get_meta_tensor_key
(
forward
:
bool
=
True
)
->
str
:
"""Returns scaling key in `fp8_meta`."""
if
forward
:
return
"scaling_fwd"
return
"scaling_bwd"
@
staticmethod
def
get_fwd_bwd_key
(
forward
:
bool
=
True
)
->
str
:
"""Convert bool `forward` to string."""
return
"forward"
if
forward
else
"backward"
@
classmethod
def
get_buffer_info
(
cls
)
->
str
:
"""
Returns a key for `fp8_meta` that stores the module's index
in the global buffers along with autocast information.
"""
return
"buffer_index_and_autocast_key"
@
classmethod
def
get_key_in_buffer
(
cls
,
forward
:
bool
,
fp8_recipe
:
Recipe
,
fp8_group
:
dist_group_type
,
)
->
str
:
"""Returns a key into the global FP8 buffers."""
autocast_key
=
cls
.
get_unique_autocast_key
(
fp8_recipe
,
fp8_group
)
fwd_bwd_key
=
cls
.
get_fwd_bwd_key
(
forward
)
return
f
"
{
fwd_bwd_key
}
_
{
autocast_key
}
"
@
classmethod
def
split_key_in_buffer
(
cls
,
key
:
str
)
->
Tuple
[
bool
,
str
]:
"""Splits buffer key into relevant parts."""
forward
,
autocast_key
=
key
.
split
(
"_"
,
1
)
forward
=
forward
==
"forward"
return
forward
,
autocast_key
@
classmethod
def
add_fp8_tensors_to_global_buffer
(
cls
,
fp8_meta
:
Dict
[
str
,
Any
],
)
->
None
:
"""
Delayed scaling only.
The amax reduction process happens completely outside the FP8 modules.
To participate in the reduction, the only role played by a module is
to call this function in order to append it's FP8 tensor into a global
buffer. There are 5 global buffers maintained, one each for amax, amax
history, scale, scale-inverse, and non-weight-mask. Each buffer has
keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix
to indicate the type of FP8 tensor, since the forward and backward
reductions happen separately.
Note: For CG capture, this method is called from the graphed
wrapper. For non CG case, it's called from within the module.
"""
# delayed scaling only function, noop for any other recipe
if
not
fp8_meta
[
"recipe"
].
delayed
():
return
# Every module must call this function exactly once since
# the amax tensors are static. Ensures that compatibility
# with non-graphed modules is maintained.
index_in_buffer
=
cls
.
get_buffer_info
()
# Same index for fwd/bwd fp8 tensors.
if
index_in_buffer
in
fp8_meta
:
return
fp8_meta
[
index_in_buffer
]
=
[]
for
forward
in
(
True
,
False
):
fp8_meta_tensor_key
=
cls
.
get_meta_tensor_key
(
forward
=
forward
)
if
fp8_meta_tensor_key
not
in
fp8_meta
:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
key
=
cls
.
get_key_in_buffer
(
forward
,
fp8_meta
[
"recipe"
],
fp8_meta
[
"fp8_group"
])
if
key
not
in
cls
.
global_amax_buffer
:
cls
.
global_amax_buffer
[
key
]
=
[
fp8_meta
[
fp8_meta_tensor_key
].
amax_history
[
0
]]
cls
.
global_amax_history_buffer
[
key
]
=
[
fp8_meta
[
fp8_meta_tensor_key
].
amax_history
]
cls
.
global_scale_buffer
[
key
]
=
[
fp8_meta
[
fp8_meta_tensor_key
].
scale
]
else
:
cls
.
global_amax_buffer
[
key
].
append
(
fp8_meta
[
fp8_meta_tensor_key
].
amax_history
[
0
])
cls
.
global_amax_history_buffer
[
key
].
append
(
fp8_meta
[
fp8_meta_tensor_key
].
amax_history
)
cls
.
global_scale_buffer
[
key
].
append
(
fp8_meta
[
fp8_meta_tensor_key
].
scale
)
fp8_meta
[
index_in_buffer
].
append
(
len
(
cls
.
global_amax_buffer
[
key
])
-
1
)
fp8_meta
[
index_in_buffer
].
append
(
key
)
@
classmethod
def
is_fp8_enabled
(
cls
)
->
bool
:
"""Is FP8 enabled"""
return
cls
.
FP8_ENABLED
@
classmethod
def
is_fp8_calibration
(
cls
)
->
bool
:
"""Is FP8 calibration"""
return
cls
.
FP8_CALIBRATION
@
classmethod
def
with_fp8_parameters
(
cls
)
->
bool
:
"""Should the parameters be stored as FP8"""
return
cls
.
FP8_PARAMETERS
@
classmethod
def
with_high_precision_init_val
(
cls
)
->
bool
:
"""Should the high precision initial values be stored with FP8 parameters"""
return
cls
.
HIGH_PRECISION_INIT_VAL
@
classmethod
def
fp8_graph_capturing
(
cls
)
->
bool
:
"""Is CUDA graph capture under way?"""
return
cls
.
FP8_GRAPH_CAPTURING
or
torch
.
cuda
.
is_current_stream_capturing
()
@
classmethod
def
is_first_fp8_module
(
cls
):
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
tmp
=
cls
.
IS_FIRST_FP8_MODULE
cls
.
IS_FIRST_FP8_MODULE
=
False
return
tmp
@
classmethod
def
get_fp8_recipe
(
cls
)
->
Recipe
:
"""Return the fp8 recipe"""
if
cls
.
FP8_RECIPE
is
not
None
:
return
cls
.
FP8_RECIPE
return
get_default_fp8_recipe
()
@
classmethod
def
get_fp8_group
(
cls
)
->
Union
[
dist_group_type
,
None
]:
"""Return the fp8 group for scale/amax comm"""
return
cls
.
FP8_DISTRIBUTED_GROUP
@
classmethod
def
get_fp8_autocast_state
(
cls
)
->
Tuple
[
bool
,
bool
,
Recipe
,
dist_group_type
,
bool
]:
"""FP8 autocast state getter"""
return
(
cls
.
FP8_ENABLED
,
cls
.
FP8_CALIBRATION
,
cls
.
FP8_RECIPE
,
cls
.
FP8_DISTRIBUTED_GROUP
,
cls
.
IS_FIRST_FP8_MODULE
,
cls
.
FP8_GRAPH_CAPTURING
,
)
@
classmethod
def
set_fp8_autocast_state
(
cls
,
fp8_state
:
Tuple
[
bool
,
bool
,
DelayedScaling
,
dist_group_type
,
bool
]
)
->
None
:
"""FP8 autocast state setter"""
(
cls
.
FP8_ENABLED
,
cls
.
FP8_CALIBRATION
,
cls
.
FP8_RECIPE
,
cls
.
FP8_DISTRIBUTED_GROUP
,
cls
.
IS_FIRST_FP8_MODULE
,
cls
.
FP8_GRAPH_CAPTURING
,
)
=
fp8_state
@
staticmethod
def
reduce_tensor_across_group_op_max
(
tensor
:
torch
.
Tensor
,
group
:
dist_group_type
)
->
None
:
"""Reduce tensor across given group."""
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
group
,
async_op
=
False
,
)
@
classmethod
def
reduce_and_update_fp8_tensors
(
cls
,
forward
:
bool
=
True
,
)
->
None
:
"""Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer."""
# global_amax_buffer should only be non-empty for fp8 delayed scaling
for
buffer_key
,
amax_buffer
in
cls
.
global_amax_buffer
.
items
():
# Check for forward or backward reduction.
fwd_update
,
autocast_key
=
cls
.
split_key_in_buffer
(
buffer_key
)
if
fwd_update
!=
forward
:
continue
if
len
(
amax_buffer
)
==
0
:
continue
# Retrieve autocast specific args and concat amaxes.
recipe
,
group
=
cls
.
autocast_arguments
[
autocast_key
]
contiguous_amax
=
torch
.
cat
(
amax_buffer
)
# Reduction.
if
(
recipe
.
reduce_amax
and
torch
.
distributed
.
is_initialized
()
and
torch
.
distributed
.
get_world_size
(
group
=
group
)
>
1
):
cls
.
reduce_tensor_across_group_op_max
(
contiguous_amax
,
group
)
# Amax and scale update.
unfused_update
=
(
bool
(
int
(
os
.
getenv
(
"NVTE_UNFUSED_FP8_UPDATE"
,
"0"
)))
or
callable
(
recipe
.
amax_compute_algo
)
or
callable
(
recipe
.
scaling_factor_compute_algo
)
)
if
not
unfused_update
:
tex
.
fused_amax_and_scale_update_after_reduction
(
contiguous_amax
,
cls
.
global_amax_history_buffer
[
buffer_key
],
cls
.
global_scale_buffer
[
buffer_key
],
recipe
.
amax_compute_algo
,
get_fp8_te_dtype
(
recipe
,
forward
),
recipe
.
margin
,
)
else
:
split_and_copy
(
contiguous_amax
,
amax_buffer
,
[
x
.
numel
()
for
x
in
amax_buffer
])
for
amax_history
,
scale
in
zip
(
cls
.
global_amax_history_buffer
[
buffer_key
],
cls
.
global_scale_buffer
[
buffer_key
],
):
_amax_and_scale_update
(
amax_history
,
scale
,
get_fp8_max
(
recipe
,
forward
),
recipe
)
@
classmethod
def
get_unique_autocast_key
(
cls
,
recipe
:
Optional
[
Recipe
]
=
None
,
group
:
Optional
[
dist_group_type
]
=
None
,
):
"""
For FP8, each autocast can be uniquely identified by the recipe and fp8 group.
Safely using `hash` as we never cross checkpoint boundaries.
"""
return
f
"
{
str
(
recipe
)
}
:
{
hash
(
group
)
}
"
@
classmethod
def
fp8_autocast_enter
(
cls
,
enabled
:
bool
=
False
,
calibrating
:
bool
=
False
,
fp8_recipe
:
Optional
[
Recipe
]
=
None
,
fp8_group
:
Optional
[
dist_group_type
]
=
None
,
_graph
:
bool
=
False
,
)
->
None
:
"""Set state and tracking variables for entry into FP8 region."""
fp8_recipe
=
get_default_fp8_recipe
()
if
fp8_recipe
is
None
else
fp8_recipe
autocast_key
=
cls
.
get_unique_autocast_key
(
fp8_recipe
,
fp8_group
)
cls
.
autocast_arguments
[
autocast_key
]
=
(
fp8_recipe
,
fp8_group
)
cls
.
FP8_ENABLED
=
enabled
cls
.
FP8_CALIBRATION
=
calibrating
cls
.
FP8_RECIPE
=
fp8_recipe
cls
.
FP8_DISTRIBUTED_GROUP
=
fp8_group
cls
.
FP8_GRAPH_CAPTURING
=
_graph
if
cls
.
FP8_AUTOCAST_DEPTH
==
0
:
cls
.
IS_FIRST_FP8_MODULE
=
True
cls
.
FP8_AUTOCAST_DEPTH
+=
1
if
enabled
:
fp8_available
,
reason_for_no_fp8
=
cls
.
is_fp8_available
()
assert
fp8_available
,
reason_for_no_fp8
if
isinstance
(
fp8_recipe
,
MXFP8BlockScaling
):
mxfp8_available
,
reason_for_no_mxfp8
=
cls
.
is_mxfp8_available
()
assert
mxfp8_available
,
reason_for_no_mxfp8
if
isinstance
(
fp8_recipe
,
Float8BlockScaling
):
fp8_block_available
,
reason_for_no_fp8_block
=
cls
.
is_fp8_block_scaling_available
()
assert
fp8_block_available
,
reason_for_no_fp8_block
@
classmethod
def
fp8_autocast_exit
(
cls
,
enabled
:
bool
,
_graph
:
bool
)
->
None
:
"""Set state and tracking variables for exit from FP8 region."""
cls
.
FP8_AUTOCAST_DEPTH
-=
1
# Reduce only the non-FP8 weight modules here.
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if
enabled
and
cls
.
FP8_AUTOCAST_DEPTH
==
0
and
not
_graph
and
torch
.
is_grad_enabled
():
# delayed scaling only function, for other recipes (current scaling with any granularity),
# this is noop for other recipes because cls.global_amax_buffer is empty list
cls
.
reduce_and_update_fp8_tensors
(
forward
=
True
)
@
classmethod
def
copy_forward_fp8_meta_tensors_for_recompute
(
cls
,
fp8_meta
:
Dict
[
str
,
Any
])
->
None
:
"""Copy the scaling factors and amaxes for recompute forward phase
to ensure both forward steps are numerically same.
"""
# delayed scaling only function, noop for any other recipe
if
not
fp8_meta
[
"recipe"
].
delayed
():
return
buffer_position_key
=
"global_fp8_buffer_pos_fwd_recompute"
to_copy
=
[
fp8_meta
[
"scaling_fwd"
].
amax_history
.
clone
(),
fp8_meta
[
"scaling_fwd"
].
scale
.
clone
(),
]
if
buffer_position_key
in
fp8_meta
:
cls
.
fp8_tensors_recompute_buffer
[
fp8_meta
[
buffer_position_key
]].
append
(
to_copy
)
else
:
if
len
(
cls
.
fp8_tensors_recompute_buffer
)
==
0
:
cls
.
fp8_tensors_recompute_buffer
=
[
deque
()]
else
:
cls
.
fp8_tensors_recompute_buffer
.
append
(
deque
())
cls
.
fp8_tensors_recompute_buffer
[
-
1
].
append
(
to_copy
)
fp8_meta
[
buffer_position_key
]
=
len
(
cls
.
fp8_tensors_recompute_buffer
)
-
1
@
classmethod
def
get_old_fp8_meta_tensors_for_recompute
(
cls
,
fp8_meta
:
Dict
[
str
,
Any
])
->
None
:
"""Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs.
"""
# delayed scaling only function, noop for any other recipe
if
not
fp8_meta
[
"recipe"
].
delayed
():
return
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta
[
"updated_amax_history_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
amax_history
.
clone
()
fp8_meta
[
"updated_scale_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
scale
.
clone
()
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key
=
"global_fp8_buffer_pos_fwd_recompute"
stashed_fp8_meta
=
cls
.
fp8_tensors_recompute_buffer
[
fp8_meta
[
buffer_position_key
]].
popleft
()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta
[
"scaling_fwd"
].
amax_history
.
copy_
(
stashed_fp8_meta
[
0
])
fp8_meta
[
"scaling_fwd"
].
scale
.
copy_
(
stashed_fp8_meta
[
1
])
@
staticmethod
def
restore_fp8_meta_tensors
(
fp8_meta
:
Dict
[
str
,
Any
])
->
None
:
"""Restore latest scaling factors and amaxes after recompute forward run."""
# delayed scaling only function, noop for any other recipe
if
not
fp8_meta
[
"recipe"
].
delayed
():
return
fp8_meta
[
"scaling_fwd"
].
amax_history
.
copy_
(
fp8_meta
[
"updated_amax_history_fwd"
])
fp8_meta
[
"scaling_fwd"
].
scale
.
copy_
(
fp8_meta
[
"updated_scale_fwd"
])
@
contextmanager
def
fp8_model_init
(
enabled
:
bool
=
True
,
recipe
:
Optional
[
Recipe
]
=
None
,
preserve_high_precision_init_val
:
bool
=
False
,
)
->
None
:
"""
Context manager for FP8 initialization of parameters.
Example usage:
.. code-block:: python
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
# Preserving high precision initial value to initialize master weight
with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
model = transformer_engine.pytorch.Linear(768, 768)
master_weight = model.weight.get_high_precision_init_val()
model.weight.clear_high_precision_init_val()
Parameters
----------
enabled: bool, default = `True`
when enabled, Transformer Engine modules created inside this `fp8_model_init`
region will hold only FP8 copies of its parameters, as opposed to the default
behavior where both higher precision and FP8 copies are present. Setting this
option to `True` may result in lower memory consumption and is especially
useful for scenarios like:
* full model training using optimizer with master weights, where the high
precision copies of weights are already present in the optimizer.
* inference, where only the FP8 copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default = `None`
Recipe used to create the parameters. If left to None, it uses the default FP8 recipe.
preserve_high_precision_init_val: bool, default = `False`
when enabled, store the high precision tensor used to initialize FP8 parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
precision tensor. The purpose is that users can use this high-precision copy
to initialize master weights, avoiding the loss of precision that can occur when
using FP8 parameters directly. Note that after the master weights are initialized,
users should call `clear_high_precision_init_val()` to release this CPU memory.
This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters
=
FP8GlobalStateManager
.
FP8_PARAMETERS
_fp8_recipe
=
FP8GlobalStateManager
.
FP8_RECIPE
_high_precision_init_val
=
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
FP8GlobalStateManager
.
FP8_PARAMETERS
=
enabled
FP8GlobalStateManager
.
FP8_RECIPE
=
get_default_fp8_recipe
()
if
recipe
is
None
else
recipe
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
=
preserve_high_precision_init_val
try
:
yield
finally
:
FP8GlobalStateManager
.
FP8_PARAMETERS
=
_fp8_parameters
FP8GlobalStateManager
.
FP8_RECIPE
=
_fp8_recipe
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
=
_high_precision_init_val
@
contextmanager
def
fp8_autocast
(
enabled
:
bool
=
True
,
calibrating
:
bool
=
False
,
fp8_recipe
:
Optional
[
Recipe
]
=
None
,
fp8_group
:
Optional
[
dist_group_type
]
=
None
,
_graph
:
bool
=
False
,
)
->
None
:
"""
Context manager for FP8 usage.
.. code-block:: python
with fp8_autocast(enabled=True):
out = model(inp)
.. note::
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
with shapes where both dimensions are divisible by 16. In terms of the input to the full
Transformer network, this typically requires padding sequence length to be multiple of 16.
.. note::
When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
inside a single `fp8_autocast` region. This is unsupported behavior because the amax
reduction is handled during the exit of the `fp8_autocast` context. Calling the same
module more than once inside an `fp8_autocast` region overrides the amax tensors
before reduction can occur.
Parameters
----------
enabled: bool, default = `True`
whether or not to enable fp8
calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is
useful for saving an inference ready fp8 checkpoint while training
using a higher precision.
fp8_recipe: recipe.Recipe, default = `None`
recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
"""
if
enabled
:
check_recipe_support
(
fp8_recipe
)
fp8_state
=
FP8GlobalStateManager
.
get_fp8_autocast_state
()
FP8GlobalStateManager
.
fp8_autocast_enter
(
enabled
=
enabled
,
calibrating
=
calibrating
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
fp8_group
,
_graph
=
_graph
,
)
try
:
yield
finally
:
FP8GlobalStateManager
.
set_fp8_autocast_state
(
fp8_state
)
FP8GlobalStateManager
.
fp8_autocast_exit
(
enabled
,
_graph
=
_graph
)
def
_update_amax_history
(
amax_history
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Update amax history and set next amax to zero."""
if
amax_history
.
shape
[
0
]
>
1
:
new_amax_history
=
torch
.
roll
(
amax_history
,
-
1
,
0
)
amax_history
.
copy_
(
new_amax_history
)
amax_history
[
0
].
fill_
(
0.0
)
return
amax_history
@
torch
.
jit
.
script
def
_default_get_amax_and_update_history
(
amax_history
:
torch
.
Tensor
,
amax_compute_algo
:
str
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Default function to obtain amax from history."""
if
amax_compute_algo
==
"max"
:
amax
=
torch
.
max
(
amax_history
,
dim
=
0
).
values
else
:
# amax_compute_algo == "most_recent"
amax
=
amax_history
[
0
].
clone
()
amax_history
=
_update_amax_history
(
amax_history
)
return
amax_history
,
amax
@
jit_fuser
def
_default_sf_compute
(
amax
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
fp8_max
:
float
,
margin
:
int
,
_fp32_max
:
float
=
torch
.
finfo
(
torch
.
float32
).
max
,
# finfo not available in jitter
)
->
torch
.
Tensor
:
"""Default function to convert amax to scaling factor.
Computing the scaling factor requires consideration of the following scenarios:
1. amax == 0:
No action is possible, set scale to the previous scale (or 1).
2. 0 < amax < tiny_amax
The amax is too tiny that the scale becomes infinite in FP32.
Set scale = FP32_max
3. tiny_amax <= amax < FP32_max:
Set scale = FP8_max (or scaled_max) / amax
4. When amax == inf or amax == nan:
No action is possible, set scale to the previous scale (or 1).
"""
sf
=
(
fp8_max
/
amax
)
/
(
2
**
margin
)
sf
=
torch
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
torch
.
where
(
torch
.
isfinite
(
amax
),
sf
,
scale
)
sf
=
torch
.
where
(
torch
.
isinf
(
sf
),
torch
.
full_like
(
sf
,
_fp32_max
),
sf
)
scale
.
copy_
(
sf
)
return
scale
def
_compute_amax_and_update_history
(
amax_history
:
torch
.
Tensor
,
amax_compute_algo
:
Union
[
Callable
,
str
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Obtain the amax from the history."""
if
callable
(
amax_compute_algo
):
amax
=
amax_compute_algo
(
amax_history
)
amax_history
=
_update_amax_history
(
amax_history
)
return
amax_history
,
amax
return
_default_get_amax_and_update_history
(
amax_history
,
amax_compute_algo
,
)
def
_compute_scaling_factor
(
amax
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
fp8_max
:
float
,
recipe
:
DelayedScaling
,
)
->
torch
.
Tensor
:
"""Convert amax to scaling factor."""
if
recipe
.
scaling_factor_compute_algo
is
None
:
return
_default_sf_compute
(
amax
,
scale
,
fp8_max
,
recipe
.
margin
,
)
return
recipe
.
scaling_factor_compute_algo
(
amax
,
scale
,
fp8_max
,
recipe
)
def
_amax_and_scale_update
(
amax_history
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
fp8_max
:
float
,
recipe
:
DelayedScaling
,
)
->
None
:
"""Updates FP8 meta tensors."""
new_amax_history
,
amax
=
_compute_amax_and_update_history
(
amax_history
,
recipe
.
amax_compute_algo
,
)
new_scale
=
_compute_scaling_factor
(
amax
,
scale
,
fp8_max
,
recipe
)
scale
.
copy_
(
new_scale
)
amax_history
.
copy_
(
new_amax_history
)
def
split_and_copy
(
buffer
:
torch
.
Tensor
,
outputs
:
List
[
torch
.
Tensor
],
chunk_sizes
:
List
[
int
],
)
->
None
:
"""Split `buffer` by `chunk_sizes` and copy into `outputs`."""
splits
=
buffer
.
split
(
chunk_sizes
)
torch
.
_foreach_copy_
(
outputs
,
splits
)
class
RecipeState
(
abc
.
ABC
):
"""Configuration and state for a quantization recipe.
This is a builder class for quantizers, which are in turn builder
classes for quantized tensors.
This class may pack together the state for multiple quantizers,
which is helpful for applying fused kernels with less overhead.
"""
@
staticmethod
def
create
(
recipe
:
Recipe
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
RecipeState
:
"""Factory method to create the state for a quantization recipe
Parameters
----------
recipe: Recipe
Quantization recipe.
mode: {"forward", "backward"}
Training stage where quantization will be performed.
num_quantizers: int, default = 1
Number of quantizers to create state for.
device: torch.device, default = default CUDA device
Device for quantized tensors.
Returns
-------
RecipeState:
Quantization recipe state.
"""
cls
=
None
if
recipe
.
delayed
():
cls
=
DelayedScalingRecipeState
elif
recipe
.
mxfp8
():
cls
=
MXFP8BlockScalingRecipeState
elif
recipe
.
float8_current_scaling
():
cls
=
Float8CurrentScalingRecipeState
elif
recipe
.
float8_block_scaling
():
cls
=
Float8BlockScalingRecipeState
else
:
raise
ValueError
(
f
"
{
recipe
.
__class__
.
__name__
}
is not supported"
)
return
cls
(
recipe
,
mode
=
mode
,
num_quantizers
=
num_quantizers
,
device
=
device
,
)
@
abc
.
abstractmethod
def
make_quantizers
(
self
)
->
list
:
"""Convert recipe state to quantizers.
Quantizers are builder classes for quantized tensors. They are
typically used to convert a high-precision tensor (e.g. in
FP32 or BF16) into a quantized tensor (e.g. in FP8).
"""
class
DelayedScalingRecipeState
(
RecipeState
):
"""State for FP8 quantization with per-tensor delayed scaling.
Delayed scaling recipe requires a scaling factor (applied when
casting to FP8) and a history of max-abs values ("amax") from
recent FP8 casts for updating the scaling factor. The scale update
is handled externally by `FP8GlobalStateManager`.
"""
recipe
:
DelayedScaling
mode
:
str
dtype
:
tex
.
DType
scale
:
torch
.
Tensor
amax_history
:
torch
.
Tensor
def
__init__
(
self
,
recipe
:
DelayedScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
dtype
=
get_fp8_te_dtype
(
recipe
,
mode
==
"forward"
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
self
.
scale
=
torch
.
ones
(
num_quantizers
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
amax_history
=
torch
.
zeros
(
recipe
.
amax_history_len
,
num_quantizers
,
dtype
=
torch
.
float32
,
device
=
device
,
)
def
make_quantizers
(
self
)
->
list
:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from
.tensor.float8_tensor
import
Float8Quantizer
return
[
Float8Quantizer
(
self
.
scale
[
i
],
self
.
amax_history
[
0
][
i
].
reshape
((
1
,)),
self
.
dtype
)
for
i
in
range
(
self
.
num_quantizers
)
]
class
Float8CurrentScalingRecipeState
(
RecipeState
):
"""Configuration for Per-tensor current scaling quantization.
Per-tensor current quantization does not require state.
"""
recipe
:
Float8CurrentScaling
mode
:
str
dtype
:
tex
.
DType
device
:
torch
.
device
def
__init__
(
self
,
recipe
:
Float8CurrentScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
dtype
=
get_fp8_te_dtype
(
recipe
,
mode
==
"forward"
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
device
def
make_quantizers
(
self
)
->
list
:
from
.tensor.float8_tensor
import
Float8CurrentScalingQuantizer
return
[
Float8CurrentScalingQuantizer
(
self
.
dtype
,
device
=
self
.
device
)
for
i
in
range
(
self
.
num_quantizers
)
]
class
MXFP8BlockScalingRecipeState
(
RecipeState
):
"""Configuration for MXFP8 quantization.
MXFP8 quantization does not require state.
"""
recipe
:
MXFP8BlockScaling
mode
:
str
dtype
:
tex
.
DType
def
__init__
(
self
,
recipe
:
MXFP8BlockScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
dtype
=
get_fp8_te_dtype
(
recipe
,
mode
==
"forward"
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
def
make_quantizers
(
self
)
->
list
:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
return
[
MXFP8Quantizer
(
self
.
dtype
)
for
i
in
range
(
self
.
num_quantizers
)]
class
Float8BlockScalingRecipeState
(
RecipeState
):
"""Configuration for Float8BlockScaling quantization.
Float8BlockScaling quantization does not require state,
but different quantizers use different modes.
"""
recipe
:
Float8BlockScaling
mode
:
str
qx_dtype
:
tex
.
DType
qw_dtype
:
tex
.
DType
qgrad_dtype
:
tex
.
DType
def
__init__
(
self
,
recipe
:
Float8BlockScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
qx_dtype
=
get_fp8_te_dtype
(
recipe
,
True
)
self
.
qw_dtype
=
get_fp8_te_dtype
(
recipe
,
True
)
self
.
qgrad_dtype
=
get_fp8_te_dtype
(
recipe
,
False
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
device
def
make_quantizers
(
self
)
->
list
:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from
.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
if
self
.
mode
==
"forward"
:
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward, and doesn't play nicely with QuantizeOp,
# which is not associated with a GEMM.
assert
self
.
num_quantizers
%
3
==
0
# x, w, output per gemm
return
list
(
itertools
.
chain
.
from_iterable
(
[
[
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qx_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
x_block_scaling_dim
,
),
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qw_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_fwd_weight
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_fwd_weight
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
w_block_scaling_dim
,
),
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qx_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
x_block_scaling_dim
,
),
]
for
_
in
range
(
self
.
num_quantizers
//
3
)
]
)
)
assert
self
.
mode
==
"backward"
,
f
"Unexpected mode
{
self
.
mode
}
"
assert
self
.
num_quantizers
%
2
==
0
# grad_output and grad_input per gemm
return
list
(
itertools
.
chain
.
from_iterable
(
[
[
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qgrad_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
grad_block_scaling_dim
,
),
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qgrad_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
grad_block_scaling_dim
,
),
]
for
_
in
range
(
self
.
num_quantizers
//
2
)
]
)
)
transformer_engine/pytorch/graph.py
View file @
063ef88d
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
import
contextlib
import
contextlib
import
gc
import
gc
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
import
torch
import
torch
...
@@ -15,8 +16,8 @@ from torch._C import _graph_pool_handle
...
@@ -15,8 +16,8 @@ from torch._C import _graph_pool_handle
from
transformer_engine.common.recipe
import
DelayedScaling
,
Recipe
from
transformer_engine.common.recipe
import
DelayedScaling
,
Recipe
from
transformer_engine.pytorch.constants
import
dist_group_type
from
transformer_engine.pytorch.constants
import
dist_group_type
from
.
fp8
import
(
from
.
quantization
import
(
fp8_
autocast
,
autocast
,
FP8GlobalStateManager
,
FP8GlobalStateManager
,
get_default_fp8_recipe
,
get_default_fp8_recipe
,
)
)
...
@@ -84,7 +85,7 @@ def _make_graphed_callables(
...
@@ -84,7 +85,7 @@ def _make_graphed_callables(
sample_args
:
SingleOrTuple
[
Tuple
[
torch
.
Tensor
,
...]],
sample_args
:
SingleOrTuple
[
Tuple
[
torch
.
Tensor
,
...]],
num_warmup_iters
:
int
=
3
,
num_warmup_iters
:
int
=
3
,
allow_unused_input
:
bool
=
False
,
allow_unused_input
:
bool
=
False
,
fp8_weight_caching
:
bool
=
False
,
cache_quantized_params
:
bool
=
False
,
sample_kwargs
:
Optional
[
SingleOrTuple
[
Dict
[
str
,
Any
]]]
=
None
,
sample_kwargs
:
Optional
[
SingleOrTuple
[
Dict
[
str
,
Any
]]]
=
None
,
_order
:
Optional
[
List
[
int
]]
=
None
,
_order
:
Optional
[
List
[
int
]]
=
None
,
_num_layers_per_chunk
:
Optional
[
List
[
int
]]
=
None
,
_num_layers_per_chunk
:
Optional
[
List
[
int
]]
=
None
,
...
@@ -252,7 +253,7 @@ def _make_graphed_callables(
...
@@ -252,7 +253,7 @@ def _make_graphed_callables(
consumed_sample_q
[
sample_keys
].
append
(
per_callable_fwd_idx
)
consumed_sample_q
[
sample_keys
].
append
(
per_callable_fwd_idx
)
fwd_sample_qs
[
m_chunk
]
=
fwd_sample_qs
[
m_chunk
][
num_consumed_samples
:]
fwd_sample_qs
[
m_chunk
]
=
fwd_sample_qs
[
m_chunk
][
num_consumed_samples
:]
if
fp8_weight_caching
:
if
cache_quantized_params
:
# Initialize flag that controls FP8 weight updates
# Initialize flag that controls FP8 weight updates
FP8GlobalStateManager
.
set_skip_fp8_weight_update_tensor
(
False
)
FP8GlobalStateManager
.
set_skip_fp8_weight_update_tensor
(
False
)
...
@@ -687,7 +688,7 @@ def _make_graphed_callables(
...
@@ -687,7 +688,7 @@ def _make_graphed_callables(
# Decide whether to update FP8 weights
# Decide whether to update FP8 weights
skip_fp8_weight_update
=
None
skip_fp8_weight_update
=
None
if
fp8_weight_caching
:
if
cache_quantized_params
:
assert
"is_first_microbatch"
in
user_kwargs
and
isinstance
(
assert
"is_first_microbatch"
in
user_kwargs
and
isinstance
(
user_kwargs
[
"is_first_microbatch"
],
bool
user_kwargs
[
"is_first_microbatch"
],
bool
),
"`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching."
),
"`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching."
...
@@ -796,14 +797,14 @@ def _make_graphed_callables(
...
@@ -796,14 +797,14 @@ def _make_graphed_callables(
def
save_fp8_tensors
(
def
save_fp8_tensors
(
modules
:
Iterable
[
torch
.
nn
.
Module
],
modules
:
Iterable
[
torch
.
nn
.
Module
],
fp8_
recipe
:
Optional
[
Recipe
],
recipe
:
Optional
[
Recipe
],
)
->
Optional
[
List
[
Any
]]:
)
->
Optional
[
List
[
Any
]]:
"""
"""
Returns the FP8 tensors for all modules
Returns the FP8 tensors for all modules
with adjusted amax history sizes.
with adjusted amax history sizes.
"""
"""
if
not
isinstance
(
fp8_
recipe
,
DelayedScaling
):
if
not
isinstance
(
recipe
,
DelayedScaling
):
return
None
return
None
fp8_tensors
=
[]
fp8_tensors
=
[]
...
@@ -812,10 +813,10 @@ def save_fp8_tensors(
...
@@ -812,10 +813,10 @@ def save_fp8_tensors(
module_tensors
=
None
module_tensors
=
None
if
isinstance
(
m
,
TransformerEngineBaseModule
):
if
isinstance
(
m
,
TransformerEngineBaseModule
):
if
m
.
primary_weights_in_fp8
:
if
m
.
primary_weights_in_fp8
:
m
.
adjust_amax_history_length
(
fp8_
recipe
.
amax_history_len
)
m
.
adjust_amax_history_length
(
recipe
.
amax_history_len
)
module_tensors
=
m
.
get_fp8_meta_tensors
()
module_tensors
=
m
.
get_fp8_meta_tensors
()
elif
isinstance
(
m
,
BasicOperation
):
elif
isinstance
(
m
,
BasicOperation
):
m
.
reset_recipe_state
(
recipe
=
fp8_
recipe
)
m
.
reset_recipe_state
(
recipe
=
recipe
)
module_tensors
=
m
.
_save_fp8_metas
()
module_tensors
=
m
.
_save_fp8_metas
()
fp8_tensors
.
append
(
module_tensors
)
fp8_tensors
.
append
(
module_tensors
)
return
fp8_tensors
return
fp8_tensors
...
@@ -850,11 +851,16 @@ def make_graphed_callables(
...
@@ -850,11 +851,16 @@ def make_graphed_callables(
num_warmup_iters
:
int
=
3
,
num_warmup_iters
:
int
=
3
,
allow_unused_input
:
bool
=
False
,
allow_unused_input
:
bool
=
False
,
sample_kwargs
:
Optional
[
SingleOrTuple
[
Dict
[
str
,
Any
]]]
=
None
,
sample_kwargs
:
Optional
[
SingleOrTuple
[
Dict
[
str
,
Any
]]]
=
None
,
fp8_enabled
:
SingleOrTuple
[
bool
]
=
Fals
e
,
fp8_enabled
:
Optional
[
SingleOrTuple
[
bool
]
]
=
Non
e
,
fp8_calibrating
:
bool
=
Fals
e
,
fp8_calibrating
:
Optional
[
bool
]
=
Non
e
,
fp8_recipe
:
Optional
[
Recipe
]
=
None
,
fp8_recipe
:
Optional
[
Recipe
]
=
None
,
fp8_group
:
Optional
[
dist_group_type
]
=
None
,
fp8_group
:
Optional
[
dist_group_type
]
=
None
,
fp8_weight_caching
:
bool
=
False
,
fp8_weight_caching
:
Optional
[
bool
]
=
None
,
enabled
:
Optional
[
SingleOrTuple
[
bool
]]
=
None
,
calibrating
:
Optional
[
bool
]
=
None
,
recipe
:
Optional
[
Recipe
]
=
None
,
amax_reduction_group
:
Optional
[
dist_group_type
]
=
None
,
cache_quantized_params
:
Optional
[
bool
]
=
None
,
_order
:
Optional
[
List
[
int
]]
=
None
,
_order
:
Optional
[
List
[
int
]]
=
None
,
_num_layers_per_chunk
:
Optional
[
List
[
int
]]
=
None
,
_num_layers_per_chunk
:
Optional
[
List
[
int
]]
=
None
,
pool
:
Optional
[
Tuple
[
int
,
...]]
=
None
,
pool
:
Optional
[
Tuple
[
int
,
...]]
=
None
,
...
@@ -870,6 +876,11 @@ def make_graphed_callables(
...
@@ -870,6 +876,11 @@ def make_graphed_callables(
`original PyTorch implementation <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_
`original PyTorch implementation <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_
for more documentation.
for more documentation.
.. warning::
Arguments 'fp8_enabled', 'fp8_calibrating', 'fp8_recipe', 'fp8_group', and 'fp8_weight_caching' are deprecated.
Use arguments 'enabled', 'calibrating', 'recipe', 'amax_reduction_group', and 'cache_quantized_params' instead.
Graphing parameters
Graphing parameters
-------------------
-------------------
modules: (tuple of) callable
modules: (tuple of) callable
...
@@ -894,30 +905,110 @@ def make_graphed_callables(
...
@@ -894,30 +905,110 @@ def make_graphed_callables(
when `_order` is provided. All callables in `modules` are assumed to have
when `_order` is provided. All callables in `modules` are assumed to have
inputs and outputs with the same dtype and shape.
inputs and outputs with the same dtype and shape.
FP8-
related parameters
Quantization
related parameters
----------------------
----------------------
fp8_
enabled: (tuple of) bool, default = `False`
enabled: (tuple of) bool, default = `False`
whether or not to enable
fp8
.
whether or not to enable
low precision quantization (FP8/FP4)
.
If tuple, the length must match the number of modules.
If tuple, the length must match the number of modules.
fp8_
calibrating: bool, default = `False`
calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
calibration mode allows collecting statistics such as amax and scale
data of
fp8
tensors even when executing without
fp8
enabled.
This is
data of
quantized
tensors even when executing without
quantization
enabled.
useful for saving an inference ready
fp8
checkpoint while training
This is
useful for saving an inference ready checkpoint while training
using a higher precision.
using a higher precision.
fp8_
recipe: Recipe, default = `None`
recipe:
recipe.
Recipe, default = `None`
recipe used for
FP8 training
.
recipe used for
low precision quantization
.
fp8
_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
amax_reduction
_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the
fp8
tensors
distributed group over which amaxes for the
quantized
tensors
are reduced at the end of each training step.
are reduced at the end of each training step.
fp8_weight_caching
: bool, default = `False`
cache_quantized_params
: bool, default = `False`
Whether or not to cache
FP8
weights across microbatches. if set to `True`,
Whether or not to cache
quantized
weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward
the `is_first_microbatch` boolean argument must be passed into the forward
method for TransformerEngine modules. When storing primary weights in
FP8
method for TransformerEngine modules. When storing primary weights in
low precision
using TE's `
fp8
_model_init` API and using an
FP8
aware optimizer,
this arg
using TE's `
quantized
_model_init` API and using an
quantization
aware optimizer,
must be set to `False` if calculating weight transposes' outside TE, e.g.,
this arg
must be set to `False` if calculating weight transposes' outside TE, e.g.,
in the optimizer step.
in the optimizer step.
"""
"""
# Handle deprecated args. If old kwargs are set, they are prioritized with warning.
if
fp8_enabled
is
not
None
:
if
enabled
is
not
None
:
raise
ValueError
(
"make_graphed_callables has deprecated `fp8_enabled` kwarg "
"in favor of `enabled`, but both kwargs are set."
)
warnings
.
warn
(
"make_graphed_callables has deprecated `fp8_enabled` kwarg in favor of `enabled`. "
"`fp8_enabled` will be removed in a future release."
,
category
=
DeprecationWarning
,
stacklevel
=
2
,
)
enabled
=
fp8_enabled
if
enabled
is
None
:
enabled
=
False
if
fp8_calibrating
is
not
None
:
if
calibrating
is
not
None
:
raise
ValueError
(
"make_graphed_callables has deprecated `fp8_calibrating` kwarg "
"in favor of `calibrating`, but both kwargs are set."
)
warnings
.
warn
(
"make_graphed_callables has deprecated `fp8_calibrating` kwarg in favor of "
"`calibrating`. `fp8_calibrating` will be removed in a future release."
,
category
=
DeprecationWarning
,
stacklevel
=
2
,
)
calibrating
=
fp8_calibrating
if
calibrating
is
None
:
calibrating
=
False
if
fp8_recipe
is
not
None
:
if
recipe
is
None
:
warnings
.
warn
(
"make_graphed_callables has deprecated `fp8_recipe` kwarg in favor of "
"`recipe`. `fp8_recipe` will be removed in a future release."
,
category
=
DeprecationWarning
,
stacklevel
=
2
,
)
else
:
raise
ValueError
(
"make_graphed_callables has deprecated `fp8_recipe` kwarg "
"in favor of `recipe`, but both kwargs are set."
)
recipe
=
fp8_recipe
if
fp8_group
is
not
None
:
if
amax_reduction_group
is
None
:
warnings
.
warn
(
"make_graphed_callables has deprecated `fp8_group` kwarg in favor of "
"`amax_reduction_group`. `fp8_group` will be removed in a future release."
,
category
=
DeprecationWarning
,
stacklevel
=
2
,
)
else
:
raise
ValueError
(
"make_graphed_callables has deprecated `fp8_group` kwarg "
"in favor of `amax_reduction_group`, but both kwargs are set."
)
amax_reduction_group
=
fp8_group
if
fp8_weight_caching
is
not
None
:
if
cache_quantized_params
is
not
None
:
raise
ValueError
(
"make_graphed_callables has deprecated `fp8_weight_caching` kwarg "
"in favor of `cache_quantized_params`, but both kwargs are set."
)
warnings
.
warn
(
"make_graphed_callables has deprecated `fp8_weight_caching` kwarg in favor of "
"`cache_quantized_params`. `fp8_weight_caching` will be removed in a future release."
,
category
=
DeprecationWarning
,
stacklevel
=
2
,
)
cache_quantized_params
=
fp8_weight_caching
if
cache_quantized_params
is
None
:
cache_quantized_params
=
False
set_capture_start
()
set_capture_start
()
# Handle single module.
# Handle single module.
...
@@ -926,21 +1017,21 @@ def make_graphed_callables(
...
@@ -926,21 +1017,21 @@ def make_graphed_callables(
just_one_callable
=
True
just_one_callable
=
True
modules
=
(
modules
,)
modules
=
(
modules
,)
if
not
isinstance
(
fp8_
enabled
,
tuple
):
if
not
isinstance
(
enabled
,
tuple
):
assert
isinstance
(
fp8_
enabled
,
bool
),
"
fp8_
enabled must be a bool or a tuple of bools"
assert
isinstance
(
enabled
,
bool
),
"enabled must be a bool or a tuple of bools"
fp8_
enabled
=
(
fp8_
enabled
,)
*
len
(
modules
)
enabled
=
(
enabled
,)
*
len
(
modules
)
else
:
else
:
assert
len
(
fp8_
enabled
)
==
len
(
assert
len
(
enabled
)
==
len
(
modules
modules
),
f
"
fp8_
enabled length (
{
len
(
fp8_
enabled
)
}
) must match modules length (
{
len
(
modules
)
}
)"
),
f
"enabled length (
{
len
(
enabled
)
}
) must match modules length (
{
len
(
modules
)
}
)"
if
any
(
fp8_
enabled
)
and
fp8_
recipe
is
None
:
if
any
(
enabled
)
and
recipe
is
None
:
fp8_
recipe
=
get_default_fp8_recipe
()
recipe
=
get_default_fp8_recipe
()
elif
not
any
(
fp8_
enabled
):
elif
not
any
(
enabled
):
fp8_
recipe
=
None
recipe
=
None
module_uses_fp8
=
dict
(
zip
((
id
(
m
)
for
m
in
modules
),
fp8_
enabled
))
module_uses_fp8
=
dict
(
zip
((
id
(
m
)
for
m
in
modules
),
enabled
))
# Store FP8 tensors to reset later.
# Store FP8 tensors to reset later.
saved_fp8_tensors
=
save_fp8_tensors
(
modules
,
fp8_
recipe
=
fp8_
recipe
)
saved_fp8_tensors
=
save_fp8_tensors
(
modules
,
recipe
=
recipe
)
# FP8 wrapper.
# FP8 wrapper.
old_call_funcs
=
{}
old_call_funcs
=
{}
...
@@ -954,11 +1045,11 @@ def make_graphed_callables(
...
@@ -954,11 +1045,11 @@ def make_graphed_callables(
# Wrap the original call function of the module class.
# Wrap the original call function of the module class.
def
call_func
(
self
,
*
args
,
**
kwargs
):
def
call_func
(
self
,
*
args
,
**
kwargs
):
with
fp8_
autocast
(
with
autocast
(
enabled
=
module_uses_fp8
.
get
(
id
(
self
),
False
),
enabled
=
module_uses_fp8
.
get
(
id
(
self
),
False
),
calibrating
=
fp8_
calibrating
,
calibrating
=
calibrating
,
fp8_
recipe
=
fp8_
recipe
,
recipe
=
recipe
,
fp8_group
=
fp8
_group
,
amax_reduction_group
=
amax_reduction
_group
,
_graph
=
True
,
_graph
=
True
,
):
):
outputs
=
old_call_funcs
[
block_cls
](
self
,
*
args
,
**
kwargs
)
outputs
=
old_call_funcs
[
block_cls
](
self
,
*
args
,
**
kwargs
)
...
@@ -992,7 +1083,7 @@ def make_graphed_callables(
...
@@ -992,7 +1083,7 @@ def make_graphed_callables(
sample_args
,
sample_args
,
num_warmup_iters
=
num_warmup_iters
,
num_warmup_iters
=
num_warmup_iters
,
allow_unused_input
=
allow_unused_input
,
allow_unused_input
=
allow_unused_input
,
fp8_weight_caching
=
fp8_weight_caching
,
cache_quantized_params
=
cache_quantized_params
,
sample_kwargs
=
sample_kwargs
,
sample_kwargs
=
sample_kwargs
,
_order
=
_order
,
_order
=
_order
,
_num_layers_per_chunk
=
_num_layers_per_chunk
,
_num_layers_per_chunk
=
_num_layers_per_chunk
,
...
...
transformer_engine/pytorch/module/_common.py
View file @
063ef88d
...
@@ -4,16 +4,17 @@
...
@@ -4,16 +4,17 @@
"""Internal function used by multiple modules."""
"""Internal function used by multiple modules."""
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
,
Callable
import
dataclasses
from
dataclasses
import
dataclass
import
queue
import
queue
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
..
import
cpp_extensions
as
tex
from
..
import
cpp_extensions
as
tex
from
..constants
import
TE_DType
from
..constants
import
TE_DType
from
..utils
import
get_default_init_method
from
..export
import
is_in_onnx_export_mode
from
..export
import
is_in_onnx_export_mode
from
..utils
import
get_default_init_method
import
warnings
import
warnings
try
:
try
:
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
...
@@ -179,7 +180,7 @@ def noop_cat(
...
@@ -179,7 +180,7 @@ def noop_cat(
return
_NoopCatFunc
.
apply
(
dim
,
*
tensors
)
return
_NoopCatFunc
.
apply
(
dim
,
*
tensors
)
@
dataclass
@
dataclass
es
.
dataclass
class
_ParameterInitMeta
:
class
_ParameterInitMeta
:
"""
"""
Stores essential metadata needed to support deferred parameter initialization.
Stores essential metadata needed to support deferred parameter initialization.
...
...
transformer_engine/pytorch/module/base.py
View file @
063ef88d
...
@@ -22,11 +22,12 @@ import transformer_engine_torch as tex
...
@@ -22,11 +22,12 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
._common
import
_ParameterInitMeta
,
noop_cat
from
._common
import
_ParameterInitMeta
,
noop_cat
from
..
fp8
import
(
from
..
quantization
import
(
MXFP8BlockScalingRecipeState
,
MXFP8BlockScalingRecipeState
,
DelayedScalingRecipeState
,
DelayedScalingRecipeState
,
Float8CurrentScalingRecipeState
,
Float8CurrentScalingRecipeState
,
Float8BlockScalingRecipeState
,
Float8BlockScalingRecipeState
,
NVFP4BlockScalingRecipeState
,
FP8GlobalStateManager
,
FP8GlobalStateManager
,
RecipeState
,
RecipeState
,
)
)
...
@@ -37,14 +38,14 @@ from ..distributed import (
...
@@ -37,14 +38,14 @@ from ..distributed import (
_fsdp_gather_tensors
,
_fsdp_gather_tensors
,
)
)
from
..constants
import
dist_group_type
from
..constants
import
dist_group_type
from
..tensor.quantized_tensor
import
QuantizedTensor
,
QuantizedTensor
Bas
e
,
Quantizer
from
..tensor.quantized_tensor
import
QuantizedTensor
,
QuantizedTensor
Storag
e
,
Quantizer
from
..tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
..tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.
_internal
.float8_tensor_
bas
e
import
Float8Tensor
Bas
e
from
..tensor.
storage
.float8_tensor_
storag
e
import
Float8Tensor
Storag
e
from
..tensor.
_internal
.mxfp8_tensor_
bas
e
import
MXFP8Tensor
Bas
e
from
..tensor.
storage
.mxfp8_tensor_
storag
e
import
MXFP8Tensor
Storag
e
from
..utils
import
is_non_tn_fp8_gemm_supported
,
torch_get_autocast_gpu_dtype
from
..utils
import
is_non_tn_fp8_gemm_supported
,
torch_get_autocast_gpu_dtype
from
..tensor.
_internal
.float8_blockwise_tensor_
bas
e
import
Float8BlockwiseQTensor
Bas
e
from
..tensor.
storage
.float8_blockwise_tensor_
storag
e
import
Float8BlockwiseQTensor
Storag
e
from
...common.recipe
import
DelayedScaling
,
Recipe
from
...common.recipe
import
DelayedScaling
,
Recipe
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
,
DebugQuantizedTensor
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
,
DebugQuantizedTensor
...
@@ -82,7 +83,8 @@ def get_cublas_workspace_size_bytes() -> None:
...
@@ -82,7 +83,8 @@ def get_cublas_workspace_size_bytes() -> None:
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
return
134_217_728
return
134_217_728
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
return
33_554_432
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return
32
*
1024
*
1024
+
1024
return
4_194_304
return
4_194_304
...
@@ -547,7 +549,7 @@ def fill_userbuffers_buffer_for_all_gather(
...
@@ -547,7 +549,7 @@ def fill_userbuffers_buffer_for_all_gather(
local_tensor
:
torch
.
Tensor
,
local_tensor
:
torch
.
Tensor
,
quantizer
:
Optional
[
Quantizer
],
quantizer
:
Optional
[
Quantizer
],
process_group
,
process_group
,
)
->
tuple
[
torch
.
Tensor
|
QuantizedTensor
Bas
e
,
torch
.
Tensor
|
QuantizedTensor
Bas
e
]:
)
->
tuple
[
torch
.
Tensor
|
QuantizedTensor
Storag
e
,
torch
.
Tensor
|
QuantizedTensor
Storag
e
]:
"""Fill local shard of Userbuffers buffer with data for all-gather
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
Returns the full tensor and the local shard, both using the
...
@@ -571,7 +573,7 @@ def fill_userbuffers_buffer_for_all_gather(
...
@@ -571,7 +573,7 @@ def fill_userbuffers_buffer_for_all_gather(
# Unquantized data
# Unquantized data
if
quantizer
is
None
:
if
quantizer
is
None
:
if
isinstance
(
local_tensor
,
QuantizedTensor
Bas
e
):
if
isinstance
(
local_tensor
,
QuantizedTensor
Storag
e
):
local_tensor
=
local_tensor
.
dequantize
()
local_tensor
=
local_tensor
.
dequantize
()
if
comm
.
is_fp8_ubuf
():
if
comm
.
is_fp8_ubuf
():
raise
RuntimeError
(
raise
RuntimeError
(
...
@@ -584,8 +586,8 @@ def fill_userbuffers_buffer_for_all_gather(
...
@@ -584,8 +586,8 @@ def fill_userbuffers_buffer_for_all_gather(
# FP8 data
# FP8 data
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
if
not
isinstance
(
local_tensor
,
Float8Tensor
Bas
e
):
if
not
isinstance
(
local_tensor
,
Float8Tensor
Storag
e
):
if
isinstance
(
local_tensor
,
QuantizedTensor
Bas
e
):
if
isinstance
(
local_tensor
,
QuantizedTensor
Storag
e
):
local_tensor
.
dequantize
()
local_tensor
.
dequantize
()
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
local_tensor
=
quantizer
(
local_tensor
)
local_tensor
=
quantizer
(
local_tensor
)
...
@@ -596,7 +598,7 @@ def fill_userbuffers_buffer_for_all_gather(
...
@@ -596,7 +598,7 @@ def fill_userbuffers_buffer_for_all_gather(
)
)
comm
.
copy_into_buffer
(
local_tensor
.
_data
,
local_chunk
=
True
)
comm
.
copy_into_buffer
(
local_tensor
.
_data
,
local_chunk
=
True
)
global_tensor_data
=
comm
.
get_buffer
(
shape
=
global_shape
)
global_tensor_data
=
comm
.
get_buffer
(
shape
=
global_shape
)
global_tensor
=
Float8Tensor
Bas
e
(
global_tensor
=
Float8Tensor
Storag
e
(
data
=
global_tensor_data
,
data
=
global_tensor_data
,
fp8_scale_inv
=
local_tensor
.
_scale_inv
,
fp8_scale_inv
=
local_tensor
.
_scale_inv
,
fp8_dtype
=
local_tensor
.
_fp8_dtype
,
fp8_dtype
=
local_tensor
.
_fp8_dtype
,
...
@@ -608,8 +610,8 @@ def fill_userbuffers_buffer_for_all_gather(
...
@@ -608,8 +610,8 @@ def fill_userbuffers_buffer_for_all_gather(
if
isinstance
(
quantizer
,
MXFP8Quantizer
):
if
isinstance
(
quantizer
,
MXFP8Quantizer
):
# Cast to MXFP8 if needed
# Cast to MXFP8 if needed
if
not
isinstance
(
local_tensor
,
MXFP8Tensor
Bas
e
):
if
not
isinstance
(
local_tensor
,
MXFP8Tensor
Storag
e
):
if
isinstance
(
local_tensor
,
QuantizedTensor
Bas
e
):
if
isinstance
(
local_tensor
,
QuantizedTensor
Storag
e
):
local_tensor
.
dequantize
()
local_tensor
.
dequantize
()
local_tensor
=
quantizer
(
local_tensor
)
local_tensor
=
quantizer
(
local_tensor
)
if
not
comm
.
is_fp8_ubuf
():
if
not
comm
.
is_fp8_ubuf
():
...
@@ -664,7 +666,7 @@ def fill_userbuffers_buffer_for_all_gather(
...
@@ -664,7 +666,7 @@ def fill_userbuffers_buffer_for_all_gather(
rowwise_data
,
rowwise_scale_inv
=
global_data
,
global_scale_inv
rowwise_data
,
rowwise_scale_inv
=
global_data
,
global_scale_inv
else
:
else
:
columnwise_data
,
columnwise_scale_inv
=
global_data
,
global_scale_inv
columnwise_data
,
columnwise_scale_inv
=
global_data
,
global_scale_inv
global_tensor
=
MXFP8Tensor
Bas
e
(
global_tensor
=
MXFP8Tensor
Storag
e
(
rowwise_data
=
rowwise_data
,
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_data
=
columnwise_data
,
...
@@ -802,6 +804,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -802,6 +804,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe_state
,
Float8BlockScalingRecipeState
recipe_state
,
Float8BlockScalingRecipeState
):
):
return
return
if
recipe
.
nvfp4
()
and
isinstance
(
recipe_state
,
NVFP4BlockScalingRecipeState
):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
# 2 (grad_output and grad_input) for bwd
...
@@ -826,10 +830,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -826,10 +830,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
f
"(
{
len
(
weight_quantizers
)
}
) must match"
f
"(
{
len
(
weight_quantizers
)
}
) must match"
)
)
for
weight
,
quantizer
in
zip
(
weight_tensors
,
weight_quantizers
):
for
weight
,
quantizer
in
zip
(
weight_tensors
,
weight_quantizers
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Bas
e
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Storag
e
):
weight
.
update_quantizer
(
quantizer
)
weight
.
update_quantizer
(
quantizer
)
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Bas
e
]]:
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Storag
e
]]:
"""Get the weight tensors of the module."""
"""Get the weight tensors of the module."""
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement _get_weight_tensors function"
f
"
{
self
.
__class__
.
__name__
}
class does not implement _get_weight_tensors function"
...
@@ -1011,12 +1015,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1011,12 +1015,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return
return
dtype
=
inp
.
dtype
dtype
=
inp
.
dtype
for
name
,
param
in
self
.
named_parameters
():
if
not
self
.
allow_different_data_and_param_types
:
if
param
is
not
None
:
for
name
,
param
in
self
.
named_parameters
():
assert
dtype
==
param
.
dtype
,
(
if
param
is
not
None
:
"Data types for parameters must match when outside of autocasted region. "
assert
dtype
==
param
.
dtype
,
(
f
" Found input dtype:
{
dtype
}
and
{
name
!
r
}
dtype:
{
param
.
dtype
}
"
"Data types for parameters must match when outside of autocasted region. "
)
f
" Found input dtype:
{
dtype
}
and
{
name
!
r
}
dtype:
{
param
.
dtype
}
"
)
self
.
activation_dtype
=
dtype
self
.
activation_dtype
=
dtype
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
...
@@ -1077,8 +1082,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1077,8 +1082,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
fp8_meta
[
"fp8_group"
]
=
FP8GlobalStateManager
.
get_fp8_group
()
self
.
fp8_meta
[
"fp8_group"
]
=
FP8GlobalStateManager
.
get_fp8_group
()
# Set FP8_MAX per tensor according to recipe
# Set FP8_MAX per tensor according to recipe
self
.
fp8_meta
[
"fp8_max_fwd"
]
=
self
.
fp8_meta
[
"recipe"
].
fp8_format
.
value
.
max_fwd
if
hasattr
(
self
.
fp8_meta
[
"recipe"
],
"fp8_format"
):
self
.
fp8_meta
[
"fp8_max_bwd"
]
=
self
.
fp8_meta
[
"recipe"
].
fp8_format
.
value
.
max_bwd
self
.
fp8_meta
[
"fp8_max_fwd"
]
=
self
.
fp8_meta
[
"recipe"
].
fp8_format
.
value
.
max_fwd
self
.
fp8_meta
[
"fp8_max_bwd"
]
=
self
.
fp8_meta
[
"recipe"
].
fp8_format
.
value
.
max_bwd
# Allocate scales and amaxes
# Allocate scales and amaxes
self
.
init_fp8_meta_tensors
(
self
.
fp8_meta
[
"recipe"
])
self
.
init_fp8_meta_tensors
(
self
.
fp8_meta
[
"recipe"
])
...
@@ -1105,6 +1111,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1105,6 +1111,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
num_gemms
:
int
=
1
,
num_gemms
:
int
=
1
,
allow_non_contiguous
:
bool
=
False
,
allow_non_contiguous
:
bool
=
False
,
allow_different_data_and_param_types
:
bool
=
False
,
)
->
Generator
[
torch
.
Tensor
,
None
,
None
]:
)
->
Generator
[
torch
.
Tensor
,
None
,
None
]:
"""Checks and prep for FWD.
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
The context manager is needed because there isn't a way for a module to know
...
@@ -1112,6 +1119,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1112,6 +1119,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
just in case. The autocast exit will pick up the most recent one.
"""
"""
self
.
allow_different_data_and_param_types
=
allow_different_data_and_param_types
self
.
forwarded_at_least_once
=
True
self
.
forwarded_at_least_once
=
True
# Activation recomputation is used and this is the second forward phase.
# Activation recomputation is used and this is the second forward phase.
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
...
@@ -1207,9 +1215,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1207,9 +1215,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output
,
grad_output
,
(
(
QuantizedTensor
,
QuantizedTensor
,
Float8Tensor
Bas
e
,
Float8Tensor
Storag
e
,
MXFP8Tensor
Bas
e
,
MXFP8Tensor
Storag
e
,
Float8BlockwiseQTensor
Bas
e
,
Float8BlockwiseQTensor
Storag
e
,
),
),
):
):
grad_output
=
quantizer
(
grad_output
)
grad_output
=
quantizer
(
grad_output
)
...
@@ -1238,9 +1246,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1238,9 +1246,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_
.
get_tensor
(
True
),
grad_output_
.
get_tensor
(
True
),
(
(
QuantizedTensor
,
QuantizedTensor
,
Float8Tensor
Bas
e
,
Float8Tensor
Storag
e
,
MXFP8Tensor
Bas
e
,
MXFP8Tensor
Storag
e
,
Float8BlockwiseQTensor
Bas
e
,
Float8BlockwiseQTensor
Storag
e
,
),
),
)
)
and
ctx
.
use_bias
and
ctx
.
use_bias
...
@@ -1256,7 +1264,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1256,7 +1264,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
ctx
.
use_bias
:
if
ctx
.
use_bias
:
if
isinstance
(
if
isinstance
(
grad_output
,
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
,
Float8BlockwiseQTensorBase
),
(
QuantizedTensor
,
Float8TensorStorage
,
MXFP8TensorStorage
,
Float8BlockwiseQTensorStorage
,
),
):
):
grad_bias
=
grad_output
.
dequantize
().
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
grad_bias
=
grad_output
.
dequantize
().
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
else
:
else
:
...
@@ -1265,10 +1278,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1265,10 +1278,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
grad_bias
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
else
:
else
:
grad_bias
,
grad_output
=
tex
.
bgrad_quantize
(
grad_output
,
quantizer
)
grad_bias
,
grad_output
=
tex
.
bgrad_quantize
(
grad_output
,
quantizer
)
if
not
isinstance
(
if
not
isinstance
(
grad_output
,
QuantizedTensorStorage
):
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
,
Float8BlockwiseQTensorBase
),
):
grad_output
=
quantizer
(
grad_output
)
grad_output
=
quantizer
(
grad_output
)
return
grad_output
,
grad_bias
return
grad_output
,
grad_bias
...
@@ -1422,14 +1432,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1422,14 +1432,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Reset cache if workspace is invalid
# Reset cache if workspace is invalid
if
out
is
not
None
and
quantizer
is
not
None
:
if
out
is
not
None
and
quantizer
is
not
None
:
reset_cache
=
False
reset_cache
=
False
if
isinstance
(
out
,
Float8Tensor
Bas
e
):
if
isinstance
(
out
,
Float8Tensor
Storag
e
):
if
(
if
(
not
is_non_tn_fp8_gemm_supported
()
not
is_non_tn_fp8_gemm_supported
()
and
quantizer
.
columnwise_usage
and
quantizer
.
columnwise_usage
and
out
.
_transpose
is
None
and
out
.
_transpose
is
None
):
):
reset_cache
=
True
reset_cache
=
True
elif
isinstance
(
out
,
MXFP8Tensor
Bas
e
):
elif
isinstance
(
out
,
MXFP8Tensor
Storag
e
):
if
quantizer
.
rowwise_usage
and
out
.
_rowwise_data
is
None
:
if
quantizer
.
rowwise_usage
and
out
.
_rowwise_data
is
None
:
reset_cache
=
True
reset_cache
=
True
elif
quantizer
.
columnwise_usage
and
out
.
_columnwise_data
is
None
:
elif
quantizer
.
columnwise_usage
and
out
.
_columnwise_data
is
None
:
...
@@ -1609,8 +1619,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1609,8 +1619,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
- MXFP8BlockScaling → MXFP8Tensor
- MXFP8BlockScaling → MXFP8Tensor
- Float8BlockScaling → Float8BlockTensor
- Float8BlockScaling → Float8BlockTensor
Example case to check: recipe is DelayedScaling (DelayedScaling is set in
fp8_
autocast()),
Example case to check: recipe is DelayedScaling (DelayedScaling is set in autocast()),
but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in
fp8
_model_init()).
but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in
quantized
_model_init()).
"""
"""
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
return
return
...
@@ -1620,7 +1630,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1620,7 +1630,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe
=
self
.
fp8_meta
[
"recipe"
]
recipe
=
self
.
fp8_meta
[
"recipe"
]
weight_tensors
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
weight_tensors
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
for
i
,
tensor
in
enumerate
(
weight_tensors
):
for
i
,
tensor
in
enumerate
(
weight_tensors
):
if
isinstance
(
tensor
,
QuantizedTensor
Bas
e
):
if
isinstance
(
tensor
,
QuantizedTensor
Storag
e
):
quantizer
=
tensor
.
_get_quantizer
()
quantizer
=
tensor
.
_get_quantizer
()
if
quantizer
is
None
:
if
quantizer
is
None
:
continue
continue
...
@@ -1631,6 +1641,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1631,6 +1641,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise
RuntimeError
(
raise
RuntimeError
(
f
"Recipe mismatch for '
{
self
.
weight_names
[
i
]
}
': tensor supports recipe"
f
"Recipe mismatch for '
{
self
.
weight_names
[
i
]
}
': tensor supports recipe"
f
"
{
compatible_recipe_class
.
__name__
}
, but got
{
recipe
.
__class__
.
__name__
}
."
f
"
{
compatible_recipe_class
.
__name__
}
, but got
{
recipe
.
__class__
.
__name__
}
."
" Please check the recipes assigned during
fp8
_model_init() and"
" Please check the recipes assigned during
quantized
_model_init() and"
"
fp8_
autocast() calls."
" autocast() calls."
)
)
transformer_engine/pytorch/module/fp8_padding.py
View file @
063ef88d
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
...
...
transformer_engine/pytorch/module/fp8_unpadding.py
View file @
063ef88d
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
...
...
Prev
1
…
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment