Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2255 additions
and
553 deletions
+2255
-553
transformer_engine/pytorch/csrc/extensions/quantizer.cpp
transformer_engine/pytorch/csrc/extensions/quantizer.cpp
+141
-4
transformer_engine/pytorch/csrc/extensions/swizzle.cpp
transformer_engine/pytorch/csrc/extensions/swizzle.cpp
+8
-4
transformer_engine/pytorch/csrc/extensions/transpose.cpp
transformer_engine/pytorch/csrc/extensions/transpose.cpp
+24
-20
transformer_engine/pytorch/csrc/extensions/type_converters.cpp
...former_engine/pytorch/csrc/extensions/type_converters.cpp
+32
-0
transformer_engine/pytorch/csrc/pybind.h
transformer_engine/pytorch/csrc/pybind.h
+18
-1
transformer_engine/pytorch/csrc/util.h
transformer_engine/pytorch/csrc/util.h
+13
-0
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+285
-19
transformer_engine/pytorch/dot_product_attention/inference.py
...sformer_engine/pytorch/dot_product_attention/inference.py
+7
-3
transformer_engine/pytorch/dot_product_attention/rope.py
transformer_engine/pytorch/dot_product_attention/rope.py
+177
-71
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+137
-1
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+77
-0
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+152
-9
transformer_engine/pytorch/module/fp8_padding.py
transformer_engine/pytorch/module/fp8_padding.py
+18
-4
transformer_engine/pytorch/module/fp8_unpadding.py
transformer_engine/pytorch/module/fp8_unpadding.py
+17
-3
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+183
-58
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+234
-90
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+440
-153
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+234
-94
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+29
-1
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+29
-18
No files found.
transformer_engine/pytorch/csrc/extensions/quantizer.cpp
View file @
ab3e5a92
...
...
@@ -109,6 +109,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
}
const
py
::
object
py_columnwise_data
=
create_transpose
?
py
::
cast
(
columnwise_data
)
:
py
::
none
();
opts
=
opts
.
dtype
(
torch
::
kFloat32
);
// TODO: Replace with an empty tensor.
at
::
Tensor
scale_inv
=
at
::
reciprocal
(
scale
);
py
::
object
ret
;
if
(
internal
)
{
...
...
@@ -250,6 +251,140 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
tensor
.
set_columnwise_scale_inv
(
scale_inv
.
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
this
->
set_quantization_params
(
&
tensor
);
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
}
Float8BlockQuantizer
::
Float8BlockQuantizer
(
const
py
::
handle
&
quantizer
)
:
Quantizer
(
quantizer
)
{
this
->
dtype
=
quantizer
.
attr
(
"dtype"
).
cast
<
DType
>
();
this
->
block_scaling_dim
=
quantizer
.
attr
(
"block_scaling_dim"
).
cast
<
int
>
();
this
->
force_pow_2_scales
=
quantizer
.
attr
(
"force_pow_2_scales"
).
cast
<
bool
>
();
this
->
amax_epsilon
=
quantizer
.
attr
(
"amax_epsilon"
).
cast
<
float
>
();
NVTE_CHECK
(
this
->
block_scaling_dim
==
1
||
this
->
block_scaling_dim
==
2
,
"Unsupported block scaling dim."
);
}
void
Float8BlockQuantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{
// Change the rowwise and columnwise_data to the configured dtype.
// May be a switch between E5M2 and E4M3.
auto
rowwise_data
=
tensor
->
get_rowwise_data
();
rowwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
auto
columnwise_data
=
tensor
->
get_columnwise_data
();
columnwise_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
tensor
->
set_rowwise_data
(
rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
rowwise_data
.
dtype
),
rowwise_data
.
shape
);
tensor
->
set_columnwise_data
(
columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
columnwise_data
.
dtype
),
columnwise_data
.
shape
);
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
Float8BlockQuantizer
::
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
)
const
{
using
namespace
pybind11
::
literals
;
std
::
vector
<
int64_t
>
torch_shape
;
size_t
numel
=
1
;
for
(
auto
s
:
shape
)
{
torch_shape
.
emplace_back
(
static_cast
<
int64_t
>
(
s
));
numel
*=
s
;
}
TensorWrapper
tensor
(
this
->
get_scaling_mode
());
at
::
TensorOptions
opts
;
at
::
TensorOptions
scale_opts
;
at
::
Tensor
data_rowwise
,
data_colwise
,
scale_inv_rowwise
,
scale_inv_colwise
;
opts
=
opts
.
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
scale_opts
=
scale_opts
.
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
size_t
k_dim
=
torch_shape
.
size
()
==
0
?
1u
:
torch_shape
.
back
();
size_t
m_dim
=
numel
/
k_dim
;
constexpr
size_t
kBlockLen
=
128
;
if
(
rowwise_usage
)
{
if
(
rowwise_data
.
has_value
())
{
data_rowwise
=
std
::
move
(
*
rowwise_data
);
}
else
{
data_rowwise
=
at
::
empty
(
torch_shape
,
opts
);
}
size_t
sinv0
=
0
;
size_t
sinv1
=
0
;
if
(
block_scaling_dim
==
2
)
{
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
(
m_dim
,
4
);
}
else
{
NVTE_CHECK
(
false
,
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got "
,
block_scaling_dim
);
}
scale_inv_rowwise
=
at
::
empty
({
static_cast
<
int64_t
>
(
sinv0
),
static_cast
<
int64_t
>
(
sinv1
)},
scale_opts
);
tensor
.
set_rowwise_data
(
data_rowwise
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_rowwise_scale_inv
(
scale_inv_rowwise
.
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
sinv0
,
sinv1
});
}
if
(
columnwise_usage
)
{
std
::
vector
<
int64_t
>
torch_columnwise_shape
;
std
::
vector
<
size_t
>
columnwise_shape
;
NVTE_CHECK
(
torch_shape
.
size
()
==
shape
.
size
(),
"Shape expected to match torch shape. Shape "
,
columnwise_shape
,
" torch shape: "
,
torch_columnwise_shape
);
if
(
torch_shape
.
size
()
>
0
)
{
torch_columnwise_shape
.
reserve
(
torch_shape
.
size
());
columnwise_shape
.
reserve
(
shape
.
size
());
torch_columnwise_shape
.
push_back
(
torch_shape
[
torch_shape
.
size
()
-
1
]);
columnwise_shape
.
push_back
(
shape
[
shape
.
size
()
-
1
]);
for
(
size_t
i
=
0
;
i
<
torch_shape
.
size
()
-
1
;
++
i
)
{
torch_columnwise_shape
.
push_back
(
torch_shape
[
i
]);
columnwise_shape
.
push_back
(
shape
[
i
]);
}
}
size_t
sinv0
=
0
;
size_t
sinv1
=
0
;
if
(
block_scaling_dim
==
2
)
{
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
(
k_dim
,
4
);
}
else
{
NVTE_CHECK
(
false
,
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got "
,
block_scaling_dim
);
}
data_colwise
=
at
::
empty
(
torch_columnwise_shape
,
opts
);
scale_inv_colwise
=
at
::
empty
({
static_cast
<
int64_t
>
(
sinv0
),
static_cast
<
int64_t
>
(
sinv1
)},
scale_opts
);
tensor
.
set_columnwise_data
(
data_colwise
.
data_ptr
(),
this
->
dtype
,
columnwise_shape
);
tensor
.
set_columnwise_scale_inv
(
scale_inv_colwise
.
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
sinv0
,
sinv1
});
}
this
->
set_quantization_params
(
&
tensor
);
py
::
object
ret
;
if
(
internal
)
{
py
::
handle
Float8BlockwiseQTensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8BlockwiseQTensorBasePythonClass
));
ret
=
Float8BlockwiseQTensorClass
(
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
));
}
else
{
py
::
handle
Float8BlockwiseQTensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8BlockwiseQTensorPythonClass
));
ret
=
Float8BlockwiseQTensorClass
(
"shape"
_a
=
torch_shape
,
"dtype"
_a
=
GetATenDType
(
dtype
),
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
));
}
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
}
...
...
@@ -302,7 +437,8 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
auto
sinv1
=
roundup
(
last_dim
/
MXFP8_BLOCK_SIZE
,
4
);
rowwise_scale_inv
=
at
::
zeros
({
sinv0
,
sinv1
},
opts
);
tensor
.
set_rowwise_data
(
data
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_rowwise_scale_inv
(
rowwise_scale_inv
.
data_ptr
(),
DType
::
kFloat8E8M0
,
tensor
.
set_rowwise_scale_inv
(
rowwise_scale_inv
.
data_ptr
(),
DType
::
kFloat8E8M0
,
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
sinv0
),
static_cast
<
size_t
>
(
sinv1
)});
}
...
...
@@ -313,7 +449,8 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
columnwise_scale_inv
=
at
::
zeros
({
sinv0
,
sinv1
},
opts
);
tensor
.
set_columnwise_data
(
columnwise_data
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_columnwise_scale_inv
(
columnwise_scale_inv
.
data_ptr
(),
DType
::
kFloat8E8M0
,
tensor
.
set_columnwise_scale_inv
(
columnwise_scale_inv
.
data_ptr
(),
DType
::
kFloat8E8M0
,
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
sinv0
),
static_cast
<
size_t
>
(
sinv1
)});
}
this
->
set_quantization_params
(
&
tensor
);
...
...
transformer_engine/pytorch/csrc/extensions/swizzle.cpp
View file @
ab3e5a92
...
...
@@ -6,14 +6,16 @@
#include "extensions.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
void
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
bool
rowwise
)
{
std
::
optional
<
at
::
Tensor
>
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
bool
rowwise
)
{
using
namespace
transformer_engine
::
pytorch
;
if
(
input
.
scaling_mode
()
==
NVTE_INVALID_SCALING
)
{
NVTE_ERROR
(
"Invalid scaling mode for swizzle."
);
}
else
if
(
input
.
scaling_mode
()
!=
NVTE_MXFP8_1D_SCALING
)
{
return
;
return
std
::
nullopt
;
}
NVTE_CHECK
(
input
.
element_size
()
==
1
,
"8-bit input required for swizzling scaling factors."
);
...
...
@@ -48,9 +50,9 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww
output_cu
.
set_rowwise_data
(
input
.
dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
else
{
input_cu
.
set_columnwise_data
(
input
.
dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_columnwise_data
(
input
.
columnwise_
dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
output_cu
.
set_columnwise_data
(
input
.
dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_columnwise_data
(
input
.
columnwise_
dptr
(),
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
...
...
@@ -63,6 +65,8 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww
}
else
{
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
DType
::
kFloat8E8M0
,
scale_inv_shape
);
}
return
swizzled_scale_inv
;
}
at
::
Tensor
rowwise_swizzle
(
at
::
Tensor
input
,
at
::
Tensor
scale_inv
)
{
...
...
transformer_engine/pytorch/csrc/extensions/transpose.cpp
View file @
ab3e5a92
...
...
@@ -6,27 +6,38 @@
#include <optional>
#include "ATen/core/TensorBody.h"
#include "extensions.h"
#include "pybind.h"
std
::
vector
<
py
::
object
>
fused_multi_quantize
(
std
::
vector
<
py
::
handle
>
input_list
,
std
::
optional
<
std
::
vector
<
py
::
handle
>>
output_list
,
namespace
transformer_engine
::
pytorch
{
std
::
vector
<
py
::
object
>
fused_multi_quantize
(
std
::
vector
<
at
::
Tensor
>
input_list
,
std
::
optional
<
std
::
vector
<
py
::
object
>>
output_list
,
std
::
vector
<
py
::
handle
>
quantizer_list
,
transformer_engine
::
DType
otype
)
{
using
namespace
transformer_engine
::
pytorch
;
init_extension
()
;
std
::
vector
<
NVTETensor
>
nvte_tensor_input_list
;
std
::
vector
<
NVTETensor
>
nvte_tensor_output_list
;
std
::
vector
<
py
::
object
>
py_output_objects_list
;
std
::
vector
<
transformer_engine
::
TensorWrapper
>
tensor_wrappers
;
auto
none
=
py
::
none
();
if
(
output_list
.
has_value
())
{
py_output_objects_list
=
output_list
.
value
();
}
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool
with_fused_kernel
=
true
;
// create TE tensors from input
for
(
in
t
i
=
0
;
i
<
input_list
.
size
();
i
++
)
{
auto
input_tensor
=
makeTransformerEngineTensor
(
input_list
[
i
]
,
none
);
for
(
size_
t
i
=
0
;
i
<
input_list
.
size
();
i
++
)
{
auto
input_tensor
=
makeTransformerEngineTensor
(
input_list
[
i
]);
const
NVTEShape
input_shape
=
input_tensor
.
shape
();
transformer_engine
::
TensorWrapper
output_tensor
;
if
(
!
detail
::
IsFloat8Quantizers
(
quantizer_list
[
i
].
ptr
()))
{
with_fused_kernel
=
false
;
}
if
(
output_list
==
std
::
nullopt
)
{
std
::
unique_ptr
<
Quantizer
>
quantizer
=
convert_quantizer
(
quantizer_list
[
i
]);
std
::
vector
<
size_t
>
output_shape
(
input_shape
.
data
,
input_shape
.
data
+
input_shape
.
ndim
);
...
...
@@ -48,16 +59,8 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
NVTE_CHECK
(
nvte_tensor_output_list
.
size
()
==
nvte_tensor_input_list
.
size
(),
"Number of input and output tensors must match"
);
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool
with_fused_kernel
=
true
;
for
(
size_t
i
=
0
;
i
<
nvte_tensor_output_list
.
size
();
i
++
)
{
const
auto
&
tensor
=
nvte_tensor_output_list
[
i
];
if
(
nvte_tensor_scaling_mode
(
tensor
)
!=
NVTE_DELAYED_TENSOR_SCALING
)
{
with_fused_kernel
=
false
;
break
;
}
if
(
nvte_tensor_columnwise_data
(
tensor
)
==
nullptr
)
{
if
(
nvte_tensor_columnwise_data
(
nvte_tensor_output_list
[
i
])
==
nullptr
)
{
with_fused_kernel
=
false
;
break
;
}
...
...
@@ -68,9 +71,8 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
nvte_multi_cast_transpose
(
nvte_tensor_input_list
.
size
(),
nvte_tensor_input_list
.
data
(),
nvte_tensor_output_list
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
else
{
for
(
size_t
i
=
0
;
i
<
nvte_tensor_output_list
.
size
();
i
++
)
{
nvte_quantize
(
nvte_tensor_input_list
[
i
],
nvte_tensor_output_list
[
i
],
at
::
cuda
::
getCurrentCUDAStream
());
for
(
size_t
i
=
0
;
i
<
py_output_objects_list
.
size
();
i
++
)
{
quantize
(
input_list
[
i
],
quantizer_list
[
i
],
py_output_objects_list
[
i
],
std
::
nullopt
);
}
}
return
py_output_objects_list
;
...
...
@@ -78,7 +80,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
transformer_engine
::
DType
otype
,
std
::
optional
<
at
::
Tensor
>
output
)
{
using
namespace
transformer_engine
::
pytorch
;
init_extension
()
;
const
auto
dim
=
input
.
dim
();
NVTE_CHECK
(
dim
>=
2
,
"Need at least 2D tensor to transpose."
);
...
...
@@ -105,3 +107,5 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
return
out
;
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/type_converters.cpp
View file @
ab3e5a92
...
...
@@ -84,6 +84,38 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer)
return
ret
;
}
TensorWrapper
NVTETensorFromFloat8BlockwiseQTensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
)
{
const
DType
dtype
=
tensor
.
attr
(
"_fp8_dtype"
).
cast
<
DType
>
();
bool
is_2D_scaled
=
tensor
.
attr
(
"_is_2D_scaled"
).
cast
<
bool
>
();
bool
rowwise_usage
=
!
(
tensor
.
attr
(
"_rowwise_data"
).
is_none
());
bool
columnwise_usage
=
!
(
tensor
.
attr
(
"_columnwise_data"
).
is_none
());
auto
ret
=
TensorWrapper
(
is_2D_scaled
?
NVTE_BLOCK_SCALING_2D
:
NVTE_BLOCK_SCALING_1D
);
if
(
rowwise_usage
)
{
const
at
::
Tensor
&
data_rowwise
=
tensor
.
attr
(
"_rowwise_data"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
scale_inv_rowwise
=
tensor
.
attr
(
"_rowwise_scale_inv"
).
cast
<
at
::
Tensor
>
();
void
*
scale_inv_rowwise_dptr
=
scale_inv_rowwise
.
data_ptr
();
const
auto
&
rowwise_shape
=
getTensorShape
(
data_rowwise
);
ret
.
set_rowwise_data
(
data_rowwise
.
data_ptr
(),
dtype
,
rowwise_shape
);
const
auto
scale_inv_rowwise_shape
=
getTensorShape
(
scale_inv_rowwise
);
ret
.
set_rowwise_scale_inv
(
scale_inv_rowwise_dptr
,
DType
::
kFloat32
,
scale_inv_rowwise_shape
);
}
if
(
columnwise_usage
)
{
const
at
::
Tensor
&
data_colwise
=
tensor
.
attr
(
"_columnwise_data"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
scale_inv_colwise
=
tensor
.
attr
(
"_columnwise_scale_inv"
).
cast
<
at
::
Tensor
>
();
void
*
scale_inv_colwise_dptr
=
scale_inv_colwise
.
data_ptr
();
const
auto
&
shape
=
getTensorShape
(
data_colwise
);
ret
.
set_columnwise_data
(
data_colwise
.
data_ptr
(),
dtype
,
shape
);
const
auto
scale_inv_colwise_shape
=
getTensorShape
(
scale_inv_colwise
);
ret
.
set_columnwise_scale_inv
(
scale_inv_colwise_dptr
,
DType
::
kFloat32
,
scale_inv_colwise_shape
);
}
quantizer
->
set_quantization_params
(
&
ret
);
return
ret
;
}
}
// namespace detail
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/pybind.h
View file @
ab3e5a92
...
...
@@ -25,6 +25,9 @@ extern PyTypeObject *Float8CurrentScalingQuantizerClass;
extern
PyTypeObject
*
MXFP8TensorPythonClass
;
extern
PyTypeObject
*
MXFP8TensorBasePythonClass
;
extern
PyTypeObject
*
MXFP8QuantizerClass
;
extern
PyTypeObject
*
Float8BlockwiseQTensorPythonClass
;
extern
PyTypeObject
*
Float8BlockwiseQTensorBasePythonClass
;
extern
PyTypeObject
*
Float8BlockwiseQuantizerClass
;
void
init_extension
();
...
...
@@ -50,6 +53,15 @@ inline bool IsMXFP8Tensor(PyObject *obj) {
return
Py_TYPE
(
obj
)
==
MXFP8TensorPythonClass
||
Py_TYPE
(
obj
)
==
MXFP8TensorBasePythonClass
;
}
inline
bool
IsFloat8BlockwiseQuantizers
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
Float8BlockwiseQuantizerClass
;
}
inline
bool
IsFloat8BlockwiseQTensor
(
PyObject
*
obj
)
{
return
Py_TYPE
(
obj
)
==
Float8BlockwiseQTensorPythonClass
||
Py_TYPE
(
obj
)
==
Float8BlockwiseQTensorBasePythonClass
;
}
TensorWrapper
NVTETensorFromFloat8Tensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
);
template
<
typename
T
>
...
...
@@ -61,6 +73,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizati
std
::
unique_ptr
<
Quantizer
>
CreateMXFP8Params
(
const
py
::
handle
params
);
TensorWrapper
NVTETensorFromFloat8BlockwiseQTensor
(
py
::
handle
tensor
,
Quantizer
*
quantization_params
);
inline
bool
IsFloatingPointType
(
at
::
ScalarType
type
)
{
return
type
==
at
::
kFloat
||
type
==
at
::
kHalf
||
type
==
at
::
kBFloat16
;
}
...
...
@@ -71,7 +86,9 @@ constexpr std::array custom_types_converters = {
std
::
make_tuple
(
IsFloat8Tensor
,
IsFloat8CurrentScalingQuantizers
,
NVTETensorFromFloat8Tensor
,
CreateQuantizer
<
Float8CurrentScalingQuantizer
>
),
std
::
make_tuple
(
IsMXFP8Tensor
,
IsMXFP8Quantizers
,
NVTETensorFromMXFP8Tensor
,
CreateQuantizer
<
MXFP8Quantizer
>
)};
CreateQuantizer
<
MXFP8Quantizer
>
),
std
::
make_tuple
(
IsFloat8BlockwiseQTensor
,
IsFloat8BlockwiseQuantizers
,
NVTETensorFromFloat8BlockwiseQTensor
,
CreateQuantizer
<
Float8BlockQuantizer
>
)};
}
// namespace detail
...
...
transformer_engine/pytorch/csrc/util.h
View file @
ab3e5a92
...
...
@@ -7,6 +7,19 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#include <torch/extension.h>
#include <optional>
#include "transformer_engine/transformer_engine.h"
bool
non_tn_fp8_gemm_supported
();
/* Swizzle the scaling factor of the input tensor.
*
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
*/
std
::
optional
<
at
::
Tensor
>
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
bool
trans
);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
transformer_engine/pytorch/distributed.py
View file @
ab3e5a92
...
...
@@ -19,15 +19,24 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from
torch.distributed.fsdp._common_utils
import
_get_module_fsdp_state
from
torch.distributed.fsdp._traversal_utils
import
_get_fsdp_states_with_modules
from
.utils
import
safely_set_viewless_tensor_data
from
.utils
import
non_tn_fp8_gemm_supported
,
safely_set_viewless_tensor_data
,
needs_quantized_gemm
from
.constants
import
dist_group_type
from
.fp8
import
FP8GlobalStateManager
,
fp8_autocast
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
.tensor.quantized_tensor
import
QuantizedTensor
,
Quantizer
from
.tensor._internal.float8_tensor_base
import
Float8TensorBase
from
.tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
.tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
try
:
import
torch.distributed._symmetric_memory
as
symm_mem
HAS_TORCH_SYMMETRIC
=
True
except
ImportError
:
HAS_TORCH_SYMMETRIC
=
False
__all__
=
[
"checkpoint"
,
"CudaRNGStatesTracker"
]
...
...
@@ -660,6 +669,9 @@ def checkpoint(
**
kwargs
,
)
from
.module.base
import
TransformerEngineBaseModule
if
isinstance
(
function
,
TransformerEngineBaseModule
):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr
(
function
,
"fsdp_wrapped"
,
False
)
...
...
@@ -860,23 +872,29 @@ def _all_gather_fp8(
process_group
:
dist_group_type
,
*
,
async_op
:
bool
=
False
,
quantizer
:
Optional
[
Float8
Quantizer
]
=
None
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
out_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
tuple
[
Float8TensorBase
,
Optional
[
torch
.
distributed
.
Work
]]:
"""All-gather FP8 tensor along first dimension."""
world_size
=
get_distributed_world_size
(
process_group
)
# Check that quantizer is valid
if
quantizer
is
not
None
and
not
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
):
raise
ValueError
(
f
"Got non-FP8 quantizer (
{
quantizer
.
__class__
.
__name__
}
)"
)
# Output tensor dims
if
out_shape
is
None
:
out_shape
=
list
(
inp
.
size
())
out_shape
[
0
]
*=
world_size
# Quantize input tensor if needed
# Cast input tensor to FP8 if needed
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
if
not
isinstance
(
inp
,
Float8TensorBase
):
assert
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
if
quantizer
is
None
:
raise
ValueError
(
"Input tensor is not FP8 and no quantizer was provided"
)
init_rowwise_usage
=
quantizer
.
rowwise_usage
init_columnwise_usage
=
quantizer
.
columnwise_usage
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
...
...
@@ -888,7 +906,7 @@ def _all_gather_fp8(
# Construct output tensor
out
:
Float8TensorBase
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
))
:
if
quantizer
is
not
None
:
dtype
=
torch
.
float32
device
=
"cuda"
if
isinstance
(
inp
,
Float8Tensor
):
...
...
@@ -906,9 +924,8 @@ def _all_gather_fp8(
out
.
_transpose_invalid
=
True
else
:
raise
RuntimeError
(
"FP8TensorBase is not supported yet without Quantizer"
)
# For delayed scaling, scale_inv is from history, so we can pass it from inp to out
# For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv,
# so we can just pass it from inp to out
# Assume scaling factors are identical across ranks
out
.
_scale_inv
=
inp
.
_scale_inv
# Perform communication
...
...
@@ -920,17 +937,86 @@ def _all_gather_fp8(
)
# Make sure FP8 transpose is populated if needed
if
out
.
_transpose
is
not
None
:
needs_transpose
=
(
quantizer
is
not
None
and
quantizer
.
columnwise_usage
and
not
non_tn_fp8_gemm_supported
()
)
if
needs_transpose
:
if
handle
is
not
None
:
handle
.
wait
()
handle
=
None
if
not
isinstance
(
out
,
Float8Tensor
):
raise
RuntimeError
(
"FP8TensorBase does not support FP8 transpose yet"
)
out
.
_create_transpose
()
return
out
,
handle
def
_all_gather_fp8_blockwise
(
inp
:
torch
.
Tensor
,
process_group
:
dist_group_type
,
*
,
async_op
:
bool
=
False
,
# pylint: disable=unused-argument
quantizer
:
Optional
[
Quantizer
]
=
None
,
out_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
distributed
.
Work
]]:
"""
All-gather FP8 tensor along first dimension for blockwise quantization.
Returns: quantizer(gather(inp))
NOTE: The implementation is not sophisticated enough to honor async_op=True.
In some cases it falls back to synchronous gather and invokes the quantizer.
"""
# Input tensor attributes
device
:
torch
.
device
dtype
:
torch
.
dtype
if
isinstance
(
inp
,
torch
.
Tensor
):
device
=
inp
.
device
dtype
=
inp
.
dtype
elif
isinstance
(
inp
,
Float8BlockwiseQTensorBase
):
if
inp
.
_rowwise_data
is
not
None
:
device
=
inp
.
_rowwise_data
.
device
elif
inp
.
_columnwise_data
is
not
None
:
device
=
inp
.
_columnwise_data
.
device
else
:
raise
ValueError
(
"Got Float8BlockwiseQTensorBase input tensor without any data"
)
dtype
=
torch
.
bfloat16
# Only has fp8 dtype. Guess BF16 for dequant.
else
:
raise
ValueError
(
"Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, "
f
"found
{
inp
.
__class__
.
__name__
}
)"
)
world_size
=
get_distributed_world_size
(
process_group
)
# Check that quantizer is valid
if
quantizer
is
not
None
and
not
isinstance
(
quantizer
,
Float8BlockQuantizer
):
raise
ValueError
(
f
"Got non-FP8 blockwise quantizer (
{
quantizer
.
__class__
.
__name__
}
)"
)
if
not
(
quantizer
.
block_scaling_dim
==
1
and
quantizer
.
block_len
==
128
):
raise
NotImplementedError
(
"Only 1D blockwise quantization is supported for allgather"
)
# Output tensor dims
if
out_shape
is
None
:
out_shape
=
list
(
inp
.
size
())
out_shape
[
0
]
*=
world_size
# Doing BF16 gather for now as baseline because it's simpler
if
not
isinstance
(
inp
,
Float8BlockwiseQTensorBase
)
and
quantizer
is
not
None
:
out
=
torch
.
empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
,
memory_format
=
torch
.
contiguous_format
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
,
async_op
=
False
)
out
=
quantizer
(
out
)
return
out
,
None
# Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# * Refer to scaffold code when implementing at:
# https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477
raise
NotImplementedError
(
"fp8 blockwise allgather not yet implemented"
)
def
_all_gather_mxfp8
(
inp
:
torch
.
Tensor
,
process_group
:
dist_group_type
,
...
...
@@ -1069,7 +1155,9 @@ def gather_along_first_dim(
async_op
:
bool
=
False
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
distributed
.
Work
]]:
"""All-gather tensors and concatenate along first dimension."""
"""
All-gather tensors and concatenate along first dimension.
"""
# Return immediately if no communication is required
world_size
=
get_distributed_world_size
(
process_group
)
...
...
@@ -1094,6 +1182,16 @@ def gather_along_first_dim(
out_shape
=
out_shape
,
)
# FP8 block scaling case, block length = 128
if
isinstance
(
inp
,
Float8BlockwiseQTensorBase
)
or
isinstance
(
quantizer
,
Float8BlockQuantizer
):
return
_all_gather_fp8_blockwise
(
inp
,
process_group
,
async_op
=
async_op
,
quantizer
=
quantizer
,
out_shape
=
out_shape
,
)
# MXFP8 case
if
isinstance
(
inp
,
MXFP8TensorBase
)
or
isinstance
(
quantizer
,
MXFP8Quantizer
):
assert
isinstance
(
quantizer
,
MXFP8Quantizer
)
...
...
@@ -1105,6 +1203,28 @@ def gather_along_first_dim(
out_shape
=
out_shape
,
)
# Debug case - call gather_along_first_dim on each tensor
if
isinstance
(
inp
,
DebugQuantizedTensor
):
out_obj
=
inp
rowwise
=
inp
.
get_tensor
(
False
)
columnwise
=
inp
.
get_tensor
(
True
)
final_quantizer
=
(
None
if
not
needs_quantized_gemm
(
inp
,
rowwise
=
True
)
else
quantizer
.
parent_quantizer
)
rowwise_total
=
gather_along_first_dim
(
rowwise
,
process_group
,
False
,
final_quantizer
)[
0
]
out_obj
.
rowwise_gemm_tensor
=
rowwise_total
if
rowwise
is
not
columnwise
:
final_quantizer_columnwise
=
(
None
if
not
needs_quantized_gemm
(
inp
,
rowwise
=
False
)
else
quantizer
.
parent_quantizer
)
columnwise_total
,
_
=
gather_along_first_dim
(
columnwise
,
process_group
,
False
,
final_quantizer_columnwise
)
out_obj
.
columnwise_gemm_tensor
=
columnwise_total
else
:
out_obj
.
rowwise_gemm_tensor
=
out_obj
.
rowwise_gemm_tensor
return
out_obj
,
None
# High-precision communication for quantized tensors
if
quantizer
is
not
None
:
warnings
.
warn
(
...
...
@@ -1147,6 +1267,152 @@ def gather_along_first_dim(
return
out
,
handle
# Global cache to store symmetric memory tensors
symmetric_mem_cache
=
{}
def
get_symmetric_memory_tensor
(
tensor_numel
,
tensor_dtype
,
tensor_device
,
tp_group
,
tag
=
None
):
"""
Gets or creates a symmetric memory tensor with specified properties.
Reuses cached tensors when available to avoid redundant creation and rendezvous operations.
Note: This function always returns a 1D tensor.
Parameters
----------
tensor_numel : int
Number of elements in the tensor.
tensor_dtype : torch.dtype
Data type of the tensor.
tensor_device : torch.device
Device on which to allocate the tensor.
tp_group : dist_group_type
Process group for rendezvous operation.
tag : Any, optional
Optional identifier to further distinguish tensors.
Returns
-------
torch.Tensor
A symmetric memory tensor with the specified properties.
"""
# Create a cache key based on tensor properties and group
cache_key
=
(
tensor_numel
,
tensor_dtype
,
tensor_device
,
tp_group
.
group_name
,
tag
)
# Check if we already have a symmetric memory tensor for this configuration
if
cache_key
not
in
symmetric_mem_cache
:
# Create a new symmetric memory tensor if not in cache
msg
=
symm_mem
.
empty
(
tensor_numel
,
dtype
=
tensor_dtype
,
device
=
tensor_device
,
)
# Perform the rendezvous once for this tensor
symm_mem
.
rendezvous
(
msg
,
group
=
tp_group
)
# Store in cache
symmetric_mem_cache
[
cache_key
]
=
msg
else
:
# Reuse the existing symmetric memory tensor
msg
=
symmetric_mem_cache
[
cache_key
]
return
msg
def
symmetric_all_reduce
(
inp
:
torch
.
Tensor
,
tp_group
:
Optional
[
dist_group_type
]
=
None
,
async_op
:
bool
=
False
,
all_reduce_type
:
str
=
"multimem_all_reduce"
,
):
"""
Performs an all-reduce operation across multiple processes using symmetric memory.
If the input tensor is already in the symmetric memory cache we can avoid copy
overheads by just directly using the input tensor for all reduce. Externally
created symmetric memory tensors not in the cache currently will not be able to
avoid the extra copies.
Parameters
----------
inp : torch.Tensor
The input tensor to be reduced. The operation is performed in-place.
tp_group : Optional[dist_group_type], default=None
The process group over which to perform the all-reduce operation.
If None, the default process group is used.
async_op : bool, default=False
Whether to perform the operation asynchronously.
Note: Currently only synchronous operations are supported for symmetric memory variants.
all_reduce_type : str, default="multimem_all_reduce"
The type of all-reduce implementation to use. Options include:
- "nccl": Standard PyTorch distributed all-reduce
- "multimem_all_reduce": multimem symmetric all-reduce
- "two_shot": Two-shot symmetric all-reduce
- "one_shot": One-shot symmetric all-reduce
Returns
-------
Tuple[torch.Tensor, Optional[torch.distributed.Work]]
- The first element is the input tensor with the all-reduce result.
- The second element is the async work handle if async_op=True,
otherwise None.
"""
assert
async_op
is
False
,
"Async symmetric ops no supported yet"
assert
HAS_TORCH_SYMMETRIC
,
"Could not import symetric memory from torch"
if
get_distributed_world_size
(
tp_group
)
==
1
:
return
inp
,
None
if
all_reduce_type
==
"nccl"
:
# Standard all-reduce implementation
handle
=
torch
.
distributed
.
all_reduce
(
inp
,
group
=
tp_group
,
async_op
=
async_op
)
return
inp
,
handle
all_reduce_impl
=
None
if
all_reduce_type
==
"multimem_all_reduce"
:
all_reduce_impl
=
torch
.
ops
.
symm_mem
.
multimem_all_reduce_
elif
all_reduce_type
==
"two_shot"
:
all_reduce_impl
=
torch
.
ops
.
symm_mem
.
two_shot_all_reduce_
elif
all_reduce_type
==
"one_shot"
:
all_reduce_impl
=
torch
.
ops
.
symm_mem
.
one_shot_all_reduce
else
:
raise
TypeError
(
f
"All reduce type
{
all_reduce_type
}
is not supported."
)
group_name
=
tp_group
.
group_name
tensor_shape
=
inp
.
shape
tensor_numel
=
inp
.
numel
()
tensor_dtype
=
inp
.
dtype
tensor_device
=
inp
.
device
input_id
=
id
(
inp
)
is_cached
=
any
(
id
(
cached_tensor
)
==
input_id
for
cached_tensor
in
symmetric_mem_cache
.
values
())
# Check if the input tensor is already in the symmetric memory cache. If it is we can avoid copy overheads.
if
is_cached
:
all_reduce_impl
(
inp
,
"sum"
,
group_name
,
)
else
:
# Get symmetric memory tensor. Build or retrieve from cache.
msg
=
get_symmetric_memory_tensor
(
tensor_numel
,
tensor_dtype
,
tensor_device
,
tp_group
)
msg
.
copy_
(
inp
.
reshape
(
-
1
))
all_reduce_impl
(
msg
,
"sum"
,
group_name
,
)
# Copy the result back to the input tensor
inp
.
copy_
(
msg
.
reshape
(
tensor_shape
))
return
inp
,
None
def
allreduce
(
inp
:
torch
.
Tensor
,
tp_group
:
Optional
[
dist_group_type
]
=
None
,
...
...
transformer_engine/pytorch/dot_product_attention/inference.py
View file @
ab3e5a92
...
...
@@ -128,9 +128,9 @@ class InferenceParams:
self
,
max_batch_size
:
int
,
max_sequence_length
:
int
,
num_heads_kv
:
int
=
16
,
head_dim_k
:
int
=
64
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
num_heads_kv
:
int
=
None
,
head_dim_k
:
int
=
None
,
dtype
:
torch
.
dtype
=
None
,
head_dim_v
:
int
=
None
,
is_paged
:
bool
=
False
,
total_num_pages
:
int
=
None
,
...
...
@@ -141,6 +141,10 @@ class InferenceParams:
):
self
.
max_batch_size
=
max_batch_size
self
.
max_sequence_length
=
max_sequence_length
assert
all
(
x
is
not
None
for
x
in
[
num_heads_kv
,
head_dim_k
,
dtype
]),
(
"num_heads_kv, head_dim_k, and dtype are required for InferenceParams since Transformer"
" Engine 2.2."
)
self
.
num_heads_kv
=
num_heads_kv
self
.
head_dim_k
=
head_dim_k
self
.
dtype
=
dtype
...
...
transformer_engine/pytorch/dot_product_attention/rope.py
View file @
ab3e5a92
...
...
@@ -7,7 +7,12 @@ Rotary Position Embedding implementation of different types along with helper fu
"""
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
QKVFormat
__all__
=
[
"RotaryPositionEmbedding"
,
"apply_rotary_pos_emb"
]
class
RotaryPositionEmbedding
(
torch
.
nn
.
Module
):
...
...
@@ -22,19 +27,24 @@ class RotaryPositionEmbedding(torch.nn.Module):
seq_len_interpolation_factor
:
Optional
[
int
]
=
None
,
pretrained_max_position_embeddings
:
Optional
[
int
]
=
None
,
rotary_base
:
float
=
10000.0
,
interleaved
:
bool
=
False
,
):
"""
Parameters
----------
dim: int
r
otary embedding dimension
rotary_percent: float
R
otary embedding dimension
.
rotary_percent: float
, default = 1.0
Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int
i
f not None, discrete positions will be interpolated by this factor via the trick in
seq_len_interpolation_factor: int
, default = None
I
f not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int
pre-trained max_position_embeddings before position interpolation
pretrained_max_position_embeddings: int, default = None
Pre-trained max_position_embeddings before position interpolation.
rotary_base: float, default = 10000.0
Base of the rotary position embedding.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
"""
super
().
__init__
()
if
rotary_percent
<
1.0
:
...
...
@@ -50,17 +60,18 @@ class RotaryPositionEmbedding(torch.nn.Module):
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
pretrained_max_position_embeddings
=
pretrained_max_position_embeddings
self
.
interleaved
=
interleaved
def
forward
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
):
"""
Create rotary position embedding frequencies
Create rotary position embedding frequencies
.
Parameters
----------
max_seq_len: int
s
equence length of a sample
S
equence length of a sample
.
offset: int, default = 0
f
ixed offset for freqencies
F
ixed offset for freq
u
encies
.
"""
seq
=
(
torch
.
arange
(
max_seq_len
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
...
...
@@ -84,7 +95,12 @@ class RotaryPositionEmbedding(torch.nn.Module):
freqs
=
torch
.
einsum
(
"i , j -> i j"
,
seq
,
self
.
inv_freq
)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
if
not
self
.
interleaved
:
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
else
:
emb
=
torch
.
stack
((
freqs
.
view
(
-
1
,
1
),
freqs
.
view
(
-
1
,
1
)),
dim
=-
1
).
view
(
freqs
.
shape
[
0
],
-
1
)
# emb [seq_length, .., dim]
return
emb
.
reshape
(
emb
.
size
(
0
),
1
,
1
,
emb
.
size
(
1
))
...
...
@@ -104,61 +120,146 @@ class FusedRoPEFunc(torch.autograd.Function):
t
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
,
tensor_format
:
str
=
"sbhd"
,
interleaved
:
bool
=
False
,
cu_seqlens
:
Union
[
torch
.
Tensor
,
None
]
=
None
,
cp_size
:
int
=
1
,
cp_rank
:
int
=
0
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
"""Fused RoPE forward."""
if
freqs
.
dtype
!=
torch
.
float32
:
freqs
=
freqs
.
float
()
if
tensor_format
==
"sbhd"
:
output
=
tex
.
fused_rope_forward
(
t
,
freqs
,
False
)
elif
tensor_format
==
"bshd"
:
output
=
tex
.
fused_rope_forward
(
t
.
transpose
(
0
,
1
),
freqs
,
True
).
transpose
(
0
,
1
)
elif
tensor_format
==
"thd"
:
output
=
tex
.
fused_rope_
thd_
forward
(
t
,
cu_seqlens
,
freqs
,
cp_size
,
cp_rank
)
else
:
raise
ValueError
(
f
"Unsupported tensor_format:
{
tensor_format
}
."
)
assert
tensor_format
in
(
"sbhd"
,
"bshd"
,
"thd"
,
),
f
"Unsupported tensor_format:
{
tensor_format
}
."
output
=
tex
.
fused_rope_forward
(
t
,
freqs
,
QKVFormat
[
tensor_format
],
interleaved
,
cu_seqlens
,
cp_size
,
cp_rank
)
ctx
.
save_for_backward
(
freqs
,
cu_seqlens
)
ctx
.
tensor_format
=
tensor_format
ctx
.
cp_size
=
cp_size
ctx
.
cp_rank
=
cp_rank
ctx
.
interleaved
=
interleaved
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
# pylint: disable=missing-function-docstring
"""Fused RoPE backward."""
freqs
,
cu_seqlens
=
ctx
.
saved_tensors
if
ctx
.
tensor_format
==
"sbhd"
:
grad_input
=
tex
.
fused_rope_backward
(
grad_output
,
freqs
,
False
)
elif
ctx
.
tensor_format
==
"bshd"
:
grad_input
=
tex
.
fused_rope_backward
(
grad_output
.
transpose
(
0
,
1
),
freqs
,
True
).
transpose
(
0
,
1
)
elif
ctx
.
tensor_format
==
"thd"
:
grad_input
=
tex
.
fused_rope_thd_backward
(
grad_output
,
cu_seqlens
,
freqs
,
ctx
.
cp_size
,
ctx
.
cp_rank
grad_output
,
freqs
,
QKVFormat
[
ctx
.
tensor_format
],
ctx
.
interleaved
,
cu_seqlens
,
ctx
.
cp_size
,
ctx
.
cp_rank
,
)
else
:
raise
ValueError
(
f
"Unsupported tensor_format:
{
ctx
.
tensor_format
}
."
)
return
grad_input
,
None
,
None
,
None
,
None
,
None
return
grad_input
,
None
,
None
,
None
,
None
,
None
,
None
def
_rotate_half
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
change sign so the last dimension becomes [-odd, +even]
def
_rotate_half
(
x
:
torch
.
Tensor
,
interleaved
:
bool
)
->
torch
.
Tensor
:
"""Change sign so the last dimension becomes [-odd, +even]
Args:
x: torch.Tensor. Input tensor.
interleaved: bool. Whether to use interleaved rotary position embedding.
Returns:
Tensor: Tensor rotated half.
"""
x
=
x
.
view
(
x
.
shape
[:
-
1
]
+
torch
.
Size
((
2
,
x
.
shape
[
-
1
]
//
2
)))
x1
,
x2
=
x
.
unbind
(
dim
=-
2
)
if
not
interleaved
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
# interleaved
x1
=
x
[:,
:,
:,
::
2
]
x2
=
x
[:,
:,
:,
1
::
2
]
x_new
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x_new
.
view
(
x_new
.
shape
[
0
],
x_new
.
shape
[
1
],
x_new
.
shape
[
2
],
-
1
)
def
_apply_rotary_pos_emb_base
(
t
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
,
tensor_format
:
str
=
"sbhd"
,
interleaved
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Base implementation of applying rotary positional embedding tensor to the input tensor.
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
"""
max_seq_len
=
freqs
.
shape
[
0
]
cur_seq_len
=
t
.
shape
[
1
]
if
tensor_format
==
"bshd"
else
t
.
shape
[
0
]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert
(
cur_seq_len
<=
max_seq_len
),
f
"Rotary Embeddings only supported up to
{
max_seq_len
}
sequence length!"
freqs
=
freqs
[:
cur_seq_len
]
if
tensor_format
==
"bshd"
:
freqs
=
freqs
.
transpose
(
0
,
1
)
# [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_
=
torch
.
cos
(
freqs
).
to
(
t
.
dtype
)
sin_
=
torch
.
sin
(
freqs
).
to
(
t
.
dtype
)
rot_dim
=
freqs
.
shape
[
-
1
]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t
,
t_pass
=
t
[...,
:
rot_dim
],
t
[...,
rot_dim
:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t
=
(
t
*
cos_
)
+
(
_rotate_half
(
t
,
interleaved
)
*
sin_
)
return
torch
.
cat
((
t
,
t_pass
),
dim
=-
1
)
def
_get_freqs_on_this_cp_rank
(
freqs
:
torch
.
Tensor
,
seqlen
:
int
,
cp_size
:
int
,
cp_rank
:
int
)
->
torch
.
Tensor
:
"""Get the position embedding on the current context parallel rank.
Args:
freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`.
seqlen: int. Length of the current sequence.
cp_size: int. Context parallel world size.
cp_rank: int. Context parallel rank.
"""
if
cp_size
>
1
:
cp_seg
=
seqlen
//
2
full_seqlen
=
cp_size
*
seqlen
return
torch
.
cat
(
[
freqs
[
cp_rank
*
cp_seg
:
(
cp_rank
+
1
)
*
cp_seg
],
freqs
[
full_seqlen
-
(
cp_rank
+
1
)
*
cp_seg
:
full_seqlen
-
cp_rank
*
cp_seg
],
]
)
# cp_size == 1
return
freqs
[:
seqlen
]
def
apply_rotary_pos_emb
(
t
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
,
tensor_format
:
str
=
"sbhd"
,
interleaved
:
bool
=
False
,
fused
:
bool
=
False
,
cu_seqlens
:
Union
[
torch
.
Tensor
,
None
]
=
None
,
cp_size
:
int
=
1
,
...
...
@@ -175,11 +276,13 @@ def apply_rotary_pos_emb(
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
...
...
@@ -189,37 +292,40 @@ def apply_rotary_pos_emb(
cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
if
fused
:
assert
(
tensor_format
!=
"thd"
or
cu_seqlens
is
not
None
),
"cu_seqlens must not be None when tensor_format is 'thd'."
return
FusedRoPEFunc
.
apply
(
t
,
freqs
,
tensor_format
,
cu_seqlens
,
cp_size
,
cp_rank
)
assert
tensor_format
in
(
"sbhd"
,
"bshd"
),
(
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
f
"when fused is False, got
{
tensor_format
}
."
if
fused
:
return
FusedRoPEFunc
.
apply
(
t
,
freqs
,
tensor_format
,
interleaved
,
cu_seqlens
,
cp_size
,
cp_rank
)
max_seq_len
=
freqs
.
shape
[
0
]
cur_seq_len
=
t
.
shape
[
1
]
if
tensor_format
==
"bshd"
else
t
.
shape
[
0
]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert
(
cur_seq_len
<=
max_seq_len
),
f
"Rotary Embeddings only supported up to
{
max_seq_len
}
sequence length!"
freqs
=
freqs
[:
cur_seq_len
]
if
tensor_format
==
"bshd"
:
freqs
=
freqs
.
transpose
(
0
,
1
)
# [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_
=
torch
.
cos
(
freqs
).
to
(
t
.
dtype
)
sin_
=
torch
.
sin
(
freqs
).
to
(
t
.
dtype
)
rot_dim
=
freqs
.
shape
[
-
1
]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t
,
t_pass
=
t
[...,
:
rot_dim
],
t
[...,
rot_dim
:]
# Unfused THD format
if
tensor_format
==
"thd"
:
cu_seqlens
=
cu_seqlens
//
cp_size
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
return
torch
.
cat
(
[
_apply_rotary_pos_emb_base
(
x
.
unsqueeze
(
1
),
_get_freqs_on_this_cp_rank
(
freqs
,
x
.
size
(
0
),
cp_size
,
cp_rank
),
interleaved
=
interleaved
,
)
for
x
in
torch
.
split
(
t
,
seqlens
)
]
).
squeeze
(
1
)
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t
=
(
t
*
cos_
)
+
(
_rotate_half
(
t
)
*
sin_
)
return
torch
.
cat
((
t
,
t_pass
),
dim
=-
1
)
# Unfused SBHD/BSHD format
if
tensor_format
==
"sbhd"
:
seqlen
=
t
.
size
(
0
)
elif
tensor_format
==
"bshd"
:
seqlen
=
t
.
size
(
1
)
else
:
raise
ValueError
(
f
"Unsupported tensor_format:
{
tensor_format
}
."
)
return
_apply_rotary_pos_emb_base
(
t
,
_get_freqs_on_this_cp_rank
(
freqs
,
seqlen
,
cp_size
,
cp_rank
),
tensor_format
,
interleaved
=
interleaved
,
)
transformer_engine/pytorch/fp8.py
View file @
ab3e5a92
...
...
@@ -6,6 +6,7 @@
from
__future__
import
annotations
import
abc
import
itertools
import
os
from
contextlib
import
contextmanager
from
collections
import
deque
...
...
@@ -19,6 +20,7 @@ from transformer_engine.common.recipe import (
Format
,
MXFP8BlockScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
)
from
.constants
import
dist_group_type
...
...
@@ -56,6 +58,17 @@ def check_mxfp8_support() -> Tuple[bool, str]:
return
False
,
"Device compute capability 10.0 or higher required for MXFP8 execution."
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 block scaling support is available"""
if
(
get_device_compute_capability
()
>=
(
9
,
0
)
and
get_device_compute_capability
()
<
(
10
,
0
)
and
float
(
torch
.
version
.
cuda
)
>=
12.9
):
return
True
,
""
return
False
,
"FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def
get_default_fp8_recipe
()
->
Recipe
:
"""FP8 recipe with default args."""
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
...
...
@@ -116,6 +129,8 @@ class FP8GlobalStateManager:
skip_fp8_weight_update_tensor
=
None
mxfp8_available
=
None
reason_for_no_mxfp8
=
""
fp8_block_scaling_available
=
None
reason_for_no_fp8_block_scaling
=
None
@
classmethod
def
reset
(
cls
)
->
None
:
...
...
@@ -141,6 +156,8 @@ class FP8GlobalStateManager:
cls
.
skip_fp8_weight_update_tensor
=
None
cls
.
mxfp8_available
=
None
cls
.
reason_for_no_mxfp8
=
""
cls
.
fp8_block_scaling_available
=
None
cls
.
reason_for_no_fp8_block_scaling
=
""
@
classmethod
def
set_skip_fp8_weight_update_tensor
(
cls
,
skip
:
bool
)
->
None
:
...
...
@@ -168,6 +185,15 @@ class FP8GlobalStateManager:
cls
.
mxfp8_available
,
cls
.
reason_for_no_mxfp8
=
check_mxfp8_support
()
return
cls
.
mxfp8_available
,
cls
.
reason_for_no_mxfp8
@
classmethod
def
is_fp8_block_scaling_available
(
cls
)
->
Tuple
[
bool
,
str
]:
"""Return if Float8 block scaling support is available."""
if
cls
.
fp8_block_scaling_available
is
None
:
cls
.
fp8_block_scaling_available
,
cls
.
reason_for_no_fp8_block_scaling
=
(
check_fp8_block_scaling_support
()
)
return
cls
.
fp8_block_scaling_available
,
cls
.
reason_for_no_fp8_block_scaling
@
staticmethod
def
get_meta_tensor_key
(
forward
:
bool
=
True
)
->
str
:
"""Returns scaling key in `fp8_meta`."""
...
...
@@ -441,6 +467,9 @@ class FP8GlobalStateManager:
if
isinstance
(
fp8_recipe
,
MXFP8BlockScaling
):
mxfp8_available
,
reason_for_no_mxfp8
=
cls
.
is_mxfp8_available
()
assert
mxfp8_available
,
reason_for_no_mxfp8
if
isinstance
(
fp8_recipe
,
Float8BlockScaling
):
fp8_block_available
,
reason_for_no_fp8_block
=
cls
.
is_fp8_block_scaling_available
()
assert
fp8_block_available
,
reason_for_no_fp8_block
@
classmethod
def
fp8_autocast_exit
(
cls
,
enabled
:
bool
,
_graph
:
bool
)
->
None
:
...
...
@@ -793,8 +822,10 @@ class RecipeState(abc.ABC):
cls
=
MXFP8BlockScalingRecipeState
elif
recipe
.
float8_current_scaling
():
cls
=
Float8CurrentScalingRecipeState
elif
recipe
.
float8_block_scaling
():
cls
=
Float8BlockScalingRecipeState
else
:
raise
ValueError
(
"{recipe.__class__.__name__} is not supported"
)
raise
ValueError
(
f
"
{
recipe
.
__class__
.
__name__
}
is not supported"
)
return
cls
(
recipe
,
mode
=
mode
,
...
...
@@ -935,3 +966,108 @@ class MXFP8BlockScalingRecipeState(RecipeState):
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
return
[
MXFP8Quantizer
(
self
.
dtype
)
for
i
in
range
(
self
.
num_quantizers
)]
class
Float8BlockScalingRecipeState
(
RecipeState
):
"""Configuration for Float8BlockScaling quantization.
Float8BlockScaling quantization does not require state,
but different quantizers use different modes.
"""
recipe
:
Float8BlockScaling
mode
:
str
qx_dtype
:
tex
.
DType
qw_dtype
:
tex
.
DType
qgrad_dtype
:
tex
.
DType
def
__init__
(
self
,
recipe
:
Float8BlockScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
qx_dtype
=
get_fp8_te_dtype
(
recipe
,
True
)
self
.
qw_dtype
=
get_fp8_te_dtype
(
recipe
,
True
)
self
.
qgrad_dtype
=
get_fp8_te_dtype
(
recipe
,
False
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
device
def
make_quantizers
(
self
)
->
list
:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from
.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
if
self
.
mode
==
"forward"
:
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward, and doesn't play nicely with QuantizeOp,
# which is not associated with a GEMM.
assert
self
.
num_quantizers
%
3
==
0
# x, w, output per gemm
return
list
(
itertools
.
chain
.
from_iterable
(
[
[
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qx_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
x_block_scaling_dim
,
),
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qw_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_fwd_weight
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_fwd_weight
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
w_block_scaling_dim
,
),
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qx_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
x_block_scaling_dim
,
),
]
for
_
in
range
(
self
.
num_quantizers
//
3
)
]
)
)
assert
self
.
mode
==
"backward"
,
f
"Unexpected mode
{
self
.
mode
}
"
assert
self
.
num_quantizers
%
2
==
0
# grad_output and grad_input per gemm
return
list
(
itertools
.
chain
.
from_iterable
(
[
[
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qgrad_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
grad_block_scaling_dim
,
),
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qgrad_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
grad_block_scaling_dim
,
),
]
for
_
in
range
(
self
.
num_quantizers
//
2
)
]
)
)
transformer_engine/pytorch/module/_common.py
View file @
ab3e5a92
...
...
@@ -9,6 +9,7 @@ from dataclasses import dataclass
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
queue
import
torch
from
..
import
cpp_extensions
as
tex
...
...
@@ -226,3 +227,79 @@ class _ParameterInitMeta:
"""Safeguard reference to the parameter's parent module and initialization function."""
if
self
.
init_fn
is
None
:
self
.
init_fn
=
get_default_init_method
()
class
WeightGradStore
:
"""
A class to manage weight gradient storage and computation in Transformer modules.
This class enables split backward propagation for better memory efficiency.
"""
def
__init__
(
self
,
delay_wgrad_compute
=
False
,
ub_bulk_wgrad
=
False
):
"""
Initialize the WeightGradStore.
Args:
delay_wgrad_compute (bool): Whether to delay weight gradient computation
ub_bulk_wgrad (bool): Whether to enable bulk weight gradient computation
"""
if
delay_wgrad_compute
:
self
.
context
=
queue
.
Queue
()
assert
(
ub_bulk_wgrad
is
False
),
"ub_bulk_wgrad is not supported when enabling delay_wgrad_compute"
self
.
enabled
=
delay_wgrad_compute
else
:
self
.
context
=
None
self
.
enabled
=
False
def
delay_wgrad_compute
(
self
):
"""
Get the current split backward propagation status.
Returns:
bool: True if split backward is enabled, False otherwise
"""
return
self
.
enabled
def
enable_delay_wgrad_compute
(
self
):
"""Enable split backward propagation."""
self
.
enabled
=
True
def
disable_delay_wgrad_compute
(
self
):
"""Disable split backward propagation."""
self
.
enabled
=
False
def
put
(
self
,
tensor_list
,
func
):
"""
Store tensors and computation function for later execution.
Args:
tensor_list (list): List of tensors needed for computation
func (callable): Function to be executed with the tensors
"""
assert
self
.
enabled
is
True
,
"delay_wgrad_compute is not enabled"
self
.
context
.
put
([
tensor_list
,
func
])
def
pop
(
self
):
"""
Execute the stored computation with the stored tensors.
Raises an exception if the queue is empty.
"""
assert
self
.
enabled
is
True
,
"delay_wgrad_compute is not enabled"
if
self
.
context
.
qsize
()
>
0
:
tensor_list
,
func
=
self
.
context
.
get
()
return
func
(
*
tensor_list
),
tensor_list
if
torch
.
distributed
.
is_initialized
():
rank
=
torch
.
distributed
.
get_rank
()
raise
RuntimeError
(
f
"Pop empty queue. rank
{
rank
}
"
)
raise
RuntimeError
(
"Pop empty queue. No distributed environment detected."
)
def
assert_empty
(
self
):
"""
Assert that the queue is empty.
Used for debugging and ensuring proper cleanup.
"""
assert
self
.
enabled
is
True
,
"delay_wgrad_compute is not enabled"
rank
=
torch
.
distributed
.
get_rank
()
assert
self
.
context
.
empty
(),
f
"Queue is not empty. rank
{
rank
}
"
transformer_engine/pytorch/module/base.py
View file @
ab3e5a92
...
...
@@ -10,6 +10,7 @@ import warnings
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
contextlib
import
contextmanager
import
logging
from
types
import
MethodType
import
torch
...
...
@@ -18,11 +19,12 @@ import torch.nn.functional as F
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
._common
import
_ParameterInitMeta
from
._common
import
_ParameterInitMeta
,
noop_cat
from
..fp8
import
(
MXFP8BlockScalingRecipeState
,
DelayedScalingRecipeState
,
Float8CurrentScalingRecipeState
,
Float8BlockScalingRecipeState
,
FP8GlobalStateManager
,
RecipeState
,
)
...
...
@@ -34,8 +36,13 @@ from ..distributed import (
)
from
..constants
import
dist_group_type
from
..tensor
import
QuantizedTensor
,
Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
...common.recipe
import
Recipe
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
,
DebugQuantizedTensor
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
__all__
=
[
"initialize_ub"
,
"destroy_ub"
]
...
...
@@ -44,7 +51,8 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD
=
True
_2X_ACC_WGRAD
=
True
_multi_stream_cublas_workspace
=
[]
_multi_stream_cublas_batchgemm_workspace
=
[]
_dummy_wgrads
=
{}
multi_stream_cublas_batchgemm_workspace
=
[]
_cublas_workspace
=
None
_ub_communicators
=
None
_NUM_MAX_UB_STREAMS
=
2
if
IS_HIP_EXTENSION
else
3
...
...
@@ -82,6 +90,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
)
return
_multi_stream_cublas_workspace
def
get_multi_stream_cublas_batchgemm_workspace
()
->
List
[
torch
.
Tensor
]:
"""Returns workspace for multi-stream cublas."""
global
_multi_stream_cublas_batchgemm_workspace
...
...
@@ -92,11 +101,29 @@ def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
)
return
_multi_stream_cublas_batchgemm_workspace
if
bool
(
int
(
os
.
getenv
(
"NVTE_DISABLE_FC2_DGRAD_OVERLAP"
,
"0"
))):
remove_ag_gemm_dgrad
=
[
"fc2_dgrad"
]
else
:
remove_ag_gemm_dgrad
=
[]
def
get_dummy_wgrad
(
shape
:
list
,
dtype
:
torch
.
dtype
,
zero
=
False
)
->
torch
.
Tensor
:
"""Returns a dummy tensor of given shape."""
assert
len
(
shape
)
==
2
global
_dummy_wgrads
if
(
shape
[
0
],
shape
[
1
],
dtype
)
not
in
_dummy_wgrads
:
_dummy_wgrads
[(
shape
[
0
],
shape
[
1
],
dtype
)]
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
False
,
)
if
zero
:
_dummy_wgrads
[(
shape
[
0
],
shape
[
1
],
dtype
)].
fill_
(
0
)
return
_dummy_wgrads
[(
shape
[
0
],
shape
[
1
],
dtype
)].
detach
()
def
initialize_ub
(
shape
:
list
,
tp_size
:
int
,
...
...
@@ -429,6 +456,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
assert
torch
.
cuda
.
is_available
(),
"TransformerEngine needs CUDA."
self
.
name
=
None
self
.
fp8_initialized
=
False
self
.
fp8
=
False
self
.
fp8_calibration
=
False
...
...
@@ -448,6 +476,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
_fp8_workspaces
:
Dict
[
str
,
QuantizedTensor
]
=
{}
self
.
activation_dtype
:
Optional
[
torch
.
dtype
]
=
None
if
not
TEDebugState
.
debug_enabled
:
TEDebugState
.
initialize
()
# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names
:
Set
[
str
]
=
{
...
...
@@ -535,6 +566,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe_state
,
Float8CurrentScalingRecipeState
):
return
if
recipe
.
float8_block_scaling
()
and
isinstance
(
recipe_state
,
Float8BlockScalingRecipeState
):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
...
...
@@ -860,7 +895,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
gather_grad_output
=
row_parallel_mode
and
ctx
.
sequence_parallel
# Non-FP8 case: bgrad is fused with wgrad for this case.
if
not
ctx
.
fp8
:
if
not
ctx
.
fp8
and
not
ctx
.
debug
:
if
gather_grad_output
:
if
not
ctx
.
ub_overlap_ag
or
ctx
.
ub_obj_gradout
is
None
:
grad_output
,
_
=
gather_along_first_dim
(
grad_output
,
ctx
.
tp_group
)
...
...
@@ -870,6 +905,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return
grad_output
,
None
# FP8 with all-gather: unfused bgrad, fused cast + transpose
# Also supports debug quantization, which is handled inside gather_along_first_dim.
if
gather_grad_output
:
grad_bias
=
None
if
ctx
.
use_bias
:
...
...
@@ -877,7 +913,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
ctx
.
ub_overlap_ag
and
ctx
.
ub_obj_gradout
is
not
None
:
# Quantize the gradient if needed
if
not
isinstance
(
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
)
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
,
Float8BlockwiseQTensorBase
,
),
):
grad_output
=
quantizer
(
grad_output
)
...
...
@@ -892,14 +934,41 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
return
grad_output
,
grad_bias
# Debug without all-gather: unfused cast and bgrad
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
if
ctx
.
debug
:
grad_output_
=
quantizer
(
grad_output
)
if
(
isinstance
(
grad_output_
.
get_tensor
(
True
),
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
),
)
and
ctx
.
use_bias
):
grad_bias
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
else
:
grad_bias
=
None
grad_output
=
grad_output_
return
grad_output
,
grad_bias
# FP8 without all-gather: fused bgrad + cast + transpose
grad_bias
=
None
if
ctx
.
use_bias
:
if
isinstance
(
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
)):
if
isinstance
(
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
,
Float8BlockwiseQTensorBase
),
):
grad_bias
=
grad_output
.
dequantize
().
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
else
:
if
isinstance
(
quantizer
,
Float8BlockQuantizer
):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
grad_bias
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
else
:
grad_bias
,
grad_output
=
tex
.
bgrad_quantize
(
grad_output
,
quantizer
)
if
not
isinstance
(
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
)):
if
not
isinstance
(
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
,
Float8BlockwiseQTensorBase
),
):
grad_output
=
quantizer
(
grad_output
)
return
grad_output
,
grad_bias
...
...
@@ -998,6 +1067,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
update_workspace
:
bool
=
True
,
skip_update_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
fsdp_group
:
Optional
[
dist_group_type
]
=
None
,
workspace_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
QuantizedTensor
:
"""Get FP8 workspace buffer and maybe update its values
...
...
@@ -1020,6 +1090,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
over `update_workspace` if provided.
fsdp_group: bool, default = None
FSDP process group that the weights are distributed over.
workspace_dtype: torch.dtype, default = None
If weight workspace contains high-precision tensor - for example
for debug quantization, this is dtype of the tensor.
"""
# FP8 primary weights
...
...
@@ -1033,6 +1106,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Try getting workspace from cache
out
=
None
if
cache_name
is
not
None
:
out
=
self
.
_fp8_workspaces
.
get
(
cache_name
,
None
)
if
quantizer
is
not
None
and
isinstance
(
out
,
MXFP8TensorBase
):
...
...
@@ -1043,6 +1117,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out
=
None
del
self
.
_fp8_workspaces
[
cache_name
]
is_debug
=
isinstance
(
quantizer
,
DebugQuantizer
)
is_out_debug_tensor
=
out
is
not
None
and
isinstance
(
out
,
DebugQuantizedTensor
)
if
is_debug
!=
is_out_debug_tensor
:
out
=
None
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
...
...
@@ -1060,7 +1139,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise
ValueError
(
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
)
out
=
quantizer
(
tensor
)
out
=
quantizer
.
quantize
(
tensor
,
dtype
=
workspace_dtype
)
# Update cache
if
cache_name
is
not
None
:
...
...
@@ -1077,7 +1156,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out
.
quantize_
(
tensor
,
noop_flag
=
skip_update_flag
)
else
:
tex
.
quantize
(
tensor
,
quantizer
,
out
,
skip_update_flag
)
return
out
def
_load_from_state_dict
(
...
...
@@ -1100,3 +1178,68 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
def
backward_dw
(
self
):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
return
with
torch
.
cuda
.
nvtx
.
range
(
f
"_
{
self
.
__class__
.
__name__
}
_wgrad"
):
(
wgrad
,
grad_bias_
,
_
,
_
),
_
=
self
.
wgrad_store
.
pop
()
if
not
self
.
fuse_wgrad_accumulation
:
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
weight_tensor
=
noop_cat
(
unfused_weights
)
if
weight_tensor
.
grad
is
None
:
weight_tensor
.
grad
=
wgrad
.
to
(
weight_tensor
.
dtype
)
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
if
bias_tensor
.
grad
is
None
:
bias_tensor
.
grad
=
grad_bias_
.
to
(
bias_tensor
.
dtype
)
del
grad_bias_
del
wgrad
def
_validate_name
(
self
):
"""
Validate name passed to the module.
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variable.
"""
assert
TEDebugState
.
debug_enabled
import
nvdlfw_inspect.api
as
debug_api
if
self
.
name
is
None
:
debug_api
.
log_message
(
"Names are not provided to debug modules. "
,
"Creating and using generic names. Pass names to debug modules for better"
" insight. "
,
level
=
logging
.
WARNING
,
)
self
.
name
=
f
"Layer_
{
TEDebugState
.
get_layer_count
()
}
"
def
_turn_off_unsupported_features_in_debug
(
self
):
if
(
getattr
(
self
,
"ub_bulk_wgrad"
,
False
)
or
getattr
(
self
,
"ub_bulk_dgrad"
,
False
)
or
getattr
(
self
,
"ub_overlap_ag"
,
False
)
or
getattr
(
self
,
"ub_overlap_rs_dgrad"
,
False
)
or
getattr
(
self
,
"ub_overlap_rs"
,
False
)
):
import
nvdlfw_inspect.api
as
debug_api
debug_api
.
log_message
(
"UserBuffers are not supported in debug module. "
"Using UB optimization will not affect the debug module. "
,
level
=
logging
.
WARNING
,
)
if
hasattr
(
self
,
"ub_bulk_wgrad"
):
self
.
ub_bulk_wgrad
=
None
if
hasattr
(
self
,
"ub_bulk_dgrad"
):
self
.
ub_bulk_dgrad
=
None
if
hasattr
(
self
,
"ub_overlap_ag"
):
self
.
ub_overlap_ag
=
None
if
hasattr
(
self
,
"ub_overlap_rs_dgrad"
):
self
.
ub_overlap_rs_dgrad
=
None
if
hasattr
(
self
,
"ub_overlap_rs"
):
self
.
ub_overlap_rs
=
None
transformer_engine/pytorch/module/fp8_padding.py
View file @
ab3e5a92
...
...
@@ -4,12 +4,13 @@
"""FP8 Padding API"""
from
typing
import
Union
,
List
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
transformer_engine_torch
as
tex
from
..fp8
import
FP8GlobalStateManager
from
..jit
import
no_torch_dynamo
...
...
@@ -74,22 +75,30 @@ class Fp8Padding(torch.nn.Module):
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
align_size: int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
"""
def
__init__
(
self
,
num_gemms
,
num_gemms
:
int
,
align_size
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_gemms
=
num_gemms
if
align_size
is
None
:
self
.
align_size
=
32
if
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
else
16
else
:
self
.
align_size
=
align_size
@
no_torch_dynamo
()
def
forward
(
self
,
inp
:
torch
.
Tensor
,
m_splits
:
List
[
int
],
)
->
Union
[
torch
.
Tensor
,
List
[
int
]]:
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
"""
Apply the padding to the input.
...
...
@@ -104,7 +113,12 @@ class Fp8Padding(torch.nn.Module):
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
# FP8 padding calculate
padded_m_splits
=
[(
m
+
15
)
//
16
*
16
for
m
in
m_splits
]
padded_m_splits
=
[
(
m
+
self
.
align_size
-
1
)
//
self
.
align_size
*
self
.
align_size
for
m
in
m_splits
]
# no padding needed
if
m_splits
==
padded_m_splits
:
return
inp
,
m_splits
if
torch
.
is_grad_enabled
():
fn
=
_Fp8Padding
.
apply
...
...
transformer_engine/pytorch/module/fp8_unpadding.py
View file @
ab3e5a92
...
...
@@ -4,12 +4,13 @@
"""FP8 Padding API"""
from
typing
import
List
from
typing
import
List
,
Optional
import
torch
import
transformer_engine_torch
as
tex
from
..fp8
import
FP8GlobalStateManager
from
..jit
import
no_torch_dynamo
...
...
@@ -70,15 +71,23 @@ class Fp8Unpadding(torch.nn.Module):
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
align_size: int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
"""
def
__init__
(
self
,
num_gemms
,
num_gemms
:
int
,
align_size
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_gemms
=
num_gemms
if
align_size
is
None
:
self
.
align_size
=
32
if
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
else
16
else
:
self
.
align_size
=
align_size
@
no_torch_dynamo
()
def
forward
(
...
...
@@ -100,7 +109,12 @@ class Fp8Unpadding(torch.nn.Module):
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
# FP8 padding calculate
padded_m_splits
=
[(
m
+
15
)
//
16
*
16
for
m
in
m_splits
]
padded_m_splits
=
[
(
m
+
self
.
align_size
-
1
)
//
self
.
align_size
*
self
.
align_size
for
m
in
m_splits
]
# no padding needed
if
m_splits
==
padded_m_splits
:
return
inp
if
torch
.
is_grad_enabled
():
fn
=
_Fp8Unpadding
.
apply
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
ab3e5a92
...
...
@@ -5,10 +5,12 @@
"""GroupedLinear API"""
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
import
functools
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
.base
import
(
get_multi_stream_cublas_workspace
,
TransformerEngineBaseModule
,
...
...
@@ -16,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
._common
import
WeightGradStore
from
..fp8
import
FP8GlobalStateManager
from
..utils
import
(
divide
,
...
...
@@ -37,7 +40,6 @@ from ..cpp_extensions import (
from
..constants
import
GemmParallelModes
,
dist_group_type
,
TE_DType
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..tensor.float8_tensor
import
Float8Tensor
from
..cpu_offload
import
is_cpu_offload_enabled
from
..tensor.quantized_tensor
import
(
...
...
@@ -47,7 +49,6 @@ from ..tensor.quantized_tensor import (
restore_from_saved
,
)
__all__
=
[
"GroupedLinear"
]
...
...
@@ -65,6 +66,7 @@ class _GroupedLinear(torch.autograd.Function):
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
input_quantizers
:
List
[
Quantizer
],
weight_quantizers
:
List
[
Quantizer
],
output_quantizers
:
List
[
Quantizer
],
...
...
@@ -85,13 +87,6 @@ class _GroupedLinear(torch.autograd.Function):
biases
=
weights_and_biases
[
num_gemms
:]
device
=
inp
.
device
# TODO Support MXFP8 # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
():
raise
NotImplementedError
(
"GroupedLinear does not yet support MXFP8"
)
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_current_scaling
():
raise
NotImplementedError
(
"GroupedLinear does not yet support Float8 Current Scaling"
)
# Make sure input dimensions are compatible
in_features
=
weights
[
0
].
shape
[
-
1
]
assert
inp
.
shape
[
-
1
]
==
in_features
,
"GEMM not possible"
...
...
@@ -124,7 +119,11 @@ class _GroupedLinear(torch.autograd.Function):
for
output_quantizer
in
output_quantizers
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
fprop_gemm_use_split_accumulator
=
_2X_ACC_FPROP
if
fp8
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
hasattr
(
recipe
,
"fp8_gemm_fprop"
):
fprop_gemm_use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
inputmats
=
tex
.
fused_multi_quantize
(
inputmats_no_fp8
,
None
,
input_quantizers
,
TE_DType
[
activation_dtype
]
)
...
...
@@ -165,7 +164,7 @@ class _GroupedLinear(torch.autograd.Function):
m_splits
=
m_splits
,
bias
=
biases
,
use_bias
=
use_bias
,
use_split_accumulator
=
_2X_ACC_FPROP
,
use_split_accumulator
=
fprop_gemm_use_split_accumulator
,
)
if
fp8_calibration
:
...
...
@@ -177,9 +176,19 @@ class _GroupedLinear(torch.autograd.Function):
weight_quantizers
[
i
].
calibrate
(
weights
[
i
])
if
is_grad_enabled
:
ctx
.
weight_quantizers
=
weight_quantizers
ctx
.
weights_shape_1
=
weights
[
0
].
shape
[
1
]
# TODO: update after #1638 is merged. # pylint: disable=fixme
if
weight_requires_grad
:
for
inputmat
in
inputmats
:
if
isinstance
(
inputmat
,
QuantizedTensor
):
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
if
inp
.
requires_grad
:
for
weight
in
weights_fp8
:
if
isinstance
(
weight
,
QuantizedTensor
):
weight
.
update_usage
(
columnwise_usage
=
True
)
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
inputmats
,
*
weights_fp8
,
...
...
@@ -200,6 +209,7 @@ class _GroupedLinear(torch.autograd.Function):
ctx
.
num_gemms
=
num_gemms
ctx
.
activation_dtype
=
activation_dtype
ctx
.
fp8
=
fp8
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
...
...
@@ -213,6 +223,7 @@ class _GroupedLinear(torch.autograd.Function):
ctx
.
reduce_and_update_bwd_fp8_tensors
or
FP8GlobalStateManager
.
is_first_fp8_module
()
)
ctx
.
wgrad_store
=
wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
...
...
@@ -245,6 +256,13 @@ class _GroupedLinear(torch.autograd.Function):
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
fp8
:
if
ctx
.
use_bias
:
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready
# for Float8BlockQuantizer.
if
ctx
.
fp8_recipe
.
float8_block_scaling
():
for
i
in
range
(
ctx
.
num_gemms
):
grad_biases
[
i
]
=
grad_output_mats
[
i
].
sum
(
dim
=
0
)
grad_output
[
i
]
=
ctx
.
grad_output_quantizers
[
i
](
grad_output_mats
[
i
])
else
:
for
i
in
range
(
ctx
.
num_gemms
):
grad_biases
[
i
],
grad_output
[
i
]
=
tex
.
bgrad_quantize
(
grad_output_mats
[
i
],
ctx
.
grad_output_quantizers
[
i
]
...
...
@@ -267,12 +285,25 @@ class _GroupedLinear(torch.autograd.Function):
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
if
ctx
.
requires_dgrad
:
dgrad_gemm_use_split_accumulator
=
_2X_ACC_DGRAD
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
dgrad_gemm_use_split_accumulator
=
(
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
)
dgrad
=
torch
.
empty
(
(
sum
(
ctx
.
m_splits
),
ctx
.
weights_shape_1
),
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
,
)
for
weight
,
quantizer
in
zip
(
weights
,
ctx
.
weight_quantizers
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
):
weight
.
update_usage
(
rowwise_usage
=
quantizer
.
rowwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
)
general_grouped_gemm
(
weights
,
grad_output
,
...
...
@@ -283,10 +314,17 @@ class _GroupedLinear(torch.autograd.Function):
layout
=
"NN"
,
m_splits
=
ctx
.
m_splits
,
grad
=
True
,
use_split_accumulator
=
_2X_ACC_DGRAD
,
use_split_accumulator
=
dgrad_gemm_use_split_accumulator
,
)
if
ctx
.
weights_requires_grad
:
wgrad_gemm_use_split_accumulator
=
_2X_ACC_WGRAD
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
wgrad_gemm_use_split_accumulator
=
(
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
)
if
ctx
.
fuse_wgrad_accumulation
:
wgrad_list
=
main_grads
else
:
...
...
@@ -294,21 +332,24 @@ class _GroupedLinear(torch.autograd.Function):
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
)
for
w
in
weights
]
# WGRAD
_
,
grad_biases_
,
_
=
general_grouped_gemm
(
inputmats
,
grad_output
,
wgrad_list
,
ctx
.
activation_dtype
,
get_multi_stream_cublas_workspace
(),
grouped_gemm_wgrad
=
functools
.
partial
(
general_grouped_gemm
,
out_dtype
=
ctx
.
activation_dtype
,
workspaces
=
get_multi_stream_cublas_workspace
(),
layout
=
"NT"
,
grad
=
True
,
m_splits
=
ctx
.
m_splits
,
use_bias
=
ctx
.
use_bias
if
grad_biases
[
0
]
is
None
else
None
,
bias
=
biases
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
use_split_accumulator
=
wgrad_gemm_use_split_accumulator
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
)
# WGRAD
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
inputmats
,
grad_output
,
wgrad_list
],
grouped_gemm_wgrad
)
else
:
_
,
grad_biases_
,
_
=
grouped_gemm_wgrad
(
inputmats
,
grad_output
,
wgrad_list
)
for
i
in
range
(
ctx
.
num_gemms
):
if
grad_biases
[
i
]
is
None
:
grad_biases
[
i
]
=
grad_biases_
[
i
]
...
...
@@ -351,7 +392,14 @@ class _GroupedLinear(torch.autograd.Function):
else
:
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
not
ctx
.
use_bias
:
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
not
ctx
.
use_bias
or
(
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
()
and
not
ctx
.
fp8
):
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
...
...
@@ -372,8 +420,9 @@ class _GroupedLinear(torch.autograd.Function):
None
,
None
,
None
,
None
,
# is_grad_enabled
None
,
# is_grad_enabled
None
,
None
,
None
,
*
wgrad_list
,
*
grad_biases
,
)
...
...
@@ -422,7 +471,12 @@ class GroupedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
`parallel_mode` are used to determine the shapes of weights and biases.
The TP communication should be handled in the dispatch and combine stages of MoE models.
"""
def
__init__
(
...
...
@@ -445,6 +499,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
...
...
@@ -465,7 +520,13 @@ class GroupedLinear(TransformerEngineBaseModule):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
_offsets
=
{
"input"
:
0
,
"weight"
:
num_gemms
,
"output"
:
2
*
num_gemms
,
"grad_output"
:
0
}
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
self
.
_offsets
=
{
"input"
:
0
,
"weight"
:
1
,
"output"
:
2
,
"grad_output"
:
0
,
"grad_input"
:
1
}
self
.
_num_fp8_tensors_per_gemm
=
{
"fwd"
:
3
,
"bwd"
:
2
,
}
if
tp_group
is
None
:
self
.
tp_size
=
tp_size
...
...
@@ -476,6 +537,12 @@ class GroupedLinear(TransformerEngineBaseModule):
self
.
set_tensor_parallel_group
(
tp_group
)
self
.
set_nccl_overlap_warning_if_tp
()
if
self
.
tp_size
>
1
and
bias
:
raise
ValueError
(
"GroupedLinear doesn't support bias when TP > 1. "
"Because the TP communication is handled outside of this module."
)
self
.
parallel_mode
=
parallel_mode
assert
(
self
.
parallel_mode
in
GemmParallelModes
...
...
@@ -502,7 +569,7 @@ class GroupedLinear(TransformerEngineBaseModule):
),
init_fn
=
init_method
,
get_rng_state_tracker
=
get_rng_state_tracker
,
fp8_meta_index
=
self
.
_offsets
[
"weight"
]
+
i
,
fp8_meta_index
=
self
.
_offsets
[
"weight"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
,
)
# Construct bias parameters if needed
...
...
@@ -527,12 +594,18 @@ class GroupedLinear(TransformerEngineBaseModule):
self
.
reset_parameters
(
defer_init
=
device
==
"meta"
)
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if
self
.
parallel_mode
==
"row"
and
self
.
apply_bias
:
self
.
gemm_bias_unfused_add
=
True
else
:
self
.
gemm_bias_unfused_add
=
False
def
set_meta_tensor
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Init scales and amaxes for fwd | bwd."""
super
().
set_meta_tensor
(
fwd
,
recipe
)
# customize quantizers based on each recipe & layer configs
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
assert
not
self
.
tp_size
>
1
,
(
"GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
def
reset_parameters
(
self
,
defer_init
=
False
):
super
().
reset_parameters
(
defer_init
=
defer_init
)
...
...
@@ -590,7 +663,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
"""
assert
not
isinstance
(
inp
,
Float8
Tensor
inp
,
Quantized
Tensor
),
"GroupedLinear doesn't support input tensor in FP8."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
...
...
@@ -615,20 +688,27 @@ class GroupedLinear(TransformerEngineBaseModule):
grad_output_quantizers
,
_
=
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
if
self
.
fp8
:
input_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"input"
]
+
i
]
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
]
for
i
in
range
(
self
.
num_gemms
)
]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for
i
in
range
(
self
.
num_gemms
):
input_quantizers
[
i
].
internal
=
Tru
e
input_quantizers
[
i
].
internal
=
Fals
e
weight_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"weight"
]
+
i
]
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"weight"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
weight_quantizers
[
i
].
internal
=
True
if
torch
.
is_grad_enabled
():
grad_output_quantizers
=
[
self
.
quantizers
[
"scaling_bwd"
][
self
.
_offsets
[
"input"
]
+
i
]
self
.
quantizers
[
"scaling_bwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
...
...
@@ -643,10 +723,11 @@ class GroupedLinear(TransformerEngineBaseModule):
args
+=
(
inp
,
m_splits
,
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
,
self
.
apply_bias
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8_calibration
,
self
.
wgrad_store
,
input_quantizers
,
weight_quantizers
,
output_quantizers
,
...
...
@@ -663,17 +744,61 @@ class GroupedLinear(TransformerEngineBaseModule):
)
out
=
linear_fn
(
*
args
)
if
self
.
gemm_bias_unfused_add
:
out_shape
=
out
.
shape
out
=
torch
.
cat
(
[
o
+
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
o
,
b
in
zip
(
torch
.
split
(
out
.
view
(
-
1
,
self
.
out_features
),
m_splits
),
bias_tensors
)
]
).
view
(
out_shape
)
if
self
.
return_bias
:
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
def
backward_dw
(
self
):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
return
with
torch
.
cuda
.
nvtx
.
range
(
"_GroupedLinear_wgrad"
):
(
_
,
grad_biases_
,
_
),
tensor_list
=
self
.
wgrad_store
.
pop
()
wgrad_list
=
tensor_list
[
2
]
if
not
self
.
fuse_wgrad_accumulation
:
for
i
in
range
(
self
.
num_gemms
):
weight_param
=
getattr
(
self
,
f
"weight
{
i
}
"
)
if
weight_param
.
grad
is
None
:
weight_param
.
grad
=
wgrad_list
[
i
].
to
(
weight_param
.
dtype
)
if
self
.
use_bias
:
for
i
in
range
(
self
.
num_gemms
):
bias_param
=
getattr
(
self
,
f
"bias
{
i
}
"
)
if
bias_param
.
grad
is
None
:
bias_param
.
grad
=
grad_biases_
[
i
].
to
(
bias_param
.
dtype
)
del
grad_biases_
del
wgrad_list
del
tensor_list
def
_customize_quantizers_float8_current_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + linear."""
assert
(
recipe
.
float8_current_scaling
()
),
"current scaling recipe quantizer customization here"
if
fwd
:
for
i
in
range
(
self
.
num_gemms
):
# set configs about amax epsilon and power_2_scale
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
].
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
].
amax_epsilon
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
# also set weight quantizer with same amax_epsilon & power_2_scale
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"weight"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
].
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_weight
.
power_2_scale
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"weight"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
].
amax_epsilon
=
recipe
.
fp8_quant_fwd_weight
.
amax_epsilon
else
:
for
i
in
range
(
self
.
num_gemms
):
# set grad_output_quantizer with amax epsilon and power_2_scale
self
.
quantizers
[
"scaling_bwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
].
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
self
.
quantizers
[
"scaling_bwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
transformer_engine/pytorch/module/layernorm_linear.py
View file @
ab3e5a92
...
...
@@ -9,16 +9,19 @@ from typing import Callable, Dict, Optional, Tuple, Union
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
functools
import
torch
from
torch.nn
import
init
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
.base
import
(
get_workspace
,
get_ub
,
TransformerEngineBaseModule
,
get_dummy_wgrad
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
...
...
@@ -34,11 +37,13 @@ from ..utils import (
nvtx_range_pop
,
nvtx_range_push
,
requires_grad
,
needs_quantized_gemm
,
)
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
get_distributed_world_size
,
allreduce
,
symmetric_all_reduce
,
reduce_scatter_along_first_dim
,
gather_along_first_dim
,
in_fp8_activation_recompute_phase
,
...
...
@@ -48,16 +53,21 @@ from ..distributed import (
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
._common
import
apply_normalization
,
noop_cat
,
_fix_gathered_fp8_transpose
from
._common
import
apply_normalization
,
noop_cat
,
_fix_gathered_fp8_transpose
,
WeightGradStore
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
)
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.utils
import
any_feature_enabled
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpp_extensions
import
(
general_gemm
,
)
...
...
@@ -89,12 +99,14 @@ class _LayerNormLinear(torch.autograd.Function):
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
fuse_wgrad_accumulation
:
bool
,
input_quantizer
:
Optional
[
Quantizer
],
weight_quantizer
:
Optional
[
Quantizer
],
output_quantizer
:
Optional
[
Quantizer
],
grad_output_quantizer
:
Optional
[
Quantizer
],
grad_input_quantizer
:
Optional
[
Quantizer
],
grad_weight_quantizer
:
Optional
[
Quantizer
],
grad_output_quantizer
:
Optional
[
Quantizer
],
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
tp_size
:
int
,
...
...
@@ -119,6 +131,8 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
debug
:
Optional
[
bool
]
=
False
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# pylint: disable=missing-function-docstring
...
...
@@ -143,11 +157,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
# Avoid quantized norm kernel if norm output will be returned
with_quantized_norm
=
(
fp8
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
ub_overlap_ag_fprop
=
(
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
...
...
@@ -180,6 +189,18 @@ class _LayerNormLinear(torch.autograd.Function):
columnwise_usage
=
False
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
force_hp_blockwise_ln_out_gather
=
(
fp8
and
with_input_all_gather
and
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
)
# Perform TP communication in high precision.
with_quantized_norm
=
(
fp8
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
force_hp_blockwise_ln_out_gather
)
# Apply normalization
nvtx_range_push
(
f
"
{
nvtx_label
}
.norm"
)
ln_out
,
mu
,
rsigma
=
apply_normalization
(
...
...
@@ -210,13 +231,13 @@ class _LayerNormLinear(torch.autograd.Function):
# norm output will be returned
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
ln_out_return
=
ln_out_total
if
fp8
:
if
fp8
or
debug
:
ln_out
=
input_quantizer
(
ln_out
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out_total
=
input_quantizer
(
ln_out_total
)
else
:
if
fp8
:
if
not
with_quantized_norm
:
if
fp8
or
debug
:
if
not
with_quantized_norm
and
not
force_hp_blockwise_ln_out_gather
:
ln_out
=
input_quantizer
(
ln_out
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag_fprop
:
...
...
@@ -229,18 +250,19 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
,
quantizer
=
(
input_quantizer
if
fp8
else
None
),
quantizer
=
(
input_quantizer
if
fp8
or
debug
else
None
),
)
else
:
if
fp8
and
not
with_quantized_norm
:
if
(
fp8
or
debug
)
and
not
with_quantized_norm
:
ln_out
=
input_quantizer
(
ln_out
)
ln_out_total
=
ln_out
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
# Cast weight to expected dtype
if
not
fp8
:
weightmat
=
weight
quantized_weight
=
False
weightmat
=
cast_if_needed
(
weight
,
activation_dtype
)
if
not
fp8
and
not
debug
:
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
else
:
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensor
)
...
...
@@ -250,6 +272,7 @@ class _LayerNormLinear(torch.autograd.Function):
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
tensor
=
weight
,
quantizer
=
weight_quantizer
,
...
...
@@ -257,11 +280,12 @@ class _LayerNormLinear(torch.autograd.Function):
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
)
# Cast bias to expected dtype
bias_dtype
=
activation_dtype
if
fp8
and
activation_dtype
==
torch
.
float32
:
if
needs_quantized_gemm
(
ln_out_total
)
and
activation_dtype
==
torch
.
float32
:
bias_dtype
=
torch
.
bfloat16
bias
=
cast_if_needed
(
bias
,
bias_dtype
)
if
bias
is
not
None
else
bias
...
...
@@ -319,9 +343,11 @@ class _LayerNormLinear(torch.autograd.Function):
clear_tensor_data
(
ln_out
,
ln_out_total
)
if
is_grad_enabled
:
ctx
.
weight_quantizer
=
weight_quantizer
ctx
.
ln_out_needs_gather
=
(
weight
.
requires_grad
and
parallel_mode
==
"column"
and
sequence_parallel
)
ctx
.
force_hp_blockwise_ln_out_gather
=
force_hp_blockwise_ln_out_gather
# Input with column-wise usage is needed for wgrad GEMM.
if
backward_needs_input
:
...
...
@@ -332,21 +358,16 @@ class _LayerNormLinear(torch.autograd.Function):
if
isinstance
(
ln_out
,
MXFP8TensorBase
)
or
not
ctx
.
ln_out_needs_gather
:
ln_out
.
update_usage
(
rowwise_usage
=
False
)
# For force_hp_blockwise_ln_out_gather, we should
# be saving the unquantized ln_out to ctx.
assert
not
force_hp_blockwise_ln_out_gather
# Weight with column-wise usage is needed for dgrad GEMM.
if
inp
.
requires_grad
:
if
isinstance
(
weightmat
,
QuantizedTensor
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
if
fp8
and
weightmat
is
not
None
:
set_offloading_param
(
weightmat
,
"weight_offloading"
,
True
)
set_offloading_param
(
ln_weight
,
"weight_offloading"
,
True
)
set_offloading_param
(
weight
,
"weight_offloading"
,
True
)
set_offloading_param
(
inputmat
,
"activation_offloading"
,
True
)
set_offloading_param
(
mu
,
"activation_offloading"
,
True
)
set_offloading_param
(
rsigma
,
"activation_offloading"
,
True
)
set_offloading_param
(
ln_out
,
"activation_offloading"
,
True
)
mark_activation_offload
(
inputmat
,
mu
,
rsigma
,
ln_out
)
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
...
...
@@ -391,6 +412,7 @@ class _LayerNormLinear(torch.autograd.Function):
if
fuse_wgrad_accumulation
and
weight
.
requires_grad
:
ctx
.
main_grad
=
weight
.
main_grad
ctx
.
grad_input_quantizer
=
grad_input_quantizer
ctx
.
grad_weight_quantizer
=
grad_weight_quantizer
ctx
.
grad_output_quantizer
=
grad_output_quantizer
ctx
.
input_quantizer
=
input_quantizer
ctx
.
owns_input
=
inputmat
is
not
inp
...
...
@@ -425,6 +447,8 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
reduce_and_update_bwd_fp8_tensors
=
FP8GlobalStateManager
.
is_first_fp8_module
()
if
in_fp8_activation_recompute_phase
():
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
=
_first_fp8_module
ctx
.
wgrad_store
=
wgrad_store
ctx
.
debug
=
debug
# Row Parallel Linear
if
ub_overlap_rs_fprop
:
...
...
@@ -434,6 +458,9 @@ class _LayerNormLinear(torch.autograd.Function):
if
sequence_parallel
:
out
,
_
=
reduce_scatter_along_first_dim
(
out
,
tp_group
)
elif
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
out
,
_
=
symmetric_all_reduce
(
out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
out
,
_
=
allreduce
(
out
,
tp_group
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
...
...
@@ -564,12 +591,27 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj_wgrad
.
set_buffer_params
(
ctx
.
grad_input_quantizer
)
dgrad_bulk
=
ub_obj_wgrad
.
get_buffer
(
ctx
.
grad_input_quantizer
)
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if
ctx
.
grad_output_quantizer
is
not
None
:
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if
ctx
.
ub_overlap_ag
and
ctx
.
fp8_recipe
.
float8_per_tensor_scaling
():
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
else
:
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
rowwise_usage
=
True
columnwise_usage
=
True
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
grad_output_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage
=
False
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
,
)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
(
grad_output
,
...
...
@@ -582,21 +624,28 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
# Prepare GEMM input
# Note: Perform tensor-parallel communication if needed
# Launch tensor-parallel communication for LayerNorm out tensor
ln_out_total
=
None
ln_out_total_work
=
None
if
ctx
.
ln_out_needs_gather
and
not
ctx
.
ub_bulk_dgrad
:
quantizer
=
None
if
ctx
.
fp8
:
if
ctx
.
input_quantizer
is
not
None
:
quantizer
=
ctx
.
input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
else
:
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
# async_op is not compatible with high precision gather since
# gather_along_first_dim does not offer callback chaining.
gather_quantizer
=
None
if
ctx
.
force_hp_blockwise_ln_out_gather
else
quantizer
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
quantizer
,
quantizer
=
gather_
quantizer
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
else
:
...
...
@@ -621,6 +670,11 @@ class _LayerNormLinear(torch.autograd.Function):
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
dgrad_gemm_use_split_accumulator
=
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
):
weight
.
update_usage
(
rowwise_usage
=
ctx
.
weight_quantizer
.
rowwise_usage
,
columnwise_usage
=
ctx
.
weight_quantizer
.
columnwise_usage
,
)
dgrad
,
*
_
=
general_gemm
(
weight
,
grad_output
,
...
...
@@ -659,6 +713,8 @@ class _LayerNormLinear(torch.autograd.Function):
# Compute grad weight tensor
wgrad
=
None
if
ctx
.
requires_wgrad
:
# Synchronize tensor-parallel communication for input tensor
if
ctx
.
ub_bulk_dgrad
:
ln_out_total
=
ub_obj_dgrad
.
get_buffer
(
ctx
.
input_quantizer
)
if
ctx
.
fp8
:
...
...
@@ -672,18 +728,32 @@ class _LayerNormLinear(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total
.
_create_transpose
()
else
:
if
ln_out_total_work
is
not
None
:
# Synchronize tensor-parallel communication
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
if
ctx
.
input_quantizer
is
not
None
and
not
isinstance
(
ln_out_total
,
QuantizedTensor
):
# Async gather may have been done in BF16
# call quantizer after gather.
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
input_quantizer
(
ln_out_total
)
# Make sure GEMM inputs have required data
if
isinstance
(
ln_out_total
,
QuantizedTensor
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
grad_output
,
QuantizedTensor
):
# This is a no-op if platform supports non-TN FP8 GEMM or the transpose
# already exists.
grad_output
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
grad_output
.
update_usage
(
columnwise_usage
=
True
)
# Figure out whether to use split accumulator
use_split_accumulator
=
_2X_ACC_WGRAD
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
use_split_accumulator
=
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
and
ub_obj_wgrad
.
is_fp8_ubuf
():
rs_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
inputmat
.
device
...
...
@@ -692,39 +762,29 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
wgrad_gemm_use_split_accumulator
=
_2X_ACC_WGRAD
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
wgrad_gemm_use_split_accumulator
=
(
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
)
wgrad
,
grad_bias_
,
*
_
,
rs_out
=
general_gemm
(
ln_out_total
,
grad_output
,
get_workspace
(),
layout
=
"NT"
,
grad
=
True
,
general_gemm_wgrad
=
functools
.
partial
(
general_gemm
,
out_dtype
=
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
workspace
=
get_workspace
(),
layout
=
"NT"
,
grad
=
True
,
bias
=
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
out
=
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
use_split_accumulator
=
wgrad_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
quantization_params
=
ctx
.
grad_weight_quantizer
,
ub
=
ub_obj_wgrad
,
ub_type
=
ub_type_wgrad
,
extra_output
=
rs_out
,
bulk_overlap
=
ctx
.
ub_bulk_wgrad
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
if
ctx
.
ub_bulk_wgrad
:
if
ub_obj_wgrad
.
is_fp8_ubuf
():
dgrad
=
rs_out
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
ln_out_total
,
grad_output
],
general_gemm_wgrad
)
else
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
None
,
local_chunk
=
True
)
wgrad
,
grad_bias_
,
_
,
rs_out
=
general_gemm_wgrad
(
ln_out_total
,
grad_output
)
if
grad_bias
is
None
:
grad_bias
=
grad_bias_
...
...
@@ -734,6 +794,17 @@ class _LayerNormLinear(torch.autograd.Function):
if
not
ctx
.
return_layernorm_output
:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data
(
ln_out_total
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
if
ctx
.
ub_bulk_wgrad
:
if
ub_obj_wgrad
.
is_fp8_ubuf
():
dgrad
=
rs_out
else
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
None
,
local_chunk
=
True
)
# Don't return grad bias if not needed
if
not
ctx
.
use_bias
:
grad_bias
=
None
# Synchronize tensor parallel communication
if
ln_out_total_work
is
not
None
:
...
...
@@ -787,18 +858,15 @@ class _LayerNormLinear(torch.autograd.Function):
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
origin_weight
,
"grad_added_to_main_grad"
):
origin_weight
.
grad_added_to_main_grad
=
True
if
getattr
(
origin_weight
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
origin_weight
.
main_grad
.
shape
,
dtype
=
origin_weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
origin_weight
.
main_grad
.
shape
),
origin_weight
.
dtype
,
zero
=
True
,
)
else
:
wgrad
=
torch
.
empty
(
origin_weight
.
main_grad
.
shape
,
dtype
=
origin_weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
origin_weight
.
main_grad
.
shape
),
origin_weight
.
dtype
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
...
...
@@ -824,12 +892,14 @@ class _LayerNormLinear(torch.autograd.Function):
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# fuse_wgrad_accumulation
None
,
# input_quantizer
None
,
# weight_quantizer
None
,
# output_quantizer
None
,
# grad_output_quantizer
None
,
# grad_input_quantizer
None
,
# grad_weight_quantizer
None
,
# grad_output_quantizer
None
,
# cpu_offloading
None
,
# tp_group
None
,
# tp_size
...
...
@@ -852,8 +922,10 @@ class _LayerNormLinear(torch.autograd.Function):
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# fsdp_group
None
,
# debug
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
)
...
...
@@ -906,6 +978,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
...
...
@@ -941,6 +1015,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
"""
def
__init__
(
...
...
@@ -970,6 +1053,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad
:
bool
=
False
,
ub_bulk_dgrad
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
name
:
str
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -985,6 +1071,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
return_layernorm_output
=
return_layernorm_output
self
.
return_layernorm_output_gathered
=
return_layernorm_output_gathered
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
self
.
name
=
name
if
TEDebugState
.
debug_enabled
:
self
.
_turn_off_unsupported_features_in_debug
()
# turn off userbuffers
if
tp_group
is
None
:
self
.
tp_size
=
tp_size
...
...
@@ -1050,6 +1142,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
assert
ub_name
is
not
None
,
"Userbuffer name [string] is not set."
self
.
ub_name
=
ub_name
if
self
.
symmetric_ar_type
is
not
None
:
assert
torch_version
()
>=
(
2
,
7
,
0
,
),
"Torch version must be at least 2.7 to use symmetric memory"
self
.
eps
=
eps
layer_norm_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
self
.
in_features
,
device
=
device
,
dtype
=
params_dtype
)
...
...
@@ -1252,6 +1351,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
inp
:
torch
.
Tensor
,
is_first_microbatch
:
Optional
[
bool
]
=
None
,
fp8_output
:
Optional
[
bool
]
=
False
,
fp8_grad
:
Optional
[
bool
]
=
False
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
...
...
@@ -1274,6 +1374,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
debug
=
TEDebugState
.
debug_enabled
if
debug
:
self
.
_validate_name
()
if
FP8GlobalStateManager
.
fp8_graph_capturing
():
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
...
...
@@ -1282,6 +1385,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
if
skip_fp8_weight_update
is
not
None
:
is_first_microbatch
=
False
if
self
.
ub_overlap_rs_fprop
:
if
get_ub
(
self
.
ub_name
+
"_fprop"
).
is_fp8_ubuf
():
fp8_output
=
True
if
self
.
ub_overlap_rs_dgrad
:
if
get_ub
(
self
.
ub_name
+
"_dgrad"
).
is_fp8_ubuf
():
fp8_grad
=
True
with
self
.
prepare_forward
(
inp
,
allow_non_contiguous
=
False
# removed .contiguous from inside the layer
)
as
inp
:
...
...
@@ -1303,13 +1413,28 @@ class LayerNormLinear(TransformerEngineBaseModule):
else
:
bias_tensor
=
getattr
(
self
,
self
.
bias_names
[
0
])
# Unused
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
)
if
debug
:
if
not
any_feature_enabled
(
quantizers
):
# If no feature is used, then run faster implementation with debug = False.
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
debug
=
False
if
isinstance
(
weight_tensor
,
QuantizedTensor
):
raise
RuntimeError
(
"FP8 weights are not supported in debug mode."
)
(
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_output_quantizer
,
grad_input_quantizer
,
)
=
self
.
_get_quantizers
(
fp8_output
)
grad_weight_quantizer
,
grad_output_quantizer
,
)
=
quantizers
if
torch
.
is_grad_enabled
():
fwd_fn
=
_LayerNormLinear
.
apply
...
...
@@ -1327,12 +1452,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch
,
self
.
fp8
,
self
.
fp8_calibration
,
self
.
wgrad_store
,
self
.
fuse_wgrad_accumulation
,
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_output_quantizer
,
grad_input_quantizer
,
grad_weight_quantizer
,
grad_output_quantizer
,
is_cpu_offload_enabled
(),
self
.
tp_group
,
self
.
tp_size
,
...
...
@@ -1357,6 +1484,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
fsdp_group
,
self
,
skip_fp8_weight_update
,
self
.
symmetric_ar_type
,
debug
,
)
out
=
fwd_fn
(
*
args
)
...
...
@@ -1374,10 +1503,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
return
out
,
ln_out
return
out
def
_get_quantizers
(
self
,
fp8_output
):
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
):
if
not
self
.
fp8
:
return
[
None
]
*
5
return
[
None
]
*
6
grad_input_quantizer
=
None
grad_weight_quantizer
=
None
grad_output_quantizer
=
None
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
...
...
@@ -1389,13 +1519,27 @@ class LayerNormLinear(TransformerEngineBaseModule):
if
torch
.
is_grad_enabled
():
grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
grad_output_quantizer
.
internal
=
True
if
fp8_grad
:
grad_input_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
return
(
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_output_quantizer
,
grad_input_quantizer
,
grad_weight_quantizer
,
grad_output_quantizer
,
)
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
):
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
assert
TEDebugState
.
debug_enabled
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
names
=
[
"activation"
,
"weight"
,
"output"
,
"dgrad"
,
"wgrad"
,
"gradient"
]
return
tuple
(
DebugQuantizer
(
self
.
name
,
name
,
q
,
self
.
tp_group
)
for
name
,
q
in
zip
(
names
,
original_quantizers
)
)
def
_customize_quantizers_float8_current_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
ab3e5a92
...
...
@@ -8,6 +8,7 @@ import warnings
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
functools
import
torch
from
torch.nn.parameter
import
Parameter
...
...
@@ -17,6 +18,7 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
.base
import
(
get_workspace
,
_ub_communicators
,
...
...
@@ -42,25 +44,31 @@ from ..utils import (
clear_tensor_data
,
requires_grad
,
non_tn_fp8_gemm_supported
,
needs_quantized_gemm
,
)
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
get_distributed_world_size
,
allreduce
,
symmetric_all_reduce
,
reduce_scatter_along_first_dim
,
gather_along_first_dim
,
use_reentrant_activation_recompute
,
in_fp8_activation_recompute_phase
,
_fsdp_scatter_tensors
,
)
from
..constants
import
dist_group_type
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..tensor.float8_tensor
import
Float8Tensor
from
..tensor.float8_tensor
import
(
Float8CurrentScalingQuantizer
,
Float8Quantizer
,
Float8Tensor
,
)
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
._common
import
apply_normalization
,
_fix_gathered_fp8_transpose
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
._common
import
apply_normalization
,
_fix_gathered_fp8_transpose
,
WeightGradStore
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
...
...
@@ -70,6 +78,8 @@ from ..tensor.quantized_tensor import (
from
..cpp_extensions
import
(
general_gemm
,
)
from
...debug.pytorch.utils
import
any_feature_enabled
from
...debug.pytorch.debug_state
import
TEDebugState
__all__
=
[
"LayerNormMLP"
]
...
...
@@ -101,7 +111,8 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
tex
.
dbias_dsrelu
),
}
# no activation fusion written yet
# Per-tensor current scaling: []
# Per-tensor current scaling or fp8 blockwise scaling: []
if
recipe
.
float8_current_scaling
()
or
recipe
.
float8_block_scaling
():
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
...
...
@@ -112,6 +123,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
None
),
}
raise
NotImplementedError
(
f
"Unhandled recipe type
{
recipe
}
"
)
def
_act_func
(
activation
:
str
,
recipe
:
Optional
[
Recipe
]
=
None
):
...
...
@@ -119,7 +131,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None):
# bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Per-tensor current scaling: []
# Per-tensor current
scaling or fp8 blockwise
scaling: []
funcs
=
_get_act_func_supported_list
(
recipe
)
if
activation
not
in
funcs
:
raise
NotImplementedError
(
"Activation type "
+
activation
+
" is not supported!"
)
...
...
@@ -145,15 +157,20 @@ class _LayerNormMLP(torch.autograd.Function):
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
fuse_wgrad_accumulation
:
bool
,
fc1_input_quantizer
:
Optional
[
Quantizer
],
fc1_weight_quantizer
:
Optional
[
Quantizer
],
fc1_output_quantizer
:
Optional
[
Quantizer
],
fc1_grad_input_quantizer
:
Optional
[
Quantizer
],
fc1_grad_weight_quantizer
:
Optional
[
Quantizer
],
fc1_grad_output_quantizer
:
Optional
[
Quantizer
],
fc2_input_quantizer
:
Optional
[
Quantizer
],
fc2_weight_quantizer
:
Optional
[
Quantizer
],
output_quantizer
:
Optional
[
Quantizer
],
grad_
fc2_out
put_quantizer
:
Optional
[
Quantizer
],
grad_
fc1_outpu
t_quantizer
:
Optional
[
Quantizer
],
grad_
in
put_quantizer
:
Optional
[
Quantizer
],
fc2_
output_quantizer
:
Optional
[
Quantizer
],
fc2_
grad_
in
put_quantizer
:
Optional
[
Quantizer
],
fc2_
grad_
weigh
t_quantizer
:
Optional
[
Quantizer
],
fc2_
grad_
out
put_quantizer
:
Optional
[
Quantizer
],
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
tp_size
:
int
,
...
...
@@ -179,6 +196,8 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
debug
:
Optional
[
bool
]
=
False
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# pylint: disable=missing-function-docstring
...
...
@@ -207,16 +226,31 @@ class _LayerNormMLP(torch.autograd.Function):
if
ln_bias
is
not
None
:
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
# Avoid quantized norm kernel if norm output will be returned
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
with_quantized_norm
=
(
fp8
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
fp8
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
debug
)
if
isinstance
(
fc1_input_quantizer
,
Float8BlockQuantizer
):
# Kernels not available for norm fusion.
with_quantized_norm
=
False
tp_world_size
=
get_distributed_world_size
(
tp_group
)
ub_overlap_ag
=
ub_overlap_ag
and
is_grad_enabled
and
not
return_layernorm_output_gathered
ub_overlap_rs
=
ub_overlap_rs
and
is_grad_enabled
backwards_needs_fc1_input
=
is_grad_enabled
and
fc1_weight
.
requires_grad
# TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe
force_hp_fc1_input_gather
=
(
fp8
and
sequence_parallel
and
isinstance
(
fc1_input_quantizer
,
Float8BlockQuantizer
)
)
# Perform TP communication in high precision.
# Configure quantizer for norm output
if
fp8
:
if
fc1_input_quantizer
is
None
:
...
...
@@ -257,13 +291,14 @@ class _LayerNormMLP(torch.autograd.Function):
# norm output will be returned
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
ln_out_return
=
ln_out_total
if
fp8
:
if
fp8
or
debug
:
if
not
force_hp_fc1_input_gather
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out_total
=
fc1_input_quantizer
(
ln_out_total
)
else
:
if
fp8
:
if
not
with_quantized_norm
:
if
fp8
or
debug
:
if
not
with_quantized_norm
and
not
force_hp_fc1_input_gather
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
...
...
@@ -276,18 +311,21 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
,
quantizer
=
(
fc1_input_quantizer
if
fp8
else
None
),
quantizer
=
(
fc1_input_quantizer
if
fp8
or
debug
else
None
),
)
else
:
if
fp8
and
not
with_quantized_norm
:
# NOTE: force_hp_fc1_input_gather is redundant with else, but
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if
(
fp8
or
debug
)
and
not
with_quantized_norm
and
not
force_hp_fc1_input_gather
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out_total
=
ln_out
# Cast weights to expected dtype
if
not
fp8
:
fc
1
_weight_final
=
cast_if_needed
(
fc
1
_weight
,
activation_dtype
)
fc2_weight_final
=
cast_if_needed
(
fc2_weight
,
activation_dtype
)
else
:
fc1_weight_final
=
fc1_weight
fc
2
_weight_final
=
fc
2
_weight
if
fp8
or
debug
:
# If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc.
# FP8 cast to workspace buffer
...
...
@@ -299,6 +337,7 @@ class _LayerNormMLP(torch.autograd.Function):
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
)
fc2_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
fc2_weight_final
=
module
.
get_weight_workspace
(
...
...
@@ -308,11 +347,15 @@ class _LayerNormMLP(torch.autograd.Function):
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
)
else
:
fc1_weight_final
=
cast_if_needed
(
fc1_weight_final
,
activation_dtype
)
fc2_weight_final
=
cast_if_needed
(
fc2_weight_final
,
activation_dtype
)
# Cast biases to expected dtype
bias_dtype
=
activation_dtype
if
fp8
and
activation_dtype
==
torch
.
float32
:
if
needs_quantized_gemm
(
ln_out_total
)
and
activation_dtype
==
torch
.
float32
:
bias_dtype
=
torch
.
bfloat16
if
fc1_bias
is
not
None
:
fc1_bias
=
cast_if_needed
(
fc1_bias
,
bias_dtype
)
...
...
@@ -333,6 +376,7 @@ class _LayerNormMLP(torch.autograd.Function):
# - bias_gelu_fusion - only for full precision.
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
if
activation
!=
"gelu"
:
# blockwise scaled gemms don't support gemm_gelu_fusion in fwd.
gemm_gelu_fusion
=
bias_gelu_fusion
=
False
else
:
if
fp8
:
...
...
@@ -341,13 +385,16 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_gelu_fusion
=
True
if
gemm_gelu_fusion
and
bias_gelu_fusion
:
gemm_gelu_fusion
=
False
if
debug
:
gemm_gelu_fusion
=
False
fc1_outputs
=
general_gemm
(
fc1_weight_final
,
ln_out_total
,
get_workspace
(),
quantization_params
=
(
fc2_input_quantizer
if
gemm_gelu_fusion
else
None
# fused gelu output is in fp8
fc2_input_quantizer
if
gemm_gelu_fusion
else
fc1_output_quantizer
# fused gelu output is in fp8
),
out_dtype
=
activation_dtype
,
bias
=
(
...
...
@@ -358,6 +405,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub
=
ub_obj_lnout
,
ub_type
=
tex
.
CommOverlapType
.
AG
if
ub_overlap_ag
else
None
,
)
if
not
is_grad_enabled
and
(
ln_out_total
is
not
ln_out_return
):
clear_tensor_data
(
ln_out_total
)
...
...
@@ -371,8 +419,17 @@ class _LayerNormMLP(torch.autograd.Function):
act_out
=
bias_gelu_fused
(
fc1_out_without_bias
,
fc1_bias
)
elif
gemm_gelu_fusion
:
act_out
,
_
,
fc1_out
,
_
=
fc1_outputs
elif
debug
:
fc1_out
,
*
_
=
fc1_outputs
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
fc2_input_quantizer
(
act_out
)
else
:
fc1_out
,
*
_
=
fc1_outputs
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_block_scaling
():
# tex.quantize does not support GELU fusion for blockwise.
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
tex
.
quantize
(
act_out
,
fc2_input_quantizer
)
else
:
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
)
if
not
is_grad_enabled
:
...
...
@@ -403,7 +460,7 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace
(),
out_dtype
=
activation_dtype
,
bias
=
fc2_bias
,
quantization_params
=
output_quantizer
,
quantization_params
=
fc2_
output_quantizer
,
out
=
fc2_out
,
use_split_accumulator
=
_2X_ACC_FPROP
,
ub
=
ub_obj_fc2out
,
...
...
@@ -412,7 +469,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
# Weight with column-wise usage is needed for dgrad GEMM.
if
is_grad_enabled
and
inp
.
requires_grad
:
if
is_grad_enabled
:
if
isinstance
(
fc1_weight_final
,
QuantizedTensor
):
fc1_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
fc2_weight_final
,
QuantizedTensor
):
...
...
@@ -422,23 +479,9 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
else
:
if
cpu_offloading
:
if
fp8
and
fc1_weight_final
is
not
None
:
set_offloading_param
(
fc1_weight_final
,
"weight_offloading"
,
True
)
if
fp8
and
fc2_weight_final
is
not
None
:
set_offloading_param
(
fc2_weight_final
,
"weight_offloading"
,
True
)
set_offloading_param
(
ln_weight
,
"weight_offloading"
,
True
)
set_offloading_param
(
fc1_weight
,
"weight_offloading"
,
True
)
set_offloading_param
(
fc2_weight
,
"weight_offloading"
,
True
)
set_offloading_param
(
fc1_bias
,
"weight_offloading"
,
True
)
set_offloading_param
(
inputmat
,
"activation_offloading"
,
True
)
set_offloading_param
(
mu
,
"activation_offloading"
,
True
)
set_offloading_param
(
rsigma
,
"activation_offloading"
,
True
)
set_offloading_param
(
mu
,
"activation_offloading"
,
True
)
set_offloading_param
(
ln_out
,
"activation_offloading"
,
True
)
set_offloading_param
(
fc1_out
,
"activation_offloading"
,
True
)
set_offloading_param
(
fc1_out_without_bias
,
"activation_offloading"
,
True
)
set_offloading_param
(
act_out
,
"activation_offloading"
,
True
)
mark_activation_offload
(
inputmat
,
mu
,
rsigma
,
ln_out
,
fc1_out
,
fc1_out_without_bias
,
act_out
)
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
...
...
@@ -455,10 +498,14 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_final
if
fp8
and
not
isinstance
(
fc2_weight
,
Float8Tensor
)
else
None
,
)
ctx
.
fc1_weight_quantizer
=
fc1_weight_quantizer
ctx
.
fc2_weight_quantizer
=
fc2_weight_quantizer
if
not
fc1_weight
.
requires_grad
:
if
not
return_layernorm_output
:
clear_tensor_data
(
ln_out
)
ln_out
=
None
elif
force_hp_fc1_input_gather
:
assert
not
isinstance
(
ln_out
,
QuantizedTensor
)
if
not
fc2_weight
.
requires_grad
:
clear_tensor_data
(
act_out
)
act_out
=
None
...
...
@@ -487,11 +534,15 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
tensor_objects
=
tensor_objects
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
grad_fc1_output_quantizer
=
grad_fc1_output_quantizer
ctx
.
grad_fc2_output_quantizer
=
grad_fc2_output_quantizer
ctx
.
grad_input_quantizer
=
grad_input_quantizer
ctx
.
fc2_input_quantizer
=
fc2_input_quantizer
ctx
.
force_hp_fc1_input_gather
=
force_hp_fc1_input_gather
ctx
.
fc1_grad_input_quantizer
=
fc1_grad_input_quantizer
ctx
.
fc1_grad_weight_quantizer
=
fc1_grad_weight_quantizer
ctx
.
fc1_grad_output_quantizer
=
fc1_grad_output_quantizer
ctx
.
fc2_grad_input_quantizer
=
fc2_grad_input_quantizer
ctx
.
fc2_grad_weight_quantizer
=
fc2_grad_weight_quantizer
ctx
.
fc2_grad_output_quantizer
=
fc2_grad_output_quantizer
ctx
.
fc1_input_quantizer
=
fc1_input_quantizer
ctx
.
fc2_input_quantizer
=
fc2_input_quantizer
ctx
.
fc1_weight_requires_grad
=
fc1_weight
.
requires_grad
ctx
.
fc2_weight_requires_grad
=
fc2_weight
.
requires_grad
...
...
@@ -502,6 +553,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
activation_dtype
=
activation_dtype
ctx
.
activation
=
activation
ctx
.
fp8
=
fp8
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
...
...
@@ -523,6 +575,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
ub_bulk_dgrad
=
ub_bulk_dgrad
ctx
.
ub_overlap_rs_dgrad
=
ub_overlap_rs_dgrad
ctx
.
ub_overlap_ag
=
ub_overlap_ag
ctx
.
debug
=
debug
ctx
.
requires_dgrad
=
(
inp
.
requires_grad
or
ln_weight
.
requires_grad
or
ln_bias
.
requires_grad
...
...
@@ -537,12 +590,19 @@ class _LayerNormMLP(torch.autograd.Function):
if
in_fp8_activation_recompute_phase
():
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
=
_first_fp8_module
ctx
.
wgrad_store
=
wgrad_store
# Row Parallel Linear
if
ub_overlap_rs
:
fc2_out
=
rs_out
elif
set_parallel_mode
and
sequence_parallel
:
fc2_out
,
_
=
reduce_scatter_along_first_dim
(
fc2_out
,
tp_group
)
elif
set_parallel_mode
and
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
fc2_out
,
_
=
symmetric_all_reduce
(
fc2_out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
fc2_out
,
_
=
allreduce
(
fc2_out
,
tp_group
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
...
...
@@ -643,15 +703,27 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
ub_bulk_dgrad
=
ctx
.
fc1_weight_requires_grad
and
ctx
.
ub_bulk_dgrad
ctx
.
ub_bulk_wgrad
=
ctx
.
fc1_weight_requires_grad
and
ctx
.
ub_bulk_wgrad
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
if
ctx
.
grad_fc2_output_quantizer
is
not
None
:
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if
ctx
.
ub_overlap_ag
and
ctx
.
fp8_recipe
.
float8_per_tensor_scaling
():
ctx
.
grad_fc2_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
else
:
ctx
.
grad_fc2_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
# Configure quantizer for FC2 grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if
ctx
.
fc2_grad_output_quantizer
is
not
None
:
rowwise_usage
=
True
columnwise_usage
=
True
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
fc2_grad_output_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage
=
False
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
,
)
# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad
=
None
if
ctx
.
ub_overlap_ag
:
ub_obj_fc2_dgrad
=
get_ub
(
"fc2_dgrad"
)
...
...
@@ -660,11 +732,10 @@ class _LayerNormMLP(torch.autograd.Function):
grad_output
,
fc2_bias_grad
,
)
=
TransformerEngineBaseModule
.
grad_output_preprocess
(
ctx
,
grad_outputs
[
0
],
True
,
ctx
.
grad
_fc2
_output_quantizer
ctx
,
grad_outputs
[
0
],
True
,
ctx
.
fc2_
grad_output_quantizer
)
# Prepare FC1 GEMM input
# Note: Perform tensor-parallel communication if needed
# Launch tensor-parallel communication for FC1 GEMM input
ln_out_total
=
None
ln_out_total_work
=
None
if
(
...
...
@@ -674,14 +745,20 @@ class _LayerNormMLP(torch.autograd.Function):
and
not
ctx
.
ub_bulk_dgrad
):
quantizer
=
None
if
ctx
.
fp8
:
if
ctx
.
fp8
or
ctx
.
debug
:
quantizer
=
ctx
.
fc1_input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
else
:
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
gather_quantizer
=
None
if
ctx
.
force_hp_fc1_input_gather
else
quantizer
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
quantizer
,
quantizer
=
gather_
quantizer
,
)
else
:
ln_out_total
=
ln_out
...
...
@@ -693,17 +770,26 @@ class _LayerNormMLP(torch.autograd.Function):
)
else
:
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
# There are
5
possible fusion paths
# There are
6
possible fusion paths
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
# 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize
# 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize
# 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm
# 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm
fc2_dgrad_gemm_gelu_fusion
=
(
not
ctx
.
fp8
and
(
ctx
.
activation
==
"gelu"
)
and
(
not
ctx
.
bias_gelu_fusion
)
not
ctx
.
fp8
and
(
ctx
.
activation
==
"gelu"
)
and
(
not
ctx
.
bias_gelu_fusion
)
and
(
not
ctx
.
debug
)
)
# FC2 DGRAD; Unconditional
if
ctx
.
fc2_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc2_weight
,
QuantizedTensor
):
ctx
.
fc2_weight
.
update_usage
(
rowwise_usage
=
ctx
.
fc2_weight_quantizer
.
rowwise_usage
,
columnwise_usage
=
ctx
.
fc2_weight_quantizer
.
columnwise_usage
,
)
gemm_output
,
*
_
=
general_gemm
(
fc2_weight
,
grad_output
,
...
...
@@ -711,7 +797,9 @@ class _LayerNormMLP(torch.autograd.Function):
layout
=
"NN"
,
grad
=
True
,
quantization_params
=
(
ctx
.
grad_fc1_output_quantizer
if
fc2_dgrad_gemm_gelu_fusion
else
None
ctx
.
fc1_grad_input_quantizer
if
fc2_dgrad_gemm_gelu_fusion
or
ctx
.
debug
else
None
),
# high precision to activation
out_dtype
=
ctx
.
activation_dtype
,
gelu
=
fc2_dgrad_gemm_gelu_fusion
,
...
...
@@ -734,39 +822,65 @@ class _LayerNormMLP(torch.autograd.Function):
if
isinstance
(
grad_output
,
QuantizedTensor
):
grad_output
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
fc2_wgrad
,
fc2_bias_grad_
,
*
_
=
general_gemm
(
act_out
,
grad_output
,
get_workspace
(),
grad_arg
=
True
if
ctx
.
fp8
and
ctx
.
fp8_recipe
.
float8_block_scaling
():
grad_arg
=
False
general_gemm_fc2_wgrad
=
functools
.
partial
(
general_gemm
,
out_dtype
=
(
origin_fc2_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
quantization_params
=
None
,
# wgrad in high precision
workspace
=
get_workspace
(),
quantization_params
=
ctx
.
fc2_grad_weight_quantizer
,
# wgrad in high precision
layout
=
"NT"
,
grad
=
True
,
bias
=
fc2_bias
if
fc2_bias_grad
is
None
else
None
,
grad
=
grad_arg
,
bias
=
fc2_bias
if
fc2_bias
is
not
None
and
fc2_bias_grad
is
None
else
None
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
out
=
origin_fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
)
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
act_out
,
grad_output
],
general_gemm_fc2_wgrad
)
fc2_wgrad
=
None
else
:
fc2_wgrad
,
fc2_bias_grad_
,
*
_
=
general_gemm_fc2_wgrad
(
act_out
,
grad_output
,
)
if
fc2_bias_grad
is
None
:
if
(
ctx
.
fp8
and
ctx
.
fp8_recipe
.
float8_block_scaling
()
and
fc2_bias
is
not
None
):
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_
=
act_out
.
view
(
-
1
,
act_out
.
shape
[
-
1
]).
sum
(
dim
=
0
)
fc2_bias_grad
=
fc2_bias_grad_
del
fc2_bias_grad_
if
ctx
.
wgrad_store
is
not
None
and
not
ctx
.
wgrad_store
.
delay_wgrad_compute
():
clear_tensor_data
(
act_out
)
# bias computation
fc1_bias_grad
=
None
fuse_gemm_and_bias_fc1_wgrad
=
False
if
ctx
.
grad
_fc1
_output_quantizer
is
not
None
:
ctx
.
grad
_fc1
_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
if
ctx
.
fc1_
grad_output_quantizer
is
not
None
:
ctx
.
fc1_
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
if
ctx
.
bias_gelu_fusion
:
# Fusion: gemm, bias + gelu
assert
ctx
.
activation
==
"gelu"
assert
not
ctx
.
fp8
fc1_bias_grad
,
dact
=
bgrad_dgelu_fused
(
fc2_dgrad
,
fc1_out_without_bias
,
fc1_bias
)
if
ctx
.
grad_fc1_output_quantizer
is
not
None
:
dact
=
ctx
.
grad_fc1_output_quantizer
(
dact
)
if
ctx
.
fc1_grad_output_quantizer
is
not
None
:
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
elif
ctx
.
debug
:
dact_func
=
_act_func
(
ctx
.
activation
)[
1
]
dact
=
dact_func
(
fc2_dgrad
,
fc1_out
.
to
(
ctx
.
activation_dtype
),
None
)
fc1_bias_grad
=
dact
.
sum
(
dim
=
0
)
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
elif
(
_act_func
(
ctx
.
activation
,
ctx
.
fp8_recipe
if
ctx
.
fp8
else
None
)[
2
]
is
not
None
and
ctx
.
fp8
...
...
@@ -776,7 +890,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
activation
,
ctx
.
fp8_recipe
if
ctx
.
fp8
else
None
)[
2
]
fc1_bias_grad
,
dact
=
dbias_dact_quantize_func
(
fc2_dgrad
,
fc1_out
.
to
(
ctx
.
activation_dtype
),
ctx
.
grad
_fc1
_output_quantizer
fc2_dgrad
,
fc1_out
.
to
(
ctx
.
activation_dtype
),
ctx
.
fc1_
grad_output_quantizer
)
# quantize bgrad gelu fused
else
:
# Fusion: gemm + gelu,
...
...
@@ -789,7 +903,14 @@ class _LayerNormMLP(torch.autograd.Function):
)
# activation in high precision
if
ctx
.
fp8
:
fc1_bias_grad
,
dact
=
tex
.
bgrad_quantize
(
dact
,
ctx
.
grad_fc1_output_quantizer
)
# TODO float8 blockwise current scaling has no bgrad fusion for now
if
isinstance
(
ctx
.
fc1_grad_output_quantizer
,
Float8BlockQuantizer
):
fc1_bias_grad
=
dact
.
view
(
-
1
,
dact
.
shape
[
-
1
]).
sum
(
dim
=
0
)
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
else
:
fc1_bias_grad
,
dact
=
tex
.
bgrad_quantize
(
dact
,
ctx
.
fc1_grad_output_quantizer
)
else
:
fuse_gemm_and_bias_fc1_wgrad
=
(
True
# fc1_bias_grad is computed later, fused with wgrad gemm for the FC1
...
...
@@ -836,12 +957,20 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_bulk
=
ub_obj_fc1_wgrad
.
get_buffer
(
None
)
# FC1 DGRAD: Unconditional
if
ctx
.
fc1_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc1_weight_quantizer
,
QuantizedTensor
):
ctx
.
fc1_weight
.
update_usage
(
rowwise_usage
=
ctx
.
fc1_weight_quantizer
.
rowwise_usage
,
columnwise_usage
=
ctx
.
fc1_weight_quantizer
.
columnwise_usage
,
)
fc1_dgrad
,
*
_
,
fc1_dgrad_rs_out
=
general_gemm
(
fc1_weight
,
dact
,
get_workspace
(),
out
=
fc1_dgrad_bulk
,
out_dtype
=
ctx
.
activation_dtype
,
quantization_params
=
ctx
.
fc1_grad_input_quantizer
,
layout
=
"NN"
,
grad
=
True
,
ub
=
ub_obj_fc1_dgrad
,
...
...
@@ -869,6 +998,8 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 WGRAD
fc1_wgrad
=
None
if
ctx
.
fc1_weight_requires_grad
:
# Synchronize tensor-parallel communication for FC1 GEMM input tensor
if
ctx
.
ub_bulk_dgrad
:
ln_out_total
=
ub_obj_fc1_dgrad
.
get_buffer
(
ctx
.
fc1_input_quantizer
)
if
ctx
.
fp8
:
...
...
@@ -880,34 +1011,41 @@ class _LayerNormMLP(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total
.
_create_transpose
()
else
:
if
ln_out_total_work
is
not
None
:
# Synchronize tensor-parallel communication
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
if
ctx
.
fc1_input_quantizer
is
not
None
and
not
isinstance
(
ln_out_total
,
QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx
.
fc1_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
fc1_input_quantizer
(
ln_out_total
)
# Make sure GEMM inputs have
expect
ed data
# Make sure GEMM inputs have
requir
ed data
if
isinstance
(
ln_out_total
,
QuantizedTensor
):
ln_out_total
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
dact
,
QuantizedTensor
):
dact
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
dact
.
update_usage
(
columnwise_usage
=
True
)
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
and
ub_obj_fc1_wgrad
.
is_fp8_ubuf
():
fc1_dgrad_rs_out
=
torch
.
empty
(
fc1_dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
"cuda"
)
fc1_wgrad_outputs
=
general_gemm
(
ln_out_total
,
dact
,
get_workspace
(),
# wgrad GEMM
general_gemm_fc1_wgrad
=
functools
.
partial
(
general_gemm
,
out_dtype
=
(
origin_fc1_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
workspace
=
get_workspace
(),
layout
=
"NT"
,
quantization_params
=
ctx
.
fc1_grad_weight_quantizer
,
grad
=
fuse_gemm_and_bias_fc1_wgrad
,
bias
=
fc1_bias
if
fuse_gemm_and_bias_fc1_wgrad
else
None
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
...
...
@@ -917,6 +1055,16 @@ class _LayerNormMLP(torch.autograd.Function):
extra_output
=
fc1_dgrad_rs_out
,
bulk_overlap
=
ctx
.
ub_bulk_wgrad
,
)
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
ln_out_total
,
dact
],
general_gemm_fc1_wgrad
)
fc1_wgrad
=
None
if
fuse_gemm_and_bias_fc1_wgrad
:
fc1_bias_grad
=
None
else
:
fc1_wgrad_outputs
=
general_gemm_fc1_wgrad
(
ln_out_total
,
dact
,
)
clear_tensor_data
(
ln_out_total
,
dact
)
...
...
@@ -931,7 +1079,7 @@ class _LayerNormMLP(torch.autograd.Function):
else
:
fc1_dgrad
=
ub_obj_fc1_wgrad
.
get_buffer
(
None
,
local_chunk
=
True
)
#
Synchronize
tensor
parallel communication
#
Make sure all
tensor
-
parallel communication
is finished
if
ln_out_total_work
is
not
None
:
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
...
...
@@ -1040,15 +1188,20 @@ class _LayerNormMLP(torch.autograd.Function):
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# fuse_wgrad_accumulation
None
,
# fc1_input_quantizer
None
,
# fc1_weight_quantizer
None
,
# fc2_input_quantizer
None
,
# fc2_weight_quantizer
None
,
# output_quantizer
None
,
# grad_fc2_output_quantizer
None
,
# grad_fc1_output_quantizer
None
,
# grad_input_quantizer
None
,
# fc1_input_quantizer,
None
,
# fc1_weight_quantizer,
None
,
# fc1_output_quantizer,
None
,
# fc1_grad_input_quantizer,
None
,
# fc1_grad_weight_quantizer,
None
,
# fc1_grad_output_quantizer,
None
,
# fc2_input_quantizer,
None
,
# fc2_weight_quantizer,
None
,
# fc2_output_quantizer,
None
,
# fc2_grad_input_quantizer,
None
,
# fc2_grad_weight_quantizer,
None
,
# fc2_grad_output_quantizer,
None
,
# cpu_offloading
None
,
# tp_group
None
,
# tp_size
...
...
@@ -1074,6 +1227,8 @@ class _LayerNormMLP(torch.autograd.Function):
None
,
# fsdp_group
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
None
,
# debug
)
...
...
@@ -1126,6 +1281,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
...
...
@@ -1168,6 +1325,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
"""
def
__init__
(
...
...
@@ -1195,10 +1361,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
zero_centered_gamma
:
bool
=
False
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
ub_overlap_ag
:
bool
=
False
,
name
:
str
=
None
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_rs_dgrad
:
bool
=
False
,
ub_bulk_dgrad
:
bool
=
False
,
ub_bulk_wgrad
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -1217,6 +1386,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
self
.
set_parallel_mode
=
set_parallel_mode
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
symmetric_ar_type
=
symmetric_ar_type
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
self
.
gemm_gelu_fusion
=
(
...
...
@@ -1224,6 +1394,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
and
self
.
activation
==
"gelu"
and
((
_ub_communicators
is
None
)
or
(
not
get_ub
(
"fc1_fprop"
).
is_atomic_gemm
()))
)
self
.
name
=
name
if
TEDebugState
.
debug_enabled
:
self
.
_turn_off_unsupported_features_in_debug
()
# turn off userbuffers
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
if
tp_group
is
None
:
self
.
tp_size
=
tp_size
...
...
@@ -1252,6 +1428,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_dgrad
and
self
.
sequence_parallel
and
not
self
.
ub_overlap_rs_dgrad
)
if
self
.
symmetric_ar_type
is
not
None
:
assert
torch_version
()
>=
(
2
,
7
,
0
,
),
"Torch version must be at least 2.7 to use symmetric memory"
# Initialize params in FP8
with_fp8_params
=
FP8GlobalStateManager
.
with_fp8_parameters
()
...
...
@@ -1384,7 +1567,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
@
no_torch_dynamo
()
def
forward
(
self
,
inp
:
torch
.
Tensor
,
is_first_microbatch
:
Optional
[
bool
]
=
None
self
,
inp
:
torch
.
Tensor
,
is_first_microbatch
:
Optional
[
bool
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
"""
Apply layer normalization to the input followed by a feedforward network (MLP Block).
...
...
@@ -1407,6 +1592,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
debug
=
TEDebugState
.
debug_enabled
if
debug
:
self
.
_validate_name
()
if
FP8GlobalStateManager
.
fp8_graph_capturing
():
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
...
...
@@ -1415,18 +1603,41 @@ class LayerNormMLP(TransformerEngineBaseModule):
if
skip_fp8_weight_update
is
not
None
:
is_first_microbatch
=
False
fp8_output
=
False
if
self
.
ub_overlap_rs
:
if
get_ub
(
"fc2_fprop"
).
is_fp8_ubuf
():
fp8_output
=
True
with
self
.
prepare_forward
(
inp
,
num_gemms
=
2
)
as
inp
:
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
)
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
)
)
if
debug
:
if
not
any_feature_enabled
(
quantizers
):
quantizers
=
self
.
_get_quantizers
(
fp8_output
)
debug
=
False
if
isinstance
(
self
.
fc1_weight
,
QuantizedTensor
):
raise
RuntimeError
(
"FP8 weights are not supported in debug mode."
)
# Get quantizers
(
fc1_input_quantizer
,
fc1_weight_quantizer
,
fc1_output_quantizer
,
fc1_grad_input_quantizer
,
fc1_grad_weight_quantizer
,
fc1_grad_output_quantizer
,
fc2_input_quantizer
,
fc2_weight_quantizer
,
output_quantizer
,
grad_
fc1_out
put_quantizer
,
grad_
fc2_outpu
t_quantizer
,
grad_
in
put_quantizer
,
)
=
self
.
_get_
quantizers
()
fc2_
output_quantizer
,
fc2_
grad_
in
put_quantizer
,
fc2_
grad_
weigh
t_quantizer
,
fc2_
grad_
out
put_quantizer
,
)
=
quantizers
# Get weight tensors
fc1_weight
=
self
.
fc1_weight
...
...
@@ -1462,15 +1673,20 @@ class LayerNormMLP(TransformerEngineBaseModule):
is_first_microbatch
,
self
.
fp8
,
self
.
fp8_calibration
,
self
.
wgrad_store
,
self
.
fuse_wgrad_accumulation
,
fc1_input_quantizer
,
fc1_weight_quantizer
,
fc1_output_quantizer
,
fc1_grad_input_quantizer
,
fc1_grad_weight_quantizer
,
fc1_grad_output_quantizer
,
fc2_input_quantizer
,
fc2_weight_quantizer
,
output_quantizer
,
grad_input_quantizer
,
grad_
fc1_outpu
t_quantizer
,
grad
_fc2
_output_quantizer
,
fc2_
output_quantizer
,
fc2_
grad_input_quantizer
,
fc2_
grad_
weigh
t_quantizer
,
fc2_
grad_output_quantizer
,
is_cpu_offload_enabled
(),
self
.
tp_group
,
self
.
tp_size
,
...
...
@@ -1479,7 +1695,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
activation_dtype
,
self
.
return_layernorm_output
,
self
.
return_layernorm_output_gathered
,
self
.
bias_gelu_nvfusion
and
not
self
.
fp8
,
self
.
bias_gelu_nvfusion
and
not
self
.
fp8
and
not
debug
,
self
.
set_parallel_mode
,
torch
.
is_grad_enabled
(),
self
.
fwd_ln_sm_margin
if
torch
.
is_grad_enabled
()
else
self
.
inf_ln_sm_margin
,
...
...
@@ -1492,10 +1708,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
ub_overlap_rs_dgrad
,
self
.
ub_bulk_dgrad
,
self
.
ub_bulk_wgrad
,
self
.
gemm_gelu_fusion
,
self
.
gemm_gelu_fusion
and
not
debug
,
self
.
fsdp_group
,
self
,
skip_fp8_weight_update
,
self
.
symmetric_ar_type
,
debug
,
)
out
=
fwd_fn
(
*
args
)
...
...
@@ -1513,17 +1731,21 @@ class LayerNormMLP(TransformerEngineBaseModule):
return
out
,
ln_out
return
out
def
_get_quantizers
(
self
):
def
_get_quantizers
(
self
,
fp8_output
):
(
fc1_input_quantizer
,
fc1_weight_quantizer
,
fc1_output_quantizer
,
fc1_grad_input_quantizer
,
fc1_grad_weight_quantizer
,
fc1_grad_output_quantizer
,
fc2_input_quantizer
,
fc2_weight_quantizer
,
output_quantizer
,
grad_
fc1_out
put_quantizer
,
grad_
fc2_outpu
t_quantizer
,
grad_
in
put_quantizer
,
)
=
[
None
]
*
8
fc2_
output_quantizer
,
fc2_
grad_
in
put_quantizer
,
fc2_
grad_
weigh
t_quantizer
,
fc2_
grad_
out
put_quantizer
,
)
=
[
None
]
*
12
if
self
.
fp8
:
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
.
internal
=
False
# temporary
...
...
@@ -1531,32 +1753,59 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_weight_quantizer
.
internal
=
True
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
isinstance
(
fc2_input_quantizer
,
MXFP8Quantizer
)
rowwise
=
True
,
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
)),
)
fc2_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
]
fc2_weight_quantizer
.
internal
=
True
if
fp8_output
:
fc2_output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_OUTPUT
]
if
torch
.
is_grad_enabled
():
grad
_fc2
_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
fc2_
grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
grad
_fc2
_output_quantizer
.
internal
=
True
grad
_fc1
_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
fc2_
grad_output_quantizer
.
internal
=
True
fc1_
grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
grad_fc1_output_quantizer
.
internal
=
True
grad_input_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT2
]
grad_input_quantizer
.
internal
=
True
fc1_grad_output_quantizer
.
internal
=
True
return
(
fc1_input_quantizer
,
fc1_weight_quantizer
,
fc1_output_quantizer
,
fc1_grad_input_quantizer
,
fc1_grad_weight_quantizer
,
fc1_grad_output_quantizer
,
fc2_input_quantizer
,
fc2_weight_quantizer
,
output_quantizer
,
grad_fc1_output_quantizer
,
grad_fc2_output_quantizer
,
grad_input_quantizer
,
fc2_output_quantizer
,
fc2_grad_input_quantizer
,
fc2_grad_weight_quantizer
,
fc2_grad_output_quantizer
,
)
def
_get_debug_quantizers
(
self
,
fp8_output
):
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
base_quantizers
=
list
(
self
.
_get_quantizers
(
fp8_output
))
assert
TEDebugState
.
debug_enabled
def
make_debug
(
prefix
,
offset
):
labels
=
[
"activation"
,
"weight"
,
"output"
,
"dgrad"
,
"wgrad"
,
"gradient"
]
return
[
DebugQuantizer
(
f
"
{
self
.
name
}
.
{
prefix
}
"
,
label
,
None
if
label
in
(
"dgrad"
,
"wgrad"
)
else
base_quantizers
[
i
+
offset
],
self
.
tp_group
,
)
for
i
,
label
in
enumerate
(
labels
)
]
return
tuple
(
make_debug
(
"fc1"
,
0
)
+
make_debug
(
"fc2"
,
6
))
def
_customize_quantizers_float8_current_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
...
...
@@ -1602,14 +1851,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
else
:
# grad
_fc2
_output_quantizer: set configs about amax epsilon and power_2_scale for grad
_fc2
_output_quantizer
#
fc2_
grad_output_quantizer: set configs about amax epsilon and power_2_scale for
fc2_
grad_output_quantizer
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
# grad
_fc1
_output_quantizer: also set numerical configs for grad
_fc1
_output_quantizer
#
fc1_
grad_output_quantizer: also set numerical configs for
fc1_
grad_output_quantizer
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
].
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
...
...
@@ -1617,10 +1866,48 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_INPUT1
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
# grad
_fc2
_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
#
fc2_
grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
backward_dw
(
self
):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
return
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormMLP_wgrad"
):
(
fc2_wgrad
,
fc2_bias_grad_
,
*
_
),
tensor_list_fc2
=
self
.
wgrad_store
.
pop
()
if
self
.
use_bias
and
self
.
fc1_bias
.
grad
is
None
:
(
fc1_wgrad
,
fc1_bias_grad
,
*
_
),
_
=
self
.
wgrad_store
.
pop
()
else
:
(
fc1_wgrad
,
*
_
),
_
=
self
.
wgrad_store
.
pop
()
fc1_bias_grad
=
None
if
self
.
use_bias
:
if
self
.
fc2_bias
.
grad
is
None
:
if
(
self
.
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_block_scaling
()
and
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
):
act_out
=
tensor_list_fc2
[
0
]
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_
=
act_out
.
view
(
-
1
,
act_out
.
shape
[
-
1
]).
sum
(
dim
=
0
)
self
.
fc2_bias
.
grad
=
fc2_bias_grad_
.
to
(
self
.
fc2_bias
.
dtype
)
if
self
.
fc1_bias
.
grad
is
None
:
self
.
fc1_bias
.
grad
=
fc1_bias_grad
.
to
(
self
.
fc1_bias
.
dtype
)
if
not
self
.
fuse_wgrad_accumulation
:
if
self
.
fc2_weight
.
grad
is
None
:
self
.
fc2_weight
.
grad
=
fc2_wgrad
.
to
(
self
.
fc2_weight
.
dtype
)
if
self
.
fc1_weight
.
grad
is
None
:
self
.
fc1_weight
.
grad
=
fc1_wgrad
.
to
(
self
.
fc1_weight
.
dtype
)
del
fc2_bias_grad_
del
fc2_wgrad
del
fc1_wgrad
del
fc1_bias_grad
transformer_engine/pytorch/module/linear.py
View file @
ab3e5a92
...
...
@@ -7,36 +7,41 @@ from typing import Callable, Dict, Optional, Tuple, Union
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
functools
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
.base
import
(
get_workspace
,
get_ub
,
TransformerEngineBaseModule
,
get_dummy_wgrad
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
._common
import
noop_cat
,
_fix_gathered_fp8_transpose
from
._common
import
noop_cat
,
_fix_gathered_fp8_transpose
,
WeightGradStore
from
..fp8
import
FP8GlobalStateManager
from
..utils
import
(
cast_if_needed
,
clear_tensor_data
,
divide
,
init_method_constant
,
requires_grad
,
needs_quantized_gemm
,
non_tn_fp8_gemm_supported
,
assert_dim_for_fp8_exec
,
nvtx_range_pop
,
nvtx_range_push
,
requires_grad
,
)
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
get_distributed_world_size
,
allreduce
,
symmetric_all_reduce
,
reduce_scatter_along_first_dim
,
gather_along_first_dim
,
is_fp8_activation_recompute_enabled
,
...
...
@@ -56,10 +61,13 @@ from ..tensor.quantized_tensor import (
prepare_for_saving
,
restore_from_saved
,
)
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.utils
import
any_feature_enabled
__all__
=
[
"Linear"
]
...
...
@@ -78,11 +86,13 @@ class _Linear(torch.autograd.Function):
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
input_quantizer
:
Optional
[
Quantizer
],
weight_quantizer
:
Optional
[
Quantizer
],
output_quantizer
:
Optional
[
Quantizer
],
grad_output_quantizer
:
Optional
[
Quantizer
],
grad_input_quantizer
:
Optional
[
Quantizer
],
grad_weight_quantizer
:
Optional
[
Quantizer
],
grad_output_quantizer
:
Optional
[
Quantizer
],
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
...
...
@@ -103,6 +113,8 @@ class _Linear(torch.autograd.Function):
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
debug
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
...
...
@@ -128,6 +140,10 @@ class _Linear(torch.autograd.Function):
parallel_mode
==
"column"
and
sequence_parallel
and
not
ub_overlap_ag_fprop
)
own_quantized_input
=
False
# TODO(kwyss): Support FP8 allgather for FP8 block quantization.
force_hp_input_gather
=
(
fp8
and
with_input_all_gather_nccl
and
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
)
# Perform TP communication in high precision.
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
if
any
([
ub_overlap_ag_fprop
,
ub_overlap_rs_fprop
])
and
not
(
...
...
@@ -137,14 +153,22 @@ class _Linear(torch.autograd.Function):
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
if
fp8
or
debug
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
with_input_all_gather_nccl
:
if
force_hp_input_gather
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat
,
tp_group
,
quantizer
=
input_quantizer
)
else
:
if
not
isinstance
(
inputmat
,
QuantizedTensor
):
columnwise_usage
=
backward_needs_input
and
isinstance
(
input_quantizer
,
MXFP8Quantizer
)
# force_hp_input_gather should enforce this
assert
not
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
inputmat
=
input_quantizer
(
inputmat
)
own_quantized_input
=
True
...
...
@@ -181,9 +205,9 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
# Cast weight to expected dtype
if
not
fp8
:
weightmat
=
cast_if_needed
(
weight
,
activation_dtype
)
else
:
weightmat
=
weight
if
fp8
or
debug
:
# Configure quantizer
if
weight_quantizer
is
not
None
:
columnwise_usage
=
is_grad_enabled
and
inp
.
requires_grad
...
...
@@ -193,7 +217,6 @@ class _Linear(torch.autograd.Function):
and
not
in_fp8_activation_recompute_phase
()
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
...
...
@@ -203,11 +226,14 @@ class _Linear(torch.autograd.Function):
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
fsdp_group
=
fsdp_group
,
workspace_dtype
=
activation_dtype
,
)
else
:
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
# Cast bias to expected dtype
bias_dtype
=
activation_dtype
if
fp8
and
activation_dtype
==
torch
.
float32
:
if
needs_quantized_gemm
(
inputmat_total
)
and
activation_dtype
==
torch
.
float32
:
bias_dtype
=
torch
.
bfloat16
bias
=
cast_if_needed
(
bias
,
bias_dtype
)
if
bias
is
not
None
else
bias
...
...
@@ -262,6 +288,7 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm"
)
if
is_grad_enabled
:
ctx
.
weight_quantizer
=
weight_quantizer
saved_inputmat
=
None
ctx
.
backward_input_needs_gather
=
(
...
...
@@ -275,6 +302,8 @@ class _Linear(torch.autograd.Function):
# can be allgathered.
if
isinstance
(
inputmat
,
MXFP8TensorBase
)
or
not
ctx
.
backward_input_needs_gather
:
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
if
force_hp_input_gather
:
assert
not
isinstance
(
inputmat
,
QuantizedTensor
)
saved_inputmat
=
inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
...
...
@@ -282,11 +311,8 @@ class _Linear(torch.autograd.Function):
if
isinstance
(
weightmat
,
QuantizedTensor
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
set_offloading_param
(
weight
,
"weight_offloading"
,
True
)
set_offloading_param
(
weightmat
,
"weight_offloading"
,
True
)
if
saved_inputmat
is
not
None
:
set_offloading_param
(
saved_inputmat
,
"activation_offloading"
,
True
)
if
cpu_offloading
and
saved_inputmat
is
not
None
:
mark_activation_offload
(
saved_inputmat
)
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
...
...
@@ -321,15 +347,18 @@ class _Linear(torch.autograd.Function):
ctx
.
tensor_objects
=
tensor_objects
ctx
.
activation_dtype
=
activation_dtype
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
fp8
=
fp8
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
force_hp_input_gather
=
force_hp_input_gather
ctx
.
input_quantizer
=
input_quantizer
ctx
.
grad_output_quantizer
=
grad_output_quantizer
ctx
.
grad_input_quantizer
=
grad_input_quantizer
ctx
.
grad_weight_quantizer
=
grad_weight_quantizer
ctx
.
grad_output_quantizer
=
grad_output_quantizer
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
if
fuse_wgrad_accumulation
and
weight
.
requires_grad
:
ctx
.
main_grad
=
weight
.
main_grad
ctx
.
debug
=
debug
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_bias
=
bias
is
not
None
...
...
@@ -353,6 +382,7 @@ class _Linear(torch.autograd.Function):
ctx
.
reduce_and_update_bwd_fp8_tensors
=
FP8GlobalStateManager
.
is_first_fp8_module
()
if
in_fp8_activation_recompute_phase
():
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
=
_first_fp8_module
ctx
.
wgrad_store
=
wgrad_store
# Row Parallel Linear
if
ub_overlap_rs_fprop
:
...
...
@@ -362,6 +392,9 @@ class _Linear(torch.autograd.Function):
if
sequence_parallel
:
out
,
_
=
reduce_scatter_along_first_dim
(
out
,
tp_group
)
elif
tensor_parallel
:
if
symmetric_ar_type
is
not
None
:
out
,
_
=
symmetric_all_reduce
(
out
,
tp_group
,
all_reduce_type
=
symmetric_ar_type
)
else
:
out
,
_
=
allreduce
(
out
,
tp_group
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.row_parallel_comm"
)
...
...
@@ -471,14 +504,27 @@ class _Linear(torch.autograd.Function):
ub_obj_wgrad
.
set_buffer_params
(
ctx
.
grad_input_quantizer
)
dgrad_bulk
=
ub_obj_wgrad
.
get_buffer
(
ctx
.
grad_input_quantizer
)
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if
ctx
.
grad_output_quantizer
is
not
None
:
rowwise_usage
=
True
columnwise_usage
=
True
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
grad_output_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage
=
False
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
,
)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
if
ctx
.
grad_output_quantizer
is
not
None
:
# Reduce duplicated transpose, which is performed in grad_output.update_usage
if
ctx
.
ub_overlap_ag
and
ctx
.
fp8_recipe
.
float8_per_tensor_scaling
():
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
else
:
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
(
grad_output
,
...
...
@@ -491,21 +537,26 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.grad_output_preprocess"
)
# Prepare input tensor
# Note: Perform tensor-parallel communication if needed
# Launch tensor-parallel communication for input tensor
inputmat_total
=
None
inputmat_total_work
=
None
if
ctx
.
backward_input_needs_gather
and
not
ctx
.
ub_bulk_dgrad
:
quantizer
=
None
if
ctx
.
fp8
:
if
ctx
.
fp8
or
ctx
.
debug
:
quantizer
=
ctx
.
input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
else
:
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
gather_quantizer
=
None
if
ctx
.
force_hp_input_gather
else
quantizer
inputmat_total
,
inputmat_total_work
=
gather_along_first_dim
(
inputmat
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
quantizer
,
quantizer
=
gather_
quantizer
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
else
:
...
...
@@ -527,7 +578,6 @@ class _Linear(torch.autograd.Function):
# Update quantizer
if
ctx
.
grad_input_quantizer
is
not
None
:
ctx
.
grad_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# dgrad GEMM
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
dgrad_gemm_use_split_accumulator
=
_2X_ACC_DGRAD
...
...
@@ -538,6 +588,12 @@ class _Linear(torch.autograd.Function):
recipe
.
fp8_gemm_dgrad
.
use_split_accumulator
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight_fp8
,
QuantizedTensor
):
weight_fp8
.
update_usage
(
rowwise_usage
=
ctx
.
weight_quantizer
.
rowwise_usage
,
columnwise_usage
=
ctx
.
weight_quantizer
.
columnwise_usage
,
)
dgrad
,
*
_
,
rs_out
=
general_gemm
(
weight_fp8
,
grad_output
,
...
...
@@ -573,6 +629,8 @@ class _Linear(torch.autograd.Function):
# Compute grad weight tensor
wgrad
=
None
if
ctx
.
requires_wgrad
:
# Synchronize tensor-parallel communication for input tensor
if
ctx
.
ub_bulk_dgrad
:
inputmat_total
=
ub_obj_dgrad
.
get_buffer
(
ctx
.
input_quantizer
)
if
ctx
.
fp8
:
...
...
@@ -586,18 +644,32 @@ class _Linear(torch.autograd.Function):
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
inputmat_total
.
_create_transpose
()
else
:
if
inputmat_total_work
is
not
None
:
# Synchronize tensor-parallel communication
inputmat_total_work
.
wait
()
inputmat_total_work
=
None
if
ctx
.
input_quantizer
is
not
None
and
not
isinstance
(
inputmat_total
,
QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
inputmat_total
=
ctx
.
input_quantizer
(
inputmat_total
)
# Make sure GEMM inputs have required data
if
isinstance
(
inputmat_total
,
QuantizedTensor
):
inputmat_total
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
grad_output
,
QuantizedTensor
):
# This is a no-op if platform supports non-TN FP8 GEMM or the transpose
# already exists.
grad_output
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
grad_output
.
update_usage
(
columnwise_usage
=
True
)
# Figure out whether to use split accumulator
use_split_accumulator
=
_2X_ACC_WGRAD
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
use_split_accumulator
=
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
and
ub_obj_wgrad
.
is_fp8_ubuf
():
rs_out
=
torch
.
empty
(
dgrad_shape
,
dtype
=
ctx
.
activation_dtype
,
device
=
grad_output
.
device
...
...
@@ -606,39 +678,29 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
wgrad_gemm_use_split_accumulator
=
_2X_ACC_WGRAD
if
ctx
.
fp8
:
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_wgrad"
):
wgrad_gemm_use_split_accumulator
=
(
recipe
.
fp8_gemm_wgrad
.
use_split_accumulator
)
wgrad
,
grad_bias_
,
_
,
rs_out
=
general_gemm
(
inputmat_total
,
grad_output
,
get_workspace
(),
layout
=
"NT"
,
grad
=
True
,
general_gemm_wgrad
=
functools
.
partial
(
general_gemm
,
out_dtype
=
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
workspace
=
get_workspace
(),
layout
=
"NT"
,
grad
=
True
,
bias
=
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
out
=
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
use_split_accumulator
=
wgrad_gemm_
use_split_accumulator
,
use_split_accumulator
=
use_split_accumulator
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
quantization_params
=
ctx
.
grad_weight_quantizer
,
ub
=
ub_obj_wgrad
,
ub_type
=
ub_type_wgrad
,
extra_output
=
rs_out
,
bulk_overlap
=
ctx
.
ub_bulk_wgrad
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
if
ctx
.
ub_bulk_wgrad
:
if
ub_obj_wgrad
.
is_fp8_ubuf
():
dgrad
=
rs_out
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
inputmat_total
,
grad_output
],
general_gemm_wgrad
)
else
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
ctx
.
grad
_
input
_quantizer
,
local_chunk
=
True
)
wgrad
,
grad_bias_
,
_
,
rs_out
=
general_gemm_w
grad
(
input
mat_total
,
grad_output
)
if
grad_bias
is
None
:
grad_bias
=
grad_bias_
...
...
@@ -647,12 +709,19 @@ class _Linear(torch.autograd.Function):
# Deallocate input tensor
if
ctx
.
owns_input
:
clear_tensor_data
(
inputmat_total
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.wgrad_gemm"
)
if
ctx
.
ub_bulk_wgrad
:
if
ub_obj_wgrad
.
is_fp8_ubuf
():
dgrad
=
rs_out
else
:
dgrad
=
ub_obj_wgrad
.
get_buffer
(
ctx
.
grad_input_quantizer
,
local_chunk
=
True
)
# Don't return grad bias if not needed
if
not
ctx
.
use_bias
:
grad_bias
=
None
#
Synchronize
tensor
parallel communication
#
Make sure all
tensor
-
parallel communication
is finished
if
inputmat_total_work
is
not
None
:
inputmat_total_work
.
wait
()
inputmat_total_work
=
None
...
...
@@ -669,18 +738,15 @@ class _Linear(torch.autograd.Function):
):
weight
.
grad_added_to_main_grad
=
True
if
getattr
(
weight
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
weight
.
main_grad
.
shape
,
dtype
=
weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
weight
.
main_grad
.
shape
),
weight
.
dtype
,
zero
=
True
,
)
else
:
wgrad
=
torch
.
empty
(
weight
.
main_grad
.
shape
,
dtype
=
weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
weight
.
main_grad
.
shape
),
weight
.
dtype
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
...
...
@@ -702,11 +768,13 @@ class _Linear(torch.autograd.Function):
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# input_quantizer
None
,
# weight_quantizer
None
,
# output_quantizer
None
,
# grad_output_quantizer
None
,
# grad_input_quantizer
None
,
# grad_weight_quantizer
None
,
# grad_output_quantizer
None
,
# fuse_wgrad_accumulation
None
,
# cpu_offloading
None
,
# tp_group
...
...
@@ -727,6 +795,8 @@ class _Linear(torch.autograd.Function):
None
,
# fsdp_group
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
None
,
# debug
)
...
...
@@ -762,6 +832,8 @@ class Linear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
...
...
@@ -797,7 +869,15 @@ class Linear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
"""
def
__init__
(
...
...
@@ -823,6 +903,9 @@ class Linear(TransformerEngineBaseModule):
ub_bulk_dgrad
:
bool
=
False
,
ub_bulk_wgrad
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -835,6 +918,13 @@ class Linear(TransformerEngineBaseModule):
self
.
apply_bias
=
bias
and
not
return_bias
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
name
=
name
if
TEDebugState
.
debug_enabled
:
self
.
_turn_off_unsupported_features_in_debug
()
# turn off userbuffers
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
if
device
==
"meta"
:
assert
parameters_split
is
None
,
"Cannot split module parameters on 'meta' device."
...
...
@@ -900,6 +990,13 @@ class Linear(TransformerEngineBaseModule):
assert
ub_name
is
not
None
,
f
"Comm+GEMM overlap layer '
{
ub_name
}
' is not initialized."
self
.
ub_name
=
ub_name
if
self
.
symmetric_ar_type
is
not
None
:
assert
torch_version
()
>=
(
2
,
7
,
0
,
),
"Torch version must be at least 2.7 to use symmetric memory"
# Initialize params in FP8
with_fp8_params
=
FP8GlobalStateManager
.
with_fp8_parameters
()
...
...
@@ -1078,6 +1175,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
debug
=
TEDebugState
.
debug_enabled
if
debug
:
self
.
_validate_name
()
if
FP8GlobalStateManager
.
fp8_graph_capturing
():
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
else
:
...
...
@@ -1085,6 +1186,13 @@ class Linear(TransformerEngineBaseModule):
if
skip_fp8_weight_update
is
not
None
:
is_first_microbatch
=
False
if
self
.
ub_overlap_rs_fprop
:
if
get_ub
(
self
.
ub_name
+
"_fprop"
).
is_fp8_ubuf
():
fp8_output
=
True
if
self
.
ub_overlap_rs_dgrad
:
if
get_ub
(
self
.
ub_name
+
"_dgrad"
).
is_fp8_ubuf
():
fp8_grad
=
True
with
self
.
prepare_forward
(
inp
,
allow_non_contiguous
=
isinstance
(
inp
,
QuantizedTensor
),
...
...
@@ -1106,13 +1214,28 @@ class Linear(TransformerEngineBaseModule):
else
:
bias_tensor
=
None
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
)
if
debug
:
if
not
any_feature_enabled
(
quantizers
):
# If no feature is used, then run faster implementation with debug = False.
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
debug
=
False
if
isinstance
(
weight_tensor
,
QuantizedTensor
):
raise
RuntimeError
(
"FP8 weights are not supported in debug mode."
)
(
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_output_quantizer
,
grad_input_quantizer
,
)
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
grad_weight_quantizer
,
grad_output_quantizer
,
)
=
quantizers
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
...
...
@@ -1133,11 +1256,13 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch
,
self
.
fp8
,
self
.
fp8_calibration
,
self
.
wgrad_store
,
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_output_quantizer
,
grad_input_quantizer
,
grad_weight_quantizer
,
grad_output_quantizer
,
self
.
fuse_wgrad_accumulation
,
is_cpu_offload_enabled
(),
self
.
tp_group
,
...
...
@@ -1158,6 +1283,8 @@ class Linear(TransformerEngineBaseModule):
self
.
fsdp_group
,
self
,
skip_fp8_weight_update
,
self
.
symmetric_ar_type
,
debug
,
)
out
=
linear_fn
(
*
args
)
if
self
.
gemm_bias_unfused_add
:
...
...
@@ -1169,8 +1296,9 @@ class Linear(TransformerEngineBaseModule):
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
):
if
not
self
.
fp8
:
return
[
None
]
*
5
return
[
None
]
*
6
grad_input_quantizer
=
None
grad_weight_quantizer
=
None
grad_output_quantizer
=
None
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
...
...
@@ -1188,8 +1316,20 @@ class Linear(TransformerEngineBaseModule):
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_output_quantizer
,
grad_input_quantizer
,
grad_weight_quantizer
,
grad_output_quantizer
,
)
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
):
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
assert
TEDebugState
.
debug_enabled
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
names
=
[
"activation"
,
"weight"
,
"output"
,
"dgrad"
,
"wgrad"
,
"gradient"
]
return
tuple
(
DebugQuantizer
(
self
.
name
,
name
,
q
,
self
.
tp_group
)
for
name
,
q
in
zip
(
names
,
original_quantizers
)
)
def
_customize_quantizers_float8_current_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
...
...
transformer_engine/pytorch/ops/basic/activation.py
View file @
ab3e5a92
...
...
@@ -13,6 +13,7 @@ import torch
import
transformer_engine_torch
as
tex
from
...fp8
import
FP8GlobalStateManager
from
...tensor
import
QuantizedTensor
from
...tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
...utils
import
clear_tensor_data
,
devices_match
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
reshape
...
...
@@ -37,8 +38,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
the first half of the input tensor, while PyTorch applies it to
the second half.
Parameters
----------
cache_quantized_input: bool, default = False
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
highly experimental.
"""
def
__init__
(
self
,
*
,
cache_quantized_input
:
bool
=
False
):
super
().
__init__
()
self
.
cache_quantized_input
:
bool
=
cache_quantized_input
@
abc
.
abstractmethod
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Forward implementation
...
...
@@ -100,9 +113,16 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if
y
.
dim
()
!=
x
.
dim
():
y
=
y
.
reshape
(
list
(
x
.
shape
[:
-
1
])
+
[
-
1
])
# Quantize input to FP8 before caching if needed
if
self
.
cache_quantized_input
:
quantizer
=
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
x
.
device
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
x
=
quantizer
(
x
)
# Save state for backward pass
ctx
.
save_for_backward
(
x
.
detach
())
ctx
.
fp8_enabled
=
fp8_enabled
ctx
.
dtype
=
dtype
ctx
.
prev_op
=
prev_op
return
y
...
...
@@ -116,10 +136,18 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Saved tensors from forward pass
(
x
,)
=
ctx
.
saved_tensors
# Check input tensor
if
isinstance
(
x
,
QuantizedTensor
):
x
=
x
.
dequantize
(
dtype
=
ctx
.
dtype
)
elif
x
.
dtype
!=
ctx
.
dtype
:
x
=
x
.
to
(
dtype
=
ctx
.
dtype
)
if
not
x
.
is_contiguous
():
x
=
x
.
contiguous
()
# Check grad output tensor
dy
=
grad_output
if
isinstance
(
dy
,
QuantizedTensor
):
dy
=
dy
.
dequantize
()
dy
=
dy
.
dequantize
(
dtype
=
ctx
.
dtype
)
if
not
devices_match
(
dy
.
device
,
x
.
device
)
or
dy
.
dtype
!=
x
.
dtype
:
dy
=
dy
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
if
not
dy
.
is_contiguous
():
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
ab3e5a92
...
...
@@ -23,6 +23,7 @@ from ...fp8 import FP8GlobalStateManager
from
...module.base
import
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
from
...tensor
import
Quantizer
,
QuantizedTensor
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..op
import
BasicOperation
,
OperationContext
...
...
@@ -412,7 +413,6 @@ class BasicLinear(BasicOperation):
x
=
None
x_async
=
None
with_x_all_gather
=
tensor_parallel_mode
==
"column"
and
sequence_parallel
own_quantized_x_local
=
False
if
with_quantized_compute
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
...
...
@@ -428,7 +428,6 @@ class BasicLinear(BasicOperation):
else
:
if
not
isinstance
(
x_local
,
QuantizedTensor
):
x_local
=
input_quantizer
(
x_local
)
own_quantized_x_local
=
True
x
=
x_local
else
:
if
isinstance
(
x_local
,
QuantizedTensor
):
...
...
@@ -483,6 +482,12 @@ class BasicLinear(BasicOperation):
"Attempting to generate MXFP8 output tensor, "
"but GEMM with MXFP8 output is not supported"
)
if
isinstance
(
output_quantizer
,
Float8BlockQuantizer
):
raise
RuntimeError
(
"Attempting to generate Float8BlockQuantized output tensor, "
"but GEMM with Float8BlockQuantized output is not supported"
)
if
output_quantizer
is
not
None
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
...
...
@@ -521,16 +526,16 @@ class BasicLinear(BasicOperation):
else
:
torch
.
distributed
.
all_reduce
(
y
,
group
=
tensor_parallel_group
)
# Configure input tensor for backward pass
if
own_quantized_x_local
:
x_local
.
update_usage
(
rowwise_usage
=
False
)
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
if
x_local
is
input
:
x_local
=
x_local
.
detach
()
# Configure input tensor for backward pass
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensor
):
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
return
y
,
x_local
,
w
@
staticmethod
...
...
@@ -679,7 +684,9 @@ class BasicLinear(BasicOperation):
quantizer
=
input_quantizer
,
)
else
:
if
not
isinstance
(
x_local
,
QuantizedTensor
):
if
isinstance
(
x_local
,
QuantizedTensor
):
x_local
.
update_usage
(
columnwise_usage
=
True
)
else
:
x_local
=
input_quantizer
(
x_local
)
x
=
x_local
else
:
...
...
@@ -706,14 +713,18 @@ class BasicLinear(BasicOperation):
raise
ValueError
(
"Weight tensor is required to compute input grad"
)
w
=
weight
w_is_quantized
=
isinstance
(
w
,
QuantizedTensor
)
if
with_quantized_compute
and
not
w_is_quantized
:
if
with_quantized_compute
:
if
w_is_quantized
:
w
.
update_usage
(
columnwise_usage
=
True
)
else
:
if
weight_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for weight tensor"
)
weight_quantizer
.
set_usage
(
columnwise
=
True
)
w
=
weight_quantizer
(
w
)
elif
not
with_quantized_compute
and
w_is_quantized
:
w
=
w
.
dequantize
()
if
not
with_quantized_compute
and
w
.
dtype
!=
dtype
:
else
:
if
w_is_quantized
:
w
=
w
.
dequantize
(
dtype
=
dtype
)
elif
w
.
dtype
!=
dtype
:
w
=
w
.
to
(
dtype
=
dtype
)
# Synchronize tensor-parallel communication
...
...
@@ -867,8 +878,8 @@ class BasicLinear(BasicOperation):
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer
.
set_usage
(
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
columnwise
=
False
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# Get autocast dtype if needed
dtype
=
None
...
...
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment