Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
53fa872c
Commit
53fa872c
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_release_v2.8' into release_v2.8
parents
27ddce40
40c69e75
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3319 additions
and
2161 deletions
+3319
-2161
transformer_engine/jax/csrc/extensions/misc.h
transformer_engine/jax/csrc/extensions/misc.h
+26
-0
transformer_engine/jax/csrc/extensions/normalization.cpp
transformer_engine/jax/csrc/extensions/normalization.cpp
+63
-0
transformer_engine/jax/csrc/extensions/pybind.cpp
transformer_engine/jax/csrc/extensions/pybind.cpp
+19
-3
transformer_engine/jax/dense.py
transformer_engine/jax/dense.py
+59
-12
transformer_engine/jax/flax/transformer.py
transformer_engine/jax/flax/transformer.py
+1
-0
transformer_engine/jax/layernorm_mlp.py
transformer_engine/jax/layernorm_mlp.py
+47
-4
transformer_engine/jax/quantize/scaling_modes.py
transformer_engine/jax/quantize/scaling_modes.py
+61
-45
transformer_engine/jax/sharding.py
transformer_engine/jax/sharding.py
+19
-0
transformer_engine/pytorch/attention/dot_product_attention/backends.py
...ngine/pytorch/attention/dot_product_attention/backends.py
+402
-178
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
...torch/attention/dot_product_attention/context_parallel.py
+1621
-1786
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
.../attention/dot_product_attention/dot_product_attention.py
+361
-2
transformer_engine/pytorch/attention/dot_product_attention/utils.py
...r_engine/pytorch/attention/dot_product_attention/utils.py
+255
-21
transformer_engine/pytorch/attention/multi_head_attention.py
transformer_engine/pytorch/attention/multi_head_attention.py
+48
-12
transformer_engine/pytorch/constants.py
transformer_engine/pytorch/constants.py
+2
-0
transformer_engine/pytorch/cpp_extensions/fused_attn.py
transformer_engine/pytorch/cpp_extensions/fused_attn.py
+23
-3
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+20
-0
transformer_engine/pytorch/csrc/common.cpp
transformer_engine/pytorch/csrc/common.cpp
+30
-0
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+76
-7
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+19
-11
transformer_engine/pytorch/csrc/extensions/activation.cpp
transformer_engine/pytorch/csrc/extensions/activation.cpp
+167
-77
No files found.
transformer_engine/jax/csrc/extensions/misc.h
View file @
53fa872c
...
@@ -87,5 +87,31 @@ constexpr struct Alignment {
...
@@ -87,5 +87,31 @@ constexpr struct Alignment {
std
::
vector
<
size_t
>
get_mxfp8_scale_shape
(
size_t
M
,
size_t
N
,
bool
is_colwise
);
std
::
vector
<
size_t
>
get_mxfp8_scale_shape
(
size_t
M
,
size_t
N
,
bool
is_colwise
);
template
<
typename
T
,
typename
...
Rest
>
void
hash_combine
(
int64_t
&
seed
,
const
T
&
v
,
Rest
...
rest
)
{
seed
^=
std
::
hash
<
T
>
{}(
v
)
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
);
(
hash_combine
(
seed
,
rest
),
...);
}
enum
class
JAXX_Collective_Op
:
int64_t
{
NONE
=
0
,
ALL_GATHER
=
1
,
REDUCE_SCATTER
=
2
,
};
static
CommOverlapType
get_nvte_collective_op
(
const
JAXX_Collective_Op
&
op
)
{
switch
(
op
)
{
case
JAXX_Collective_Op
::
ALL_GATHER
:
return
CommOverlapType
::
AG
;
break
;
case
JAXX_Collective_Op
::
REDUCE_SCATTER
:
return
CommOverlapType
::
RS
;
break
;
default:
NVTE_ERROR
(
"Invalid Collective Op "
,
static_cast
<
int
>
(
op
));
break
;
}
}
}
// namespace jax
}
// namespace jax
}
// namespace transformer_engine
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/normalization.cpp
View file @
53fa872c
...
@@ -180,6 +180,42 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
...
@@ -180,6 +180,42 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.
Attr
<
bool
>
(
"is_2x"
),
.
Attr
<
bool
>
(
"is_2x"
),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
Error_Type
NormForwardInitializeFFI
(
cudaStream_t
stream
,
Buffer_Type
x_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
gamma_buf
,
Buffer_Type
beta_buf
,
Result_Type
output_buf
,
Result_Type
colwise_output_buf
,
Result_Type
scale_inv_buf
,
Result_Type
colwise_scale_inv_buf
,
Result_Type
amax_buf
,
Result_Type
mu_buf
,
Result_Type
rsigma_buf
,
Result_Type
wkspace_buf
,
int
norm_type
,
bool
zero_centered_gamma
,
double
epsilon
,
int64_t
sm_margin
,
JAXX_Scaling_Mode
scaling_mode
,
bool
is_2x
)
{
return
wrapInStreamCapture
(
std
::
function
(
NormForwardFFI
),
stream
,
x_buf
,
scale_buf
,
gamma_buf
,
beta_buf
,
output_buf
,
colwise_output_buf
,
scale_inv_buf
,
colwise_scale_inv_buf
,
amax_buf
,
mu_buf
,
rsigma_buf
,
wkspace_buf
,
norm_type
,
zero_centered_gamma
,
epsilon
,
sm_margin
,
scaling_mode
,
is_2x
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
NormForwardInitializeHandler
,
NormForwardInitializeFFI
,
FFI
::
Bind
<
FFI_Initialize
>
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// x
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// gamma
.
Arg
<
Buffer_Type
>
()
// beta
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// colwise_output
.
Ret
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// colwise_scale_inv
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// mu
.
Ret
<
Buffer_Type
>
()
// rsigma
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
int64_t
>
(
"norm_type"
)
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
double
>
(
"epsilon"
)
.
Attr
<
int64_t
>
(
"sm_margin"
)
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
bool
>
(
"is_2x"
));
pybind11
::
tuple
GetNormBackwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
pybind11
::
tuple
GetNormBackwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
w_dtype
,
NVTE_Norm_Type
norm_type
,
DType
w_dtype
,
NVTE_Norm_Type
norm_type
,
bool
zero_centered_gamma
,
int
sm_margin
)
{
bool
zero_centered_gamma
,
int
sm_margin
)
{
...
@@ -305,5 +341,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI,
...
@@ -305,5 +341,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI,
.
Attr
<
int64_t
>
(
"sm_margin"
),
.
Attr
<
int64_t
>
(
"sm_margin"
),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
Error_Type
NormBackwardInitializeFFI
(
cudaStream_t
stream
,
Buffer_Type
dz_buf
,
Buffer_Type
x_buf
,
Buffer_Type
mu_buf
,
Buffer_Type
rsigma_buf
,
Buffer_Type
gamma_buf
,
Result_Type
xgrad_buf
,
Result_Type
wgrad_buf
,
Result_Type
dbeta_buf
,
Result_Type
wkspace_buf
,
int64_t
norm_type
,
bool
zero_centered_gamma
,
int64_t
sm_margin
)
{
return
wrapInStreamCapture
(
std
::
function
(
NormBackwardFFI
),
stream
,
dz_buf
,
x_buf
,
mu_buf
,
rsigma_buf
,
gamma_buf
,
xgrad_buf
,
wgrad_buf
,
dbeta_buf
,
wkspace_buf
,
norm_type
,
zero_centered_gamma
,
sm_margin
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
NormBackwardInitializeHandler
,
NormBackwardInitializeFFI
,
FFI
::
Bind
<
FFI_Initialize
>
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// dz
.
Arg
<
Buffer_Type
>
()
// x
.
Arg
<
Buffer_Type
>
()
// mu
.
Arg
<
Buffer_Type
>
()
// rsigma
.
Arg
<
Buffer_Type
>
()
// gamma
.
Ret
<
Buffer_Type
>
()
// xgrad
.
Ret
<
Buffer_Type
>
()
// wgrad
.
Ret
<
Buffer_Type
>
()
// dbeta
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
int64_t
>
(
"norm_type"
)
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
int64_t
>
(
"sm_margin"
));
}
// namespace jax
}
// namespace jax
}
// namespace transformer_engine
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/pybind.cpp
View file @
53fa872c
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
************************************************************************/
************************************************************************/
#include "../extensions.h"
#include "../extensions.h"
#include "cgemm_helper.h"
#include "common/util/cuda_runtime.h"
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
jax
{
namespace
jax
{
...
@@ -20,8 +22,12 @@ pybind11::dict Registrations() {
...
@@ -20,8 +22,12 @@ pybind11::dict Registrations() {
pybind11
::
dict
dict
;
pybind11
::
dict
dict
;
// Activation
// Activation
dict
[
"te_act_lu_ffi"
]
=
EncapsulateFFI
(
ActLuHandler
);
dict
[
"te_act_lu_ffi"
]
=
dict
[
"te_dact_dbias_quantize_ffi"
]
=
EncapsulateFFI
(
DActLuDBiasQuantizeHandler
);
pybind11
::
dict
(
pybind11
::
arg
(
"initialize"
)
=
EncapsulateFFI
(
ActLuInitializeHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
ActLuHandler
));
dict
[
"te_dact_dbias_quantize_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"initialize"
)
=
EncapsulateFFI
(
DActLuDBiasQuantizeInitializeHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
DActLuDBiasQuantizeHandler
));
// Quantization
// Quantization
dict
[
"te_dbias_quantize_ffi"
]
=
EncapsulateFFI
(
DBiasQuantizeHandler
);
dict
[
"te_dbias_quantize_ffi"
]
=
EncapsulateFFI
(
DBiasQuantizeHandler
);
...
@@ -42,9 +48,11 @@ pybind11::dict Registrations() {
...
@@ -42,9 +48,11 @@ pybind11::dict Registrations() {
// Normalization
// Normalization
dict
[
"te_norm_forward_ffi"
]
=
dict
[
"te_norm_forward_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
arg
(
"initialize"
)
=
EncapsulateFFI
(
NormForwardInitializeHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
NormForwardHandler
));
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
NormForwardHandler
));
dict
[
"te_norm_backward_ffi"
]
=
dict
[
"te_norm_backward_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
arg
(
"initialize"
)
=
EncapsulateFFI
(
NormBackwardInitializeHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
NormBackwardHandler
));
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
NormBackwardHandler
));
// Attention
// Attention
...
@@ -57,7 +65,7 @@ pybind11::dict Registrations() {
...
@@ -57,7 +65,7 @@ pybind11::dict Registrations() {
// GEMM
// GEMM
dict
[
"te_gemm_ffi"
]
=
dict
[
"te_gemm_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
C
ublasHandle
InitHandler
),
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
C
ollectiveGemm
InitHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
GemmHandler
));
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
GemmHandler
));
// Grouped GEMM
// Grouped GEMM
...
@@ -84,6 +92,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
...
@@ -84,6 +92,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m
.
def
(
"get_fused_attn_bwd_workspace_sizes"
,
&
GetFusedAttnBackwardWorkspaceSizes
);
m
.
def
(
"get_fused_attn_bwd_workspace_sizes"
,
&
GetFusedAttnBackwardWorkspaceSizes
);
m
.
def
(
"nvte_get_qkv_format"
,
&
nvte_get_qkv_format
);
m
.
def
(
"nvte_get_qkv_format"
,
&
nvte_get_qkv_format
);
m
.
def
(
"is_non_nt_fp8_gemm_supported"
,
&
nvte_is_non_tn_fp8_gemm_supported
);
m
.
def
(
"is_non_nt_fp8_gemm_supported"
,
&
nvte_is_non_tn_fp8_gemm_supported
);
m
.
def
(
"initialize_cgemm_communicator"
,
&
InitializeCgemmCommunicator
);
m
.
def
(
"get_cgemm_num_max_streams"
,
&
GetCgemmNumMaxStreams
);
pybind11
::
enum_
<
DType
>
(
m
,
"DType"
,
pybind11
::
module_local
())
pybind11
::
enum_
<
DType
>
(
m
,
"DType"
,
pybind11
::
module_local
())
.
value
(
"kByte"
,
DType
::
kByte
)
.
value
(
"kByte"
,
DType
::
kByte
)
...
@@ -159,6 +169,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
...
@@ -159,6 +169,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.
value
(
"COLWISE"
,
transformer_engine
::
jax
::
QuantizeLayout
::
COLWISE
)
.
value
(
"COLWISE"
,
transformer_engine
::
jax
::
QuantizeLayout
::
COLWISE
)
.
value
(
"ROWWISE_COLWISE"
,
transformer_engine
::
jax
::
QuantizeLayout
::
ROWWISE_COLWISE
)
.
value
(
"ROWWISE_COLWISE"
,
transformer_engine
::
jax
::
QuantizeLayout
::
ROWWISE_COLWISE
)
.
export_values
();
.
export_values
();
pybind11
::
enum_
<
JAXX_Collective_Op
>
(
m
,
"JAXX_Collective_Op"
,
pybind11
::
module_local
())
.
value
(
"NONE"
,
JAXX_Collective_Op
::
NONE
)
.
value
(
"ALL_GATHER"
,
JAXX_Collective_Op
::
ALL_GATHER
)
.
value
(
"REDUCE_SCATTER"
,
JAXX_Collective_Op
::
REDUCE_SCATTER
)
.
export_values
();
}
}
}
// namespace jax
}
// namespace jax
...
...
transformer_engine/jax/dense.py
View file @
53fa872c
...
@@ -11,10 +11,12 @@ customizable contracting dimensions for flexible tensor operations.
...
@@ -11,10 +11,12 @@ customizable contracting dimensions for flexible tensor operations.
from
typing
import
Tuple
,
Sequence
from
typing
import
Tuple
,
Sequence
from
functools
import
partial
from
functools
import
partial
import
warnings
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
.
import
cpp_extensions
as
tex
from
.
import
cpp_extensions
as
tex
from
.cpp_extensions.quantization
import
AmaxScope
from
.quantize
import
(
from
.quantize
import
(
ScaledTensorFactory
,
ScaledTensorFactory
,
ScalingMode
,
ScalingMode
,
...
@@ -61,8 +63,12 @@ def dense(
...
@@ -61,8 +63,12 @@ def dense(
kernel
:
jnp
.
ndarray
,
kernel
:
jnp
.
ndarray
,
bias
:
jnp
.
ndarray
=
None
,
bias
:
jnp
.
ndarray
=
None
,
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
batch_sequence_transpose
:
bool
=
False
,
input_axes
:
Tuple
[
str
,
...]
=
None
,
input_axes
:
Tuple
[
str
,
...]
=
None
,
kernel_axes
:
Tuple
[
str
,
...]
=
None
,
kernel_axes
:
Tuple
[
str
,
...]
=
None
,
output_axes
:
Tuple
[
str
,
...]
=
None
,
using_global_amax_of_x
:
bool
=
False
,
collective_op_set
:
tex
.
CollectiveOpSet
=
tex
.
noop_collective_op_set
,
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
):
):
"""Perform dense layer transformation with optional quantization.
"""Perform dense layer transformation with optional quantization.
...
@@ -76,11 +82,20 @@ def dense(
...
@@ -76,11 +82,20 @@ def dense(
kernel: Weight matrix for the dense layer transformation
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Returns:
Transformed output tensor
Transformed output tensor
"""
"""
if
batch_sequence_transpose
:
warnings
.
warn
(
"batch_sequence_transpose is not well tested, use with caution!"
)
if
not
get_quantize_config
().
is_fp8_enabled
():
if
not
get_quantize_config
().
is_fp8_enabled
():
input_dtype
=
x
.
dtype
input_dtype
=
x
.
dtype
kernel
=
kernel
.
astype
(
input_dtype
)
kernel
=
kernel
.
astype
(
input_dtype
)
...
@@ -90,29 +105,30 @@ def dense(
...
@@ -90,29 +105,30 @@ def dense(
kernel
,
kernel
,
bias
,
bias
,
contracting_dims
,
contracting_dims
,
batch_sequence_transpose
,
input_axes
,
input_axes
,
kernel_axes
,
kernel_axes
,
output_axes
,
using_global_amax_of_x
,
collective_op_set
,
quantizer_set
,
quantizer_set
,
)
)
return
output
return
output
@
partial
(
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,
4
,
5
,
6
,
7
,
8
,
9
))
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,
4
,
5
,
),
)
def
_dense
(
def
_dense
(
x
,
x
,
kernel
,
kernel
,
bias
,
bias
,
contracting_dims
,
contracting_dims
,
batch_sequence_transpose
,
input_axes
,
input_axes
,
kernel_axes
,
kernel_axes
,
quantizer_set
,
output_axes
,
using_global_amax_of_x
,
collective_op_set
,
quantizer_set
,
# need to be a diff_arg for DelayedScaling state management
):
):
"""Internal implementation of dense layer transformation with custom VJP.
"""Internal implementation of dense layer transformation with custom VJP.
...
@@ -124,8 +140,12 @@ def _dense(
...
@@ -124,8 +140,12 @@ def _dense(
kernel: Weight matrix
kernel: Weight matrix
bias: Optional bias tensor
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
contracting_dims: Contracting dimensions specification
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix
kernel_axes: Logical axes for sharding the weight matrix
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Returns:
...
@@ -136,8 +156,12 @@ def _dense(
...
@@ -136,8 +156,12 @@ def _dense(
kernel
,
kernel
,
bias
,
bias
,
contracting_dims
,
contracting_dims
,
batch_sequence_transpose
,
input_axes
,
input_axes
,
kernel_axes
,
kernel_axes
,
output_axes
,
using_global_amax_of_x
,
collective_op_set
,
quantizer_set
,
quantizer_set
,
)
)
return
output
return
output
...
@@ -148,8 +172,12 @@ def _dense_fwd_rule(
...
@@ -148,8 +172,12 @@ def _dense_fwd_rule(
kernel
,
kernel
,
bias
,
bias
,
contracting_dims
,
contracting_dims
,
batch_sequence_transpose
,
input_axes
,
input_axes
,
kernel_axes
,
kernel_axes
,
output_axes
,
using_global_amax_of_x
,
collective_op_set
,
quantizer_set
,
quantizer_set
,
):
):
"""Forward pass rule for dense layer transformation.
"""Forward pass rule for dense layer transformation.
...
@@ -175,6 +203,7 @@ def _dense_fwd_rule(
...
@@ -175,6 +203,7 @@ def _dense_fwd_rule(
x
,
x
,
flatten_axis
=
flatten_axis_x
,
flatten_axis
=
flatten_axis_x
,
quantizer
=
quantizer_set
.
x
,
quantizer
=
quantizer_set
.
x
,
amax_scope
=
AmaxScope
.
TPSP
if
using_global_amax_of_x
else
AmaxScope
.
LOCAL
,
)
)
casted_x
=
with_sharding_constraint_by_logical_axes
(
casted_x
,
input_axes
)
casted_x
=
with_sharding_constraint_by_logical_axes
(
casted_x
,
input_axes
)
...
@@ -182,6 +211,7 @@ def _dense_fwd_rule(
...
@@ -182,6 +211,7 @@ def _dense_fwd_rule(
kernel
,
kernel
,
flatten_axis
=
flatten_axis_k
,
flatten_axis
=
flatten_axis_k
,
quantizer
=
quantizer_set
.
kernel
,
quantizer
=
quantizer_set
.
kernel
,
amax_scope
=
AmaxScope
.
FSDP
,
)
)
casted_kernel
=
with_sharding_constraint_by_logical_axes
(
casted_kernel
,
kernel_axes
)
casted_kernel
=
with_sharding_constraint_by_logical_axes
(
casted_kernel
,
kernel_axes
)
...
@@ -191,9 +221,12 @@ def _dense_fwd_rule(
...
@@ -191,9 +221,12 @@ def _dense_fwd_rule(
casted_x
.
get_tensor
(
usage
=
TensorUsage
.
LHS
),
casted_x
.
get_tensor
(
usage
=
TensorUsage
.
LHS
),
casted_kernel
.
get_tensor
(
usage
=
TensorUsage
.
RHS
),
casted_kernel
.
get_tensor
(
usage
=
TensorUsage
.
RHS
),
contracting_dims
=
(
x_contracting_dims
,
k_contracting_dims
),
contracting_dims
=
(
x_contracting_dims
,
k_contracting_dims
),
transpose_batch_sequence
=
batch_sequence_transpose
,
bias
=
bias
if
not
tex
.
gemm_uses_jax_dot
()
else
None
,
bias
=
bias
if
not
tex
.
gemm_uses_jax_dot
()
else
None
,
fuse_bias
=
use_bias
if
not
tex
.
gemm_uses_jax_dot
()
else
False
,
fuse_bias
=
use_bias
if
not
tex
.
gemm_uses_jax_dot
()
else
False
,
collective_op
=
collective_op_set
.
forward
,
)
)
output
=
with_sharding_constraint_by_logical_axes
(
output
,
output_axes
)
if
use_bias
and
tex
.
gemm_uses_jax_dot
():
if
use_bias
and
tex
.
gemm_uses_jax_dot
():
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
...
@@ -212,8 +245,16 @@ def _dense_fwd_rule(
...
@@ -212,8 +245,16 @@ def _dense_fwd_rule(
def
_dense_bwd_rule
(
def
_dense_bwd_rule
(
contracting_dims
,
input_axes
,
kernel_axes
,
ctx
,
grad
contracting_dims
,
):
# pylint: disable=unused-argument
batch_sequence_transpose
,
input_axes
,
kernel_axes
,
output_axes
,
using_global_amax_of_x
,
collective_op_set
,
ctx
,
grad
,
):
"""Backward pass rule for dense layer transformation.
"""Backward pass rule for dense layer transformation.
Returns:
Returns:
...
@@ -228,6 +269,7 @@ def _dense_bwd_rule(
...
@@ -228,6 +269,7 @@ def _dense_bwd_rule(
quantizer_set
,
quantizer_set
,
flatten_axis_k
,
flatten_axis_k
,
)
=
ctx
)
=
ctx
grad
=
with_sharding_constraint_by_logical_axes
(
grad
,
output_axes
)
fwd_x_contracting_dims
,
fwd_k_contracting_dims
=
map
(
fwd_x_contracting_dims
,
fwd_k_contracting_dims
=
map
(
tex
.
sanitize_dims
,
(
casted_x_lhs
.
ndim
,
casted_kernel_rhs
.
ndim
),
contracting_dims
tex
.
sanitize_dims
,
(
casted_x_lhs
.
ndim
,
casted_kernel_rhs
.
ndim
),
contracting_dims
...
@@ -238,6 +280,7 @@ def _dense_bwd_rule(
...
@@ -238,6 +280,7 @@ def _dense_bwd_rule(
is_dbias
=
use_bias
,
is_dbias
=
use_bias
,
flatten_axis
=
flatten_axis_k
,
flatten_axis
=
flatten_axis_k
,
quantizer
=
quantizer_set
.
dgrad
,
quantizer
=
quantizer_set
.
dgrad
,
amax_scope
=
AmaxScope
.
LOCAL
if
using_global_amax_of_x
else
AmaxScope
.
TPSP
,
)
)
# GEMM NT
# GEMM NT
...
@@ -254,8 +297,9 @@ def _dense_bwd_rule(
...
@@ -254,8 +297,9 @@ def _dense_bwd_rule(
casted_grad
.
get_tensor
(
usage
=
TensorUsage
.
LHS
),
casted_grad
.
get_tensor
(
usage
=
TensorUsage
.
LHS
),
casted_kernel_rhs
,
casted_kernel_rhs
,
contracting_dims
=
(
g_contracting_dim
,
k_contracting_dim
),
contracting_dims
=
(
g_contracting_dim
,
k_contracting_dim
),
transpose_batch_sequence
=
batch_sequence_transpose
,
collective_op
=
collective_op_set
.
backward
,
)
)
dgrad
=
with_sharding_constraint_by_logical_axes
(
dgrad
,
input_axes
)
# GEMM TN
# GEMM TN
# x_non_contracting_dims
# x_non_contracting_dims
...
@@ -267,7 +311,10 @@ def _dense_bwd_rule(
...
@@ -267,7 +311,10 @@ def _dense_bwd_rule(
casted_x_lhs
,
casted_x_lhs
,
casted_grad
.
get_tensor
(
usage
=
TensorUsage
.
RHS
),
casted_grad
.
get_tensor
(
usage
=
TensorUsage
.
RHS
),
contracting_dims
=
(
x_contracting_dim
,
g_contracting_dim
),
contracting_dims
=
(
x_contracting_dim
,
g_contracting_dim
),
transpose_batch_sequence
=
batch_sequence_transpose
,
)
)
dgrad
=
with_sharding_constraint_by_logical_axes
(
dgrad
,
input_axes
)
wgrad
=
with_sharding_constraint_by_logical_axes
(
wgrad
,
kernel_axes
)
wgrad
=
with_sharding_constraint_by_logical_axes
(
wgrad
,
kernel_axes
)
return
dgrad
,
wgrad
,
dbias
,
quantizer_set
return
dgrad
,
wgrad
,
dbias
,
quantizer_set
...
...
transformer_engine/jax/flax/transformer.py
View file @
53fa872c
...
@@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[
...
@@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[
return
drop_path_shape
return
drop_path_shape
# TODO(Phuong): move this function to sharding.py
def
extend_logical_axis_rules
(
rules
:
LogicalRules
)
->
LogicalRules
:
def
extend_logical_axis_rules
(
rules
:
LogicalRules
)
->
LogicalRules
:
"""
"""
Extend the given Flax logical axis rules with the predefined TransformerLayer's
Extend the given Flax logical axis rules with the predefined TransformerLayer's
...
...
transformer_engine/jax/layernorm_mlp.py
View file @
53fa872c
...
@@ -21,6 +21,7 @@ import jax.numpy as jnp
...
@@ -21,6 +21,7 @@ import jax.numpy as jnp
from
jax.ad_checkpoint
import
checkpoint_name
from
jax.ad_checkpoint
import
checkpoint_name
from
.
import
cpp_extensions
as
tex
from
.
import
cpp_extensions
as
tex
from
.cpp_extensions.quantization
import
AmaxScope
from
.layernorm
import
canonicalize_norm_type
from
.layernorm
import
canonicalize_norm_type
from
.quantize
import
(
from
.quantize
import
(
with_sharding_constraint_by_logical_axes
,
with_sharding_constraint_by_logical_axes
,
...
@@ -40,6 +41,7 @@ def layernorm_mlp(
...
@@ -40,6 +41,7 @@ def layernorm_mlp(
norm_type
:
str
,
norm_type
:
str
,
zero_centered_gamma
:
bool
=
False
,
zero_centered_gamma
:
bool
=
False
,
epsilon
:
float
=
1e-6
,
epsilon
:
float
=
1e-6
,
batch_sequence_transpose
:
bool
=
False
,
norm_input_axes
:
Tuple
[
str
,
...]
=
None
,
norm_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_1_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_1_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_2_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_2_input_axes
:
Tuple
[
str
,
...]
=
None
,
...
@@ -48,6 +50,10 @@ def layernorm_mlp(
...
@@ -48,6 +50,10 @@ def layernorm_mlp(
ffn1_ckpt_name
:
str
=
"ffn1"
,
ffn1_ckpt_name
:
str
=
"ffn1"
,
ffn2_ckpt_name
:
str
=
"ffn2"
,
ffn2_ckpt_name
:
str
=
"ffn2"
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
collective_op_sets
:
Tuple
[
tex
.
CollectiveOpSet
]
=
(
tex
.
noop_collective_op_set
,
tex
.
noop_collective_op_set
,
),
quantizer_sets
:
Tuple
[
QuantizerSet
]
=
(
noop_quantizer_set
,
noop_quantizer_set
),
quantizer_sets
:
Tuple
[
QuantizerSet
]
=
(
noop_quantizer_set
,
noop_quantizer_set
),
)
->
jnp
.
ndarray
:
)
->
jnp
.
ndarray
:
"""Apply layer normalization followed by MLP block.
"""Apply layer normalization followed by MLP block.
...
@@ -71,6 +77,7 @@ def layernorm_mlp(
...
@@ -71,6 +77,7 @@ def layernorm_mlp(
norm_type: Type of normalization ("layernorm" or "rmsnorm")
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
epsilon: Small constant for numerical stability in normalization
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for sharding the layernorm input
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
...
@@ -79,6 +86,7 @@ def layernorm_mlp(
...
@@ -79,6 +86,7 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
activation_type: Activation function(s) to apply after the first dense layer transformation
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns:
Returns:
...
@@ -121,6 +129,7 @@ def layernorm_mlp(
...
@@ -121,6 +129,7 @@ def layernorm_mlp(
norm_type
,
norm_type
,
zero_centered_gamma
,
zero_centered_gamma
,
epsilon
,
epsilon
,
batch_sequence_transpose
,
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
...
@@ -129,12 +138,13 @@ def layernorm_mlp(
...
@@ -129,12 +138,13 @@ def layernorm_mlp(
ffn1_ckpt_name
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
collective_op_sets
,
quantizer_sets
,
quantizer_sets
,
)
)
return
output
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
))
def
_layernorm_mlp
(
def
_layernorm_mlp
(
x
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
...
@@ -146,6 +156,7 @@ def _layernorm_mlp(
...
@@ -146,6 +156,7 @@ def _layernorm_mlp(
norm_type
:
str
,
norm_type
:
str
,
zero_centered_gamma
:
bool
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
epsilon
:
float
,
batch_sequence_transpose
:
bool
,
norm_input_axes
:
Tuple
[
str
,
...],
norm_input_axes
:
Tuple
[
str
,
...],
dot_1_input_axes
:
Tuple
[
str
,
...],
dot_1_input_axes
:
Tuple
[
str
,
...],
dot_2_input_axes
:
Tuple
[
str
,
...],
dot_2_input_axes
:
Tuple
[
str
,
...],
...
@@ -154,6 +165,7 @@ def _layernorm_mlp(
...
@@ -154,6 +165,7 @@ def _layernorm_mlp(
ffn1_ckpt_name
:
str
,
ffn1_ckpt_name
:
str
,
ffn2_ckpt_name
:
str
,
ffn2_ckpt_name
:
str
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
collective_op_sets
:
Tuple
[
tex
.
CollectiveOpSet
],
quantizer_sets
,
quantizer_sets
,
):
):
"""Internal implementation of layernorm_mlp with custom VJP.
"""Internal implementation of layernorm_mlp with custom VJP.
...
@@ -173,12 +185,16 @@ def _layernorm_mlp(
...
@@ -173,12 +185,16 @@ def _layernorm_mlp(
norm_type: Type of normalization
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
epsilon: Small constant for numerical stability
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for layernorm sharding
norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding
kernel_1_axes: Logical axes for first weight matrix sharding
kernel_2_axes: Logical axes for second weight matrix sharding
ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s)
activation_type: Activation function(s)
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of quantizer sets
quantizer_sets: Tuple of quantizer sets
Returns:
Returns:
...
@@ -195,6 +211,7 @@ def _layernorm_mlp(
...
@@ -195,6 +211,7 @@ def _layernorm_mlp(
norm_type
,
norm_type
,
zero_centered_gamma
,
zero_centered_gamma
,
epsilon
,
epsilon
,
batch_sequence_transpose
,
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
...
@@ -203,6 +220,7 @@ def _layernorm_mlp(
...
@@ -203,6 +220,7 @@ def _layernorm_mlp(
ffn1_ckpt_name
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
collective_op_sets
,
quantizer_sets
,
quantizer_sets
,
)
)
return
output
return
output
...
@@ -219,6 +237,7 @@ def _layernorm_mlp_fwd_rule(
...
@@ -219,6 +237,7 @@ def _layernorm_mlp_fwd_rule(
norm_type
,
norm_type
,
zero_centered_gamma
,
zero_centered_gamma
,
epsilon
,
epsilon
,
batch_sequence_transpose
,
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
...
@@ -227,6 +246,7 @@ def _layernorm_mlp_fwd_rule(
...
@@ -227,6 +246,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
collective_op_sets
,
quantizer_sets
,
quantizer_sets
,
):
):
"""Forward pass rule for layernorm_mlp.
"""Forward pass rule for layernorm_mlp.
...
@@ -246,6 +266,10 @@ def _layernorm_mlp_fwd_rule(
...
@@ -246,6 +266,10 @@ def _layernorm_mlp_fwd_rule(
del
kernel_1_axes
,
kernel_2_axes
del
kernel_1_axes
,
kernel_2_axes
ffn1_quantizer_set
,
ffn2_quantizer_set
=
quantizer_sets
ffn1_quantizer_set
,
ffn2_quantizer_set
=
quantizer_sets
collective_op_set_1
,
collective_op_set_2
=
collective_op_sets
assert
not
collective_op_set_1
.
forward
.
is_reduce_scatter
assert
not
collective_op_set_2
.
forward
.
is_all_gather
# x should be in shape of (batch..., hidden)
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
...
@@ -272,13 +296,12 @@ def _layernorm_mlp_fwd_rule(
...
@@ -272,13 +296,12 @@ def _layernorm_mlp_fwd_rule(
epsilon
,
epsilon
,
norm_type
,
norm_type
,
quantizer
=
ffn1_quantizer_set
.
x
,
quantizer
=
ffn1_quantizer_set
.
x
,
amax_scope
=
AmaxScope
.
TPSP
,
)
)
casted_ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_ln_out
,
dot_1_input_axes
)
casted_ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_ln_out
,
dot_1_input_axes
)
casted_kernel_1
=
tex
.
quantize
(
casted_kernel_1
=
tex
.
quantize
(
kernel_1
,
kernel_1
,
flatten_axis
=-
2
,
quantizer
=
ffn1_quantizer_set
.
kernel
,
amax_scope
=
AmaxScope
.
FSDP
flatten_axis
=-
2
,
quantizer
=
ffn1_quantizer_set
.
kernel
,
)
)
# NN GEMM
# NN GEMM
...
@@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule(
...
@@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out
.
get_tensor
(
TensorUsage
.
LHS
),
casted_ln_out
.
get_tensor
(
TensorUsage
.
LHS
),
casted_kernel_1
.
get_tensor
(
TensorUsage
.
RHS
),
casted_kernel_1
.
get_tensor
(
TensorUsage
.
RHS
),
contracting_dims
=
(
x_contracting_dims
,
k_contracting_dims
),
contracting_dims
=
(
x_contracting_dims
,
k_contracting_dims
),
transpose_batch_sequence
=
batch_sequence_transpose
,
bias
=
bias_1
if
not
tex
.
gemm_uses_jax_dot
()
else
None
,
bias
=
bias_1
if
not
tex
.
gemm_uses_jax_dot
()
else
None
,
fuse_bias
=
use_bias_1
if
not
tex
.
gemm_uses_jax_dot
()
else
False
,
fuse_bias
=
use_bias_1
if
not
tex
.
gemm_uses_jax_dot
()
else
False
,
collective_op
=
collective_op_set_1
.
forward
,
)
)
if
use_bias_1
and
tex
.
gemm_uses_jax_dot
():
if
use_bias_1
and
tex
.
gemm_uses_jax_dot
():
...
@@ -317,6 +342,7 @@ def _layernorm_mlp_fwd_rule(
...
@@ -317,6 +342,7 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_2
=
tex
.
quantize
(
casted_kernel_2
=
tex
.
quantize
(
kernel_2
,
kernel_2
,
quantizer
=
ffn2_quantizer_set
.
kernel
,
quantizer
=
ffn2_quantizer_set
.
kernel
,
amax_scope
=
AmaxScope
.
FSDP
,
)
)
# NN GEMM
# NN GEMM
...
@@ -325,8 +351,10 @@ def _layernorm_mlp_fwd_rule(
...
@@ -325,8 +351,10 @@ def _layernorm_mlp_fwd_rule(
casted_act_out
.
get_tensor
(
TensorUsage
.
LHS
),
casted_act_out
.
get_tensor
(
TensorUsage
.
LHS
),
casted_kernel_2
.
get_tensor
(
TensorUsage
.
RHS
),
casted_kernel_2
.
get_tensor
(
TensorUsage
.
RHS
),
contracting_dims
=
(
x_contracting_dims
,
k_contracting_dims
),
contracting_dims
=
(
x_contracting_dims
,
k_contracting_dims
),
transpose_batch_sequence
=
batch_sequence_transpose
,
bias
=
bias_2
if
not
tex
.
gemm_uses_jax_dot
()
else
None
,
bias
=
bias_2
if
not
tex
.
gemm_uses_jax_dot
()
else
None
,
fuse_bias
=
use_bias_2
if
not
tex
.
gemm_uses_jax_dot
()
else
False
,
fuse_bias
=
use_bias_2
if
not
tex
.
gemm_uses_jax_dot
()
else
False
,
collective_op
=
collective_op_set_2
.
forward
,
)
)
if
use_bias_2
and
tex
.
gemm_uses_jax_dot
():
if
use_bias_2
and
tex
.
gemm_uses_jax_dot
():
...
@@ -334,6 +362,8 @@ def _layernorm_mlp_fwd_rule(
...
@@ -334,6 +362,8 @@ def _layernorm_mlp_fwd_rule(
bias_2_new_shape
=
(
1
,)
*
(
dot_2_output
.
ndim
-
bias_2
.
ndim
)
+
bias_2_shape
bias_2_new_shape
=
(
1
,)
*
(
dot_2_output
.
ndim
-
bias_2
.
ndim
)
+
bias_2_shape
dot_2_output
+=
jnp
.
reshape
(
bias_2
,
bias_2_new_shape
)
dot_2_output
+=
jnp
.
reshape
(
bias_2
,
bias_2_new_shape
)
# sharding of outputs should be the same as dot_1's input
dot_2_output
=
with_sharding_constraint_by_logical_axes
(
dot_2_output
,
dot_1_input_axes
)
dot_2_output
=
checkpoint_name
(
dot_2_output
,
ffn2_ckpt_name
)
dot_2_output
=
checkpoint_name
(
dot_2_output
,
ffn2_ckpt_name
)
ctx
=
(
ctx
=
(
...
@@ -363,6 +393,7 @@ def _layernorm_mlp_bwd_rule(
...
@@ -363,6 +393,7 @@ def _layernorm_mlp_bwd_rule(
norm_type
,
norm_type
,
zero_centered_gamma
,
zero_centered_gamma
,
epsilon
,
epsilon
,
batch_sequence_transpose
,
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
...
@@ -371,6 +402,7 @@ def _layernorm_mlp_bwd_rule(
...
@@ -371,6 +402,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
collective_op_sets
,
ctx
,
ctx
,
grad
,
grad
,
):
):
...
@@ -409,6 +441,10 @@ def _layernorm_mlp_bwd_rule(
...
@@ -409,6 +441,10 @@ def _layernorm_mlp_bwd_rule(
)
=
ctx
)
=
ctx
ffn1_quantizer_set
,
ffn2_quantizer_set
=
quantizer_sets
ffn1_quantizer_set
,
ffn2_quantizer_set
=
quantizer_sets
collective_op_set_1
,
collective_op_set_2
=
collective_op_sets
assert
not
collective_op_set_1
.
backward
.
is_all_gather
assert
not
collective_op_set_2
.
backward
.
is_reduce_scatter
# Since the sharding of outputs should be the same as dot_1's input
# Since the sharding of outputs should be the same as dot_1's input
grad
=
with_sharding_constraint_by_logical_axes
(
grad
,
dot_1_input_axes
)
grad
=
with_sharding_constraint_by_logical_axes
(
grad
,
dot_1_input_axes
)
...
@@ -417,6 +453,7 @@ def _layernorm_mlp_bwd_rule(
...
@@ -417,6 +453,7 @@ def _layernorm_mlp_bwd_rule(
grad
,
grad
,
is_dbias
=
use_bias_2
,
is_dbias
=
use_bias_2
,
quantizer
=
ffn1_quantizer_set
.
dgrad
,
quantizer
=
ffn1_quantizer_set
.
dgrad
,
amax_scope
=
AmaxScope
.
TPSP
,
)
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...
@@ -434,6 +471,8 @@ def _layernorm_mlp_bwd_rule(
...
@@ -434,6 +471,8 @@ def _layernorm_mlp_bwd_rule(
casted_grad
.
get_tensor
(
TensorUsage
.
LHS
),
casted_grad
.
get_tensor
(
TensorUsage
.
LHS
),
casted_kernel_2
,
casted_kernel_2
,
contracting_dims
=
(
g_contracting_dims_2
,
k_contracting_dims_2
),
contracting_dims
=
(
g_contracting_dims_2
,
k_contracting_dims_2
),
transpose_batch_sequence
=
batch_sequence_transpose
,
collective_op
=
collective_op_set_2
.
backward
,
)
)
dgrad_2
=
with_sharding_constraint_by_logical_axes
(
dgrad_2
,
dot_2_input_axes
)
dgrad_2
=
with_sharding_constraint_by_logical_axes
(
dgrad_2
,
dot_2_input_axes
)
...
@@ -448,6 +487,7 @@ def _layernorm_mlp_bwd_rule(
...
@@ -448,6 +487,7 @@ def _layernorm_mlp_bwd_rule(
casted_act_out
,
casted_act_out
,
casted_grad
.
get_tensor
(
TensorUsage
.
RHS
),
casted_grad
.
get_tensor
(
TensorUsage
.
RHS
),
contracting_dims
=
(
x_contracting_dims
,
g_contracting_dims
),
contracting_dims
=
(
x_contracting_dims
,
g_contracting_dims
),
transpose_batch_sequence
=
batch_sequence_transpose
,
)
)
wgrad_2
=
with_sharding_constraint_by_logical_axes
(
wgrad_2
,
kernel_2_axes
)
wgrad_2
=
with_sharding_constraint_by_logical_axes
(
wgrad_2
,
kernel_2_axes
)
...
@@ -474,6 +514,8 @@ def _layernorm_mlp_bwd_rule(
...
@@ -474,6 +514,8 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out
.
get_tensor
(
TensorUsage
.
LHS
),
casted_dact_out
.
get_tensor
(
TensorUsage
.
LHS
),
casted_kernel_1
,
casted_kernel_1
,
contracting_dims
=
(
g_contracting_dims_1
,
k_contracting_dims_1
),
contracting_dims
=
(
g_contracting_dims_1
,
k_contracting_dims_1
),
transpose_batch_sequence
=
batch_sequence_transpose
,
collective_op
=
collective_op_set_1
.
backward
,
)
)
dgrad_1
=
with_sharding_constraint_by_logical_axes
(
dgrad_1
,
dot_1_input_axes
)
dgrad_1
=
with_sharding_constraint_by_logical_axes
(
dgrad_1
,
dot_1_input_axes
)
...
@@ -484,6 +526,7 @@ def _layernorm_mlp_bwd_rule(
...
@@ -484,6 +526,7 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out
,
casted_ln_out
,
casted_dact_out
.
get_tensor
(
TensorUsage
.
RHS
),
casted_dact_out
.
get_tensor
(
TensorUsage
.
RHS
),
contracting_dims
=
(
x_contracting_dims
,
g_contracting_dims
),
contracting_dims
=
(
x_contracting_dims
,
g_contracting_dims
),
transpose_batch_sequence
=
batch_sequence_transpose
,
)
)
wgrad_1
=
with_sharding_constraint_by_logical_axes
(
wgrad_1
,
kernel_1_axes
)
wgrad_1
=
with_sharding_constraint_by_logical_axes
(
wgrad_1
,
kernel_1_axes
)
...
...
transformer_engine/jax/quantize/scaling_modes.py
View file @
53fa872c
...
@@ -17,7 +17,7 @@ from functools import reduce, lru_cache
...
@@ -17,7 +17,7 @@ from functools import reduce, lru_cache
import
operator
import
operator
import
numpy
as
np
import
numpy
as
np
from
jax.experimental.custom_partitioning
import
BATCHING
from
jax.experimental.custom_partitioning
import
BATCHING
,
CompoundFactor
from
jax.tree_util
import
register_pytree_node_class
from
jax.tree_util
import
register_pytree_node_class
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
...
@@ -152,12 +152,15 @@ class ScalingModeMetadataImpl(ABC):
...
@@ -152,12 +152,15 @@ class ScalingModeMetadataImpl(ABC):
@
abstractmethod
@
abstractmethod
def
get_shardy_sharding_rules
(
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
self
,
input_shape
,
unique_var
,
flatten_axis
,
)
->
QuantizeShardyRules
:
)
->
QuantizeShardyRules
:
"""Sharding rules for the input and (row, col)wise scale tensors.
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
Args:
input_
rank
: The
rank
of the input tensor (for which we produce the scale tensor)
input_
shape
: The
shape
of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
flatten_axis: Axis along which data can be flattened to 2D for quantization.
...
@@ -232,12 +235,15 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -232,12 +235,15 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
return
(
n_groups
,)
return
(
n_groups
,)
def
get_shardy_sharding_rules
(
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
self
,
input_shape
,
unique_var
,
flatten_axis
,
)
->
QuantizeShardyRules
:
)
->
QuantizeShardyRules
:
"""Sharding rules for the input and (row, col)wise scale tensors.
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
Args:
input_
rank
: The
rank
of the input tensor (for which we produce the scale tensor)
input_
shape
: The
shape
of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
flatten_axis: Axis along which data can be flattened to 2D for quantization.
...
@@ -245,7 +251,7 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -245,7 +251,7 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
The Shardy rules for the scaling mode
The Shardy rules for the scaling mode
"""
"""
del
flatten_axis
del
flatten_axis
input_spec
=
tuple
(
f
"
{
unique_var
}{
i
}
"
for
i
in
range
(
input_
rank
))
input_spec
=
tuple
(
f
"
{
unique_var
}{
i
}
"
for
i
in
range
(
len
(
input_
shape
)
))
scale_var
=
BATCHING
+
unique_var
+
"_scale_inv"
scale_var
=
BATCHING
+
unique_var
+
"_scale_inv"
return
QuantizeShardyRules
(
input_spec
,
(
scale_var
,),
(
scale_var
,),
{})
return
QuantizeShardyRules
(
input_spec
,
(
scale_var
,),
(
scale_var
,),
{})
...
@@ -323,20 +329,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -323,20 +329,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
return
(
n_groups
,)
return
(
n_groups
,)
def
get_shardy_sharding_rules
(
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
self
,
input_shape
,
unique_var
,
flatten_axis
,
)
->
QuantizeShardyRules
:
)
->
QuantizeShardyRules
:
"""Sharding rules for the input and (row, col)wise scale tensors.
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
Args:
input_
rank
: The
rank
of the input tensor (for which we produce the scale tensor)
input_
shape
: The
shape
of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
.
flatten_axis: Axis along which data can be flattened to 2D for quantization
Returns:
Returns:
The Shardy rules for the scaling mode
The Shardy rules for the scaling mode
"""
"""
del
flatten_axis
del
flatten_axis
input_spec
=
tuple
(
f
"
{
unique_var
}{
i
}
"
for
i
in
range
(
input_
rank
))
input_spec
=
tuple
(
f
"
{
unique_var
}{
i
}
"
for
i
in
range
(
len
(
input_
shape
)
))
scale_var
=
BATCHING
+
unique_var
+
"_scale_inv"
scale_var
=
BATCHING
+
unique_var
+
"_scale_inv"
return
QuantizeShardyRules
(
input_spec
,
(
scale_var
,),
(
scale_var
,),
{})
return
QuantizeShardyRules
(
input_spec
,
(
scale_var
,),
(
scale_var
,),
{})
...
@@ -562,52 +571,55 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -562,52 +571,55 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return
(
n_block_x
*
n_block_y
,)
return
(
n_block_x
*
n_block_y
,)
def
get_shardy_sharding_rules
(
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
self
,
input_shape
,
unique_var
,
flatten_axis
,
)
->
QuantizeShardyRules
:
)
->
QuantizeShardyRules
:
"""Sharding rules for the input and (row, col)wise scale tensors.
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
Args:
input_
rank
: The
rank
of the input tensor (for which we produce the scale tensor)
input_
shape
: The
shape
of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
Returns:
Returns:
The Shardy rules for the scaling mode
The Shardy rules for the scaling mode
"""
"""
del
flatten_axis
input_rank
=
len
(
input_shape
)
input_spec
=
[
f
"
{
unique_var
}{
i
}
"
for
i
in
range
(
input_rank
)]
input_spec
=
[
f
"
{
unique_var
}
_
{
i
}
"
for
i
in
range
(
input_rank
)]
rowwise
=
[
f
"
{
unique_var
}
scale_inv_rowwise
{
i
}
"
for
i
in
range
(
input_rank
)]
flatten_axis
=
(
flatten_axis
+
input_rank
)
%
input_rank
colwise
=
[
f
"
{
unique_var
}
scale_inv_colwise
{
i
}
"
for
i
in
range
(
input_rank
)]
# This implementation needs to be updated for different block dims.
# NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors.
assert
self
.
_block_dims
==
(
1
,
32
)
# Unfortunately, because Shardy rules are applied to the inner primitive, the
# only way to preserve the relationship is to lower unpadded scales to the
# We have to use two different factors in the two CompoundFactors because of Shardy
# underlying custom call and pad them in C++. Until that's implemented, the
# verifier requirements, even though they are the same.
# Shardy rules for block scales have to be completely disconnected from the
blocksizes
=
{}
# Shardy rules for the tensor they belong to.
colwise_var
=
f
"
{
unique_var
}
_None"
rowwise_var
=
f
"
{
unique_var
}
_None"
# # We have to use two different factors in the two CompoundFactors because of Shardy
if
not
input_shape
[
-
1
]
==
32
:
# # verifier requirements, even though they are the same.
rowwise_var
=
input_spec
[
-
1
]
+
"_compound"
# rowwise_var = unique_var
input_spec
[
-
1
]
=
CompoundFactor
(
rowwise_var
,
"blocksize_x"
)
# colwise_var = f"{unique_var}_"
blocksizes
[
"blocksize_x"
]
=
32
# input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
if
not
input_shape
[
flatten_axis
-
1
]
==
32
:
# input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
colwise_var
=
input_spec
[
flatten_axis
-
1
]
+
"_compound"
input_spec
[
flatten_axis
-
1
]
=
CompoundFactor
(
colwise_var
,
"blocksize_y"
)
# # The rowwise and colwise scale tensors should be sharded the same way as the input.
blocksizes
[
"blocksize_y"
]
=
32
# # However, we need to adjust the dimensions where the block scaling factor applies.
# rowwise = input_spec.copy()
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# rowwise[-1] = rowwise_var
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise
=
input_spec
.
copy
()
# colwise = input_spec.copy()
rowwise
[
-
1
]
=
rowwise_var
# colwise[flatten_axis - 1] = colwise_var
colwise
=
input_spec
.
copy
()
# # This implementation needs to be updated for different block dims.
colwise
[
flatten_axis
-
1
]
=
colwise_var
# assert self._block_dims == (1, 32)
return
QuantizeShardyRules
(
return
QuantizeShardyRules
(
tuple
(
input_spec
),
tuple
(
input_spec
),
tuple
(
rowwise
),
tuple
(
rowwise
),
tuple
(
colwise
),
tuple
(
colwise
),
{},
# {"
block
_
size
_rowwise": 32, "block_size_colwise": 32}
,
blocksize
s
,
)
)
...
@@ -697,18 +709,22 @@ class ScalingMode(Enum):
...
@@ -697,18 +709,22 @@ class ScalingMode(Enum):
return
self
.
_get_impl
().
get_quantize_layout
(
usage
)
return
self
.
_get_impl
().
get_quantize_layout
(
usage
)
def
get_shardy_sharding_rules
(
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
=-
1
self
,
input_shape
,
unique_var
,
flatten_axis
=-
1
,
)
->
Tuple
[
Tuple
[
str
]]:
)
->
Tuple
[
Tuple
[
str
]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
Args:
input_
rank
: The
rank
of the input tensor (for which we produce the scale tensor)
input_
shape
: The
shape
of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
Returns:
The Shardy rules for the scaling mode
The Shardy rules for the scaling mode
"""
"""
return
self
.
_get_impl
().
get_shardy_sharding_rules
(
input_
rank
,
unique_var
,
flatten_axis
)
return
self
.
_get_impl
().
get_shardy_sharding_rules
(
input_
shape
,
unique_var
,
flatten_axis
)
def
get_grouped_scale_shape_2x
(
def
get_grouped_scale_shape_2x
(
self
,
data_shape
,
n_groups
,
group_axis
,
is_padded
=
True
,
flatten_axis
=-
1
self
,
data_shape
,
n_groups
,
group_axis
,
is_padded
=
True
,
flatten_axis
=-
1
...
...
transformer_engine/jax/sharding.py
View file @
53fa872c
...
@@ -13,6 +13,7 @@ from contextlib import contextmanager
...
@@ -13,6 +13,7 @@ from contextlib import contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
import
warnings
import
warnings
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax.interpreters
import
pxla
from
jax.interpreters
import
pxla
...
@@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
...
@@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
if
axis
!=
global_mesh_resource
().
pp_resource
:
if
axis
!=
global_mesh_resource
().
pp_resource
:
x
=
lax_paral_op
(
x
,
jax
.
lax
.
pmax
,
axis
,
mesh
)
x
=
lax_paral_op
(
x
,
jax
.
lax
.
pmax
,
axis
,
mesh
)
return
x
return
x
def
tpsp_axis_size
():
"""
Get the size of the tensor parallelism axis.
Return 1 if no TP axis is set.
"""
return
get_mesh_axis_size
(
global_mesh_resource
().
tpsp_resource
)
def
dp_or_fsdp_axis_size
():
"""
Get the size of the data parallelism or FSDP axis.
Return 1 if no DP/FSDP axis is set.
"""
dp_size
=
get_mesh_axis_size
(
global_mesh_resource
().
dp_resource
)
fsdp_size
=
get_mesh_axis_size
(
global_mesh_resource
().
fsdp_resource
)
return
dp_size
if
dp_size
>
1
else
fsdp_size
transformer_engine/pytorch/attention/dot_product_attention/backends.py
View file @
53fa872c
...
@@ -13,17 +13,20 @@ import logging
...
@@ -13,17 +13,20 @@ import logging
from
packaging.version
import
Version
as
PkgVersion
from
packaging.version
import
Version
as
PkgVersion
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.utils
import
(
from
transformer_engine.pytorch.utils
import
(
SplitAlongDim
,
get_device_compute_capability
,
get_device_compute_capability
,
combine_tensors
,
split_tensor_along_dim
,
split_tensor_along_dim
,
)
)
from
transformer_engine.pytorch.utils
import
attention_mask_func
from
transformer_engine.pytorch.utils
import
attention_mask_func
,
nvtx_range_push
,
nvtx_range_pop
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
Base
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
)
)
...
@@ -40,7 +43,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
...
@@ -40,7 +43,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_O
,
META_O
,
META_QKV
,
META_QKV
,
)
)
from
transformer_engine.pytorch.fp8
import
get_fp8_torch_dtype
from
transformer_engine.pytorch.fp8
import
get_fp8_torch_dtype
,
FP8GlobalStateManager
from
transformer_engine.pytorch.distributed
import
get_distributed_world_size
from
transformer_engine.pytorch.distributed
import
get_distributed_world_size
from
transformer_engine.pytorch.jit
import
no_torch_dynamo
from
transformer_engine.pytorch.jit
import
no_torch_dynamo
from
transformer_engine.pytorch.attention.dot_product_attention.context_parallel
import
(
from
transformer_engine.pytorch.attention.dot_product_attention.context_parallel
import
(
...
@@ -53,6 +56,9 @@ from transformer_engine.pytorch.attention.inference import InferenceParams
...
@@ -53,6 +56,9 @@ from transformer_engine.pytorch.attention.inference import InferenceParams
import
transformer_engine.pytorch.attention.dot_product_attention.utils
as
dpa_utils
import
transformer_engine.pytorch.attention.dot_product_attention.utils
as
dpa_utils
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
FlashAttentionUtils
as
fa_utils
,
FlashAttentionUtils
as
fa_utils
,
combine_and_quantize
,
combine_and_dequantize
,
print_quantizers
,
)
)
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
AttentionLogging
as
attn_log
,
AttentionLogging
as
attn_log
,
...
@@ -131,6 +137,58 @@ if not IS_HIP_EXTENSION:
...
@@ -131,6 +137,58 @@ if not IS_HIP_EXTENSION:
fa_utils
.
set_flash_attention_3_params
()
fa_utils
.
set_flash_attention_3_params
()
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
_dpa_fp8_cs_o_in_f16
=
os
.
getenv
(
"NVTE_DPA_FP8CS_O_in_F16"
,
"1"
)
==
"1"
class
FP8EmulationFunc
(
torch
.
autograd
.
Function
):
"""
Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows:
- forward : QKV (quantize+dequantize), P (pass-through), S (quantize+dequantize), O (pass-through)
- backward: dO (quantize+dequantize), dS (pass-through), dP (quantize+dequantize), dQKV (pass-through)
"""
@
staticmethod
def
forward
(
ctx
,
tensor1
,
tensor2
,
tensor3
,
quantizer
,
quantizer_name
,
qkv_layout
):
# pylint: disable=missing-function-docstring
if
quantizer_name
==
"QKV_quantizer"
:
query_layer
,
key_layer
,
value_layer
=
[
x
.
contiguous
()
for
x
in
[
tensor1
,
tensor2
,
tensor3
]
]
q_fp8
,
k_fp8
,
v_fp8
=
combine_and_quantize
(
qkv_layout
,
query_layer
,
key_layer
,
value_layer
,
quantizer
)
tensors
=
combine_and_dequantize
(
qkv_layout
,
q_fp8
,
k_fp8
,
v_fp8
,
src_nominal_dtype
=
query_layer
.
dtype
)
elif
quantizer_name
in
[
"S_quantizer"
,
"O_quantizer"
]:
t_fp8
=
quantizer
(
tensor1
)
tensors
=
(
t_fp8
.
dequantize
(
dtype
=
tensor1
.
dtype
),
tensor2
,
tensor3
)
else
:
tensors
=
(
tensor1
,
tensor2
,
tensor3
)
ctx
.
quantizer
=
quantizer
ctx
.
quantizer_name
=
quantizer_name
ctx
.
qkv_layout
=
qkv_layout
return
tensors
[
0
],
tensors
[
1
],
tensors
[
2
]
@
staticmethod
def
backward
(
ctx
,
grad1
,
grad2
,
grad3
):
# pylint: disable=missing-function-docstring
if
ctx
.
quantizer_name
in
[
"dO_quantizer"
,
"dP_quantizer"
]:
dt_fp8
=
ctx
.
quantizer
(
grad1
)
tensors
=
dt_fp8
.
dequantize
(
dtype
=
grad1
.
dtype
),
grad2
,
grad3
elif
ctx
.
quantizer_name
==
"dQKV_quantizer"
:
query_grad
,
key_grad
,
value_grad
=
[
x
.
contiguous
()
for
x
in
[
grad1
,
grad2
,
grad3
]]
dq_fp8
,
dk_fp8
,
dv_fp8
=
combine_and_quantize
(
ctx
.
qkv_layout
,
query_grad
,
key_grad
,
value_grad
,
ctx
.
quantizer
)
tensors
=
combine_and_dequantize
(
ctx
.
qkv_layout
,
dq_fp8
,
dk_fp8
,
dv_fp8
,
src_nominal_dtype
=
query_grad
.
dtype
)
else
:
tensors
=
grad1
,
grad2
,
grad3
return
tensors
[
0
],
tensors
[
1
],
tensors
[
2
],
None
,
None
,
None
class
UnfusedDotProductAttention
(
torch
.
nn
.
Module
):
class
UnfusedDotProductAttention
(
torch
.
nn
.
Module
):
"""Parallel attention w/o QKV and Proj Gemms
"""Parallel attention w/o QKV and Proj Gemms
...
@@ -144,6 +202,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
...
@@ -144,6 +202,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout
:
float
=
0.0
,
attention_dropout
:
float
=
0.0
,
attention_dropout_ctx
:
Optional
[
Callable
]
=
nullcontext
,
attention_dropout_ctx
:
Optional
[
Callable
]
=
nullcontext
,
layer_number
:
Optional
[
int
]
=
None
,
layer_number
:
Optional
[
int
]
=
None
,
softmax_type
:
str
=
"vanilla"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -151,6 +210,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
...
@@ -151,6 +210,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attention_dropout_ctx
=
attention_dropout_ctx
self
.
attention_dropout_ctx
=
attention_dropout_ctx
self
.
layer_number
=
layer_number
self
.
layer_number
=
layer_number
self
.
softmax_type
=
softmax_type
def
mask_func
(
x
,
y
):
def
mask_func
(
x
,
y
):
return
(
return
(
...
@@ -187,6 +247,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
...
@@ -187,6 +247,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
core_attention_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
softmax_offset
:
torch
.
Tensor
=
None
,
fp8
:
bool
=
False
,
fp8_meta
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
quantizers
=
None
,
fp8_output
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Unfused attention fprop"""
"""Unfused attention fprop"""
assert
(
assert
(
...
@@ -284,6 +349,35 @@ class UnfusedDotProductAttention(torch.nn.Module):
...
@@ -284,6 +349,35 @@ class UnfusedDotProductAttention(torch.nn.Module):
if
apply_qk_layer_scaling
:
if
apply_qk_layer_scaling
:
scale
/=
self
.
layer_number
scale
/=
self
.
layer_number
if
fp8
:
# get quantizers from DPA; all Nones if not fp8
QKV_quantizer
,
O_quantizer
,
S_quantizer
,
dQKV_quantizer
,
dO_quantizer
,
dP_quantizer
=
(
dpa_utils
.
get_attention_quantizers
(
fp8
,
quantizers
)
)
# S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8_meta
is
not
None
and
fp8_meta
.
get
(
"local_recipes"
,
None
)
is
not
None
:
fp8_recipe
=
fp8_meta
[
"local_recipes"
][
0
]
if
fp8_recipe
.
float8_current_scaling
():
S_quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
S_quantizer
.
dtype
,
device
=
"cuda"
)
dP_quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
dP_quantizer
.
dtype
,
device
=
"cuda"
)
if
"2"
in
qkv_layout
or
"3"
in
qkv_layout
:
qkv_format
,
*
_
=
dpa_utils
.
get_qkv_format
(
qkv_layout
)
qkv_layout
=
"_"
.
join
([
qkv_format
]
*
3
)
# quantize and dequantize QKV to emulate FP8
query_layer
,
key_layer
,
value_layer
=
FP8EmulationFunc
.
apply
(
query_layer
,
key_layer
,
value_layer
,
QKV_quantizer
,
"QKV_quantizer"
,
qkv_layout
)
# quantize and dequantize dQKV to emulate FP8
query_layer
,
key_layer
,
value_layer
=
FP8EmulationFunc
.
apply
(
query_layer
,
key_layer
,
value_layer
,
dQKV_quantizer
,
"dQKV_quantizer"
,
qkv_layout
)
# Raw attention scores. [b * np, sq, sk]
# Raw attention scores. [b * np, sq, sk]
if
core_attention_bias_type
==
"no_bias"
:
if
core_attention_bias_type
==
"no_bias"
:
matmul_result
=
torch
.
baddbmm
(
matmul_result
=
torch
.
baddbmm
(
...
@@ -328,7 +422,27 @@ class UnfusedDotProductAttention(torch.nn.Module):
...
@@ -328,7 +422,27 @@ class UnfusedDotProductAttention(torch.nn.Module):
dtype
=
query_layer
.
dtype
dtype
=
query_layer
.
dtype
)
)
# attention scores and attention mask [b, np, sq, sk]
if
fp8
:
# quantize and dequantize dP to emulate FP8
matmul_result
,
*
_
=
FP8EmulationFunc
.
apply
(
matmul_result
,
None
,
None
,
dP_quantizer
,
"dP_quantizer"
,
None
)
# add attention sink to the last column: [b, np, sq, sk+1]
if
self
.
softmax_type
!=
"vanilla"
:
matmul_result
=
torch
.
cat
(
[
matmul_result
,
softmax_offset
.
to
(
dtype
=
matmul_result
.
dtype
).
expand
(
matmul_result
.
size
(
0
),
-
1
,
matmul_result
.
size
(
2
),
-
1
),
],
dim
=-
1
,
)
attention_mask
=
F
.
pad
(
attention_mask
,
(
0
,
1
),
mode
=
"constant"
,
value
=
False
)
attn_mask_type
=
"arbitrary"
# attention scores and attention mask
softmax_scale
=
self
.
layer_number
if
apply_qk_layer_scaling
else
None
softmax_scale
=
self
.
layer_number
if
apply_qk_layer_scaling
else
None
attention_probs
=
self
.
scale_mask_softmax
(
attention_probs
=
self
.
scale_mask_softmax
(
matmul_result
,
attention_mask
,
attn_mask_type
,
softmax_scale
matmul_result
,
attention_mask
,
attn_mask_type
,
softmax_scale
...
@@ -339,6 +453,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
...
@@ -339,6 +453,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
if
"padding"
in
attn_mask_type
:
if
"padding"
in
attn_mask_type
:
attention_probs
=
attention_probs
.
masked_fill
(
attention_mask
,
0
)
attention_probs
=
attention_probs
.
masked_fill
(
attention_mask
,
0
)
# remove attention sink: [b, np, sq, sk]
if
self
.
softmax_type
!=
"vanilla"
:
attention_probs
=
attention_probs
[...,
:
-
1
]
# This is actually dropping out entire tokens to attend to, which might
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# seem a bit unusual, but is taken from the original Transformer paper.
with
self
.
attention_dropout_ctx
():
with
self
.
attention_dropout_ctx
():
...
@@ -359,6 +477,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
...
@@ -359,6 +477,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
# change view [b * np, sq, sk]
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
if
fp8
:
# quantize and dequantize S to emulate FP8
attention_probs
,
*
_
=
FP8EmulationFunc
.
apply
(
attention_probs
,
None
,
None
,
S_quantizer
,
"S_quantizer"
,
None
)
# matmul: [b * np, sq, hn]
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
...
@@ -393,6 +517,20 @@ class UnfusedDotProductAttention(torch.nn.Module):
...
@@ -393,6 +517,20 @@ class UnfusedDotProductAttention(torch.nn.Module):
# [tq, np, hn] --> [tq, hp]
# [tq, np, hn] --> [tq, hp]
context_layer
=
context_layer
.
view
(
total_tokens
,
-
1
)
context_layer
=
context_layer
.
view
(
total_tokens
,
-
1
)
if
fp8
:
# quantize and dequantize O to emulate FP8
context_layer
,
*
_
=
FP8EmulationFunc
.
apply
(
context_layer
,
None
,
None
,
O_quantizer
,
"O_quantizer"
,
None
)
# quantize and dequantize dO to emulate FP8
context_layer
,
*
_
=
FP8EmulationFunc
.
apply
(
context_layer
,
None
,
None
,
dO_quantizer
,
"dO_quantizer"
,
None
)
# quantize O
if
fp8_output
:
context_layer
=
O_quantizer
(
context_layer
)
return
context_layer
return
context_layer
...
@@ -491,6 +629,7 @@ class FlashAttention(torch.nn.Module):
...
@@ -491,6 +629,7 @@ class FlashAttention(torch.nn.Module):
quantizers
=
None
,
quantizers
=
None
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
flash_attention_backend
:
Optional
[
PkgVersion
]
=
PkgVersion
(
"0"
),
flash_attention_backend
:
Optional
[
PkgVersion
]
=
PkgVersion
(
"0"
),
fp8_output
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""flash-attn fprop"""
"""flash-attn fprop"""
...
@@ -696,6 +835,7 @@ class FlashAttention(torch.nn.Module):
...
@@ -696,6 +835,7 @@ class FlashAttention(torch.nn.Module):
quantizers
=
quantizers
,
quantizers
=
quantizers
,
pad_between_seqs
=
False
,
pad_between_seqs
=
False
,
use_flash_attn_3
=
use_flash_attn_3
,
use_flash_attn_3
=
use_flash_attn_3
,
fp8_output
=
fp8_output
,
)
)
else
:
else
:
from
transformer_engine.pytorch.cpu_offload
import
(
from
transformer_engine.pytorch.cpu_offload
import
(
...
@@ -795,8 +935,6 @@ class FlashAttention(torch.nn.Module):
...
@@ -795,8 +935,6 @@ class FlashAttention(torch.nn.Module):
)
)
return
out
return
out
# "fp8_mha" decides outputs in fp8, while inputs are inferred from
# the real dtype
assert
isinstance
(
key_layer
,
query_layer
.
__class__
)
and
isinstance
(
assert
isinstance
(
key_layer
,
query_layer
.
__class__
)
and
isinstance
(
value_layer
,
query_layer
.
__class__
value_layer
,
query_layer
.
__class__
),
"q, k, and v must have the same type."
),
"q, k, and v must have the same type."
...
@@ -843,7 +981,7 @@ class FlashAttention(torch.nn.Module):
...
@@ -843,7 +981,7 @@ class FlashAttention(torch.nn.Module):
if
fp8
:
if
fp8
:
output
=
output
.
to
(
dtype
=
torch_orig_dtype
)
output
=
output
.
to
(
dtype
=
torch_orig_dtype
)
if
fp8
and
fp8_
meta
[
"recipe"
].
fp8_mha
:
if
fp8
and
fp8_
output
:
O_quantizer
=
quantizers
[
"scaling_fwd"
][
META_O
]
O_quantizer
=
quantizers
[
"scaling_fwd"
][
META_O
]
output
=
O_quantizer
(
output
)
output
=
O_quantizer
(
output
)
...
@@ -871,7 +1009,7 @@ class FlashAttention(torch.nn.Module):
...
@@ -871,7 +1009,7 @@ class FlashAttention(torch.nn.Module):
if
q_format
==
"sbhd"
:
if
q_format
==
"sbhd"
:
# (bs)hd -> bs(hd) -> sb(hd)
# (bs)hd -> bs(hd) -> sb(hd)
if
fp8
and
fp8_
meta
[
"recipe"
].
fp8_mha
:
if
fp8
and
fp8_
output
:
output_data
=
(
output_data
=
(
output
.
_data
.
reshape
(
batch_size
,
max_seqlen_q
//
cp_size
,
-
1
)
output
.
_data
.
reshape
(
batch_size
,
max_seqlen_q
//
cp_size
,
-
1
)
.
transpose
(
0
,
1
)
.
transpose
(
0
,
1
)
...
@@ -895,7 +1033,7 @@ class FlashAttention(torch.nn.Module):
...
@@ -895,7 +1033,7 @@ class FlashAttention(torch.nn.Module):
class
FusedAttnFunc
(
torch
.
autograd
.
Function
):
class
FusedAttnFunc
(
torch
.
autograd
.
Function
):
"""Fu
nction for FusedAttention with separate Q, K, V tensors
"""
"""Fu
sedAttention forward and backward implementation
"""
@
staticmethod
@
staticmethod
def
forward
(
def
forward
(
...
@@ -919,6 +1057,7 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -919,6 +1057,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout
,
qkv_layout
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
softmax_type
,
window_size
,
window_size
,
rng_gen
,
rng_gen
,
fused_attention_backend
,
fused_attention_backend
,
...
@@ -927,55 +1066,72 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -927,55 +1066,72 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta
,
fp8_meta
,
quantizers
,
quantizers
,
deterministic
,
deterministic
,
softmax_offset
,
fp8_output
,
layer_number
,
):
):
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8
=
False
is_output_fp8
=
fp8_meta
[
"recipe"
].
fp8_mha
if
"recipe"
in
fp8_meta
else
False
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn
fake_dtype
=
q
.
dtype
# add NVTX range
nvtx_label
=
"transformer_engine.FusedAttnFunc.forward"
nvtx_range_push
(
f
"
{
nvtx_label
}
"
)
# recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"]
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8_meta
is
not
None
and
fp8_meta
.
get
(
"local_recipes"
,
None
)
is
not
None
:
fp8_recipe
=
fp8_meta
[
"local_recipes"
][
0
]
# input types are inferred from the real data while output types are controlled by fp8_output
# fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha)
assert
isinstance
(
k
,
q
.
__class__
)
and
isinstance
(
v
,
q
.
__class__
),
"q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor."
is_input_fp8
=
isinstance
(
q
,
Float8Tensor
)
is_output_fp8
=
fp8_output
# whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa)
# whether bwd kernel in FP8:
is_bwd_fp8
=
fp8
and
int
(
os
.
getenv
(
"NVTE_FP8_DPA_BWD"
,
"1"
))
# get quantizers from DPA; all Nones if not fp8
QKV_quantizer
,
O_quantizer
,
S_quantizer
,
dQKV_quantizer
,
dO_quantizer
,
dP_quantizer
=
(
QKV_quantizer
,
O_quantizer
,
S_quantizer
,
dQKV_quantizer
,
dO_quantizer
,
dP_quantizer
=
(
dpa_utils
.
get_attention_quantizers
(
fp8
,
quantizers
,
cp_specific_quantizers
=
False
)
dpa_utils
.
get_attention_quantizers
(
fp8
,
quantizers
)
)
)
# get nominal data type for out
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
out_nominal_dtype
=
q
.
dtype
if
fp8
:
if
fp8
:
fused_attention_backend
=
FusedAttnBackend
[
"FP8"
]
fused_attention_backend
=
FusedAttnBackend
[
"FP8"
]
assert
isinstance
(
k
,
q
.
__class__
)
and
isinstance
(
v
,
q
.
__class__
),
"q, k, and v must have the same type."
is_input_fp8
=
isinstance
(
q
,
Float8Tensor
)
# q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16
q_fp8
,
k_fp8
,
v_fp8
=
None
,
None
,
None
# q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
if
is_input_fp8
:
if
is_input_fp8
:
q_fp8
,
k_fp8
,
v_fp8
=
q
,
k
,
v
q_fp8
,
k_fp8
,
v_fp8
=
q
,
k
,
v
else
:
else
:
# 1: qkv packed, 2: kv packed, 3: qkv separate
q_fp8
,
k_fp8
,
v_fp8
=
combine_and_quantize
(
qkv_layout
,
q
,
k
,
v
,
QKV_quantizer
)
qkv_group
=
len
(
qkv_layout
.
replace
(
"paged_kv_"
,
""
).
split
(
"_"
))
match
qkv_group
:
# print quantizers
case
1
:
print_quantizers
(
dim
=
qkv_layout
.
find
(
"3"
)
"FusedAttnFunc.forward >> before: "
,
qkv
=
combine_tensors
([
q
,
k
,
v
],
dim
)
layer_number
,
qkv_c
=
qkv
.
view
(
-
1
,
qkv
.
shape
[
-
3
]
*
qkv
.
shape
[
-
2
]
*
qkv
.
shape
[
-
1
])
QKV_quantizer
,
qkv_fp8
=
QKV_quantizer
(
qkv
)
O_quantizer
,
q_fp8
,
k_fp8
,
v_fp8
=
SplitAlongDim
.
apply
(
qkv_fp8
,
dim
,
[
1
,
1
,
1
],
True
)
S_quantizer
,
case
2
:
dQKV_quantizer
,
q_fp8
=
QKV_quantizer
(
q
)
dO_quantizer
,
dim
=
qkv_layout
.
split
(
"_"
)[
1
].
find
(
"2"
)
dP_quantizer
,
kv
=
combine_tensors
([
k
,
v
],
dim
)
)
kv_c
=
kv
.
view
(
-
1
,
kv
.
shape
[
-
3
]
*
kv
.
shape
[
-
2
]
*
kv
.
shape
[
-
1
])
kv_fp8
=
QKV_quantizer
(
kv_c
)
# out_:
k_fp8
,
v_fp8
=
SplitAlongDim
.
apply
(
kv_fp8
,
dim
,
[
1
,
1
],
True
)
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
case
3
:
# fp8_dtype = tex.DType.kFloat8E4M3
q_fp8
=
QKV_quantizer
(
q
)
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
k_fp8
=
QKV_quantizer
(
k
)
out_
,
aux_ctx_tensors
=
fused_attn_fwd
(
v_fp8
=
QKV_quantizer
(
v
)
case
_
:
raise
"Invalid qkv_layout "
+
qkv_layout
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
out_fp8
,
aux_ctx_tensors
=
fused_attn_fwd
(
is_training
,
is_training
,
max_seqlen_q
,
max_seqlen_q
,
max_seqlen_kv
,
max_seqlen_kv
,
...
@@ -984,7 +1140,7 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -984,7 +1140,7 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8
,
q_fp8
,
k_fp8
,
k_fp8
,
v_fp8
,
v_fp8
,
fake
_dtype
,
out_nominal
_dtype
,
fused_attention_backend
,
fused_attention_backend
,
attn_bias
,
attn_bias
,
cu_seqlens_q_padded
,
cu_seqlens_q_padded
,
...
@@ -999,45 +1155,59 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -999,45 +1155,59 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout
,
qkv_layout
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
softmax_type
,
window_size
,
window_size
,
rng_gen
,
rng_gen
,
softmax_offset
,
)
)
if
is_output_fp8
:
out_ret
=
out_fp8
# out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# out: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_fp8
=
out_
out
=
out_
if
isinstance
(
out_
,
Float8Tensor
):
if
not
is_output_fp8
or
not
is_bwd_fp8
:
out
=
out_
.
dequantize
().
view
(
out_
.
shape
)
else
:
else
:
out_ret
=
out_fp8
.
dequantize
().
view
(
out_fp8
.
shape
)
if
is_output_fp8
or
(
# is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16
is_bwd_fp8
# is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn
and
not
(
fp8_recipe
.
float8_current_scaling
()
and
_dpa_fp8_cs_o_in_f16
)
out_save
=
out_ret
):
out_fp8
=
O_quantizer
(
out_
)
# print quantizers
print_quantizers
(
"FusedAttnFunc.forward >> after: "
,
layer_number
,
QKV_quantizer
,
O_quantizer
,
S_quantizer
,
dQKV_quantizer
,
dO_quantizer
,
dP_quantizer
,
)
if
not
int
(
os
.
getenv
(
"NVTE_FP8_DPA_BWD"
,
"1"
)):
# return appropriate tensors
# 1: qkv packed, 2: kv packed, 3: qkv separate
out_ret
=
out_fp8
if
is_output_fp8
else
out
# save appropriate tensors
fp8_tensors
=
(
None
,
None
,
None
,
None
)
qkvo_tensors
=
(
None
,
None
,
None
,
None
)
if
is_bwd_fp8
:
if
fp8_recipe
.
float8_current_scaling
()
and
_dpa_fp8_cs_o_in_f16
:
fp8_tensors
=
(
q_fp8
,
k_fp8
,
v_fp8
,
None
)
qkvo_tensors
=
(
None
,
None
,
None
,
out
)
else
:
fp8_tensors
=
(
q_fp8
,
k_fp8
,
v_fp8
,
out_fp8
)
else
:
if
is_input_fp8
:
if
is_input_fp8
:
qkv_group
=
len
(
qkv_layout
.
replace
(
"paged_kv_"
,
""
).
split
(
"_"
))
q
,
k
,
v
=
combine_and_dequantize
(
qkv_layout
,
q_fp8
,
k_fp8
,
v_fp8
)
if
qkv_group
==
1
:
qkvo_tensors
=
(
q
,
k
,
v
,
out
)
dim
=
qkv_layout
.
find
(
"3"
)
qkv
=
combine_tensors
([
q
,
k
,
v
],
dim
)
qkv_c
=
qkv
.
view
(
-
1
,
qkv
.
shape
[
-
3
]
*
qkv
.
shape
[
-
2
]
*
qkv
.
shape
[
-
1
])
qkv_no_fp8
=
qkv_c
.
dequantize
().
view
(
qkv
.
shape
)
q
,
k
,
v
=
SplitAlongDim
.
apply
(
qkv_no_fp8
,
dim
,
[
1
,
1
,
1
],
True
)
if
qkv_group
==
2
:
q
=
q
.
dequantize
()
dim
=
qkv_layout
.
replace
(
"paged_kv_"
,
""
).
split
(
"_"
)[
1
].
find
(
"2"
)
kv
=
combine_tensors
([
k
,
v
],
dim
)
kv_c
=
kv
.
view
(
-
1
,
kv
.
shape
[
-
3
]
*
kv
.
shape
[
-
2
]
*
kv
.
shape
[
-
1
])
kv_no_fp8
=
kv
.
dequantize
()
k
,
v
=
SplitAlongDim
.
apply
(
kv_no_fp8
,
dim
,
[
1
,
1
],
True
)
if
qkv_group
==
3
:
q
=
q
.
dequantize
()
k
=
k
.
dequantize
()
v
=
v
.
dequantize
()
if
is_output_fp8
:
out_save
=
out_fp8
.
dequantize
()
fp8_tensors
=
(
q_fp8
,
k_fp8
,
v_fp8
,
out_fp8
)
else
:
else
:
# q, k, v, out_
ret:
torch.float16 or torch.bfloat16
# q, k, v, out_
: torch.Tensor; dtype =
torch.float16 or torch.bfloat16
out_
ret
,
aux_ctx_tensors
=
fused_attn_fwd
(
out_
,
aux_ctx_tensors
=
fused_attn_fwd
(
is_training
,
is_training
,
max_seqlen_q
,
max_seqlen_q
,
max_seqlen_kv
,
max_seqlen_kv
,
...
@@ -1046,7 +1216,7 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1046,7 +1216,7 @@ class FusedAttnFunc(torch.autograd.Function):
q
,
q
,
k
,
k
,
v
,
v
,
fake
_dtype
,
out_nominal
_dtype
,
fused_attention_backend
,
fused_attention_backend
,
attn_bias
,
attn_bias
,
cu_seqlens_q_padded
,
cu_seqlens_q_padded
,
...
@@ -1061,13 +1231,23 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1061,13 +1231,23 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout
,
qkv_layout
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
softmax_type
,
window_size
,
window_size
,
rng_gen
,
rng_gen
,
softmax_offset
,
)
)
out_save
=
out_ret
out
=
out_
out_ret
=
out_
fp8_tensors
=
(
None
,
None
,
None
,
None
)
fp8_tensors
=
(
None
,
None
,
None
,
None
)
qkvo_tensors
=
(
q
,
k
,
v
,
out
)
ctx
.
fp8
=
fp8
and
int
(
os
.
getenv
(
"NVTE_FP8_DPA_BWD"
,
"1"
))
nvtx_range_pop
(
f
"
{
nvtx_label
}
"
)
ctx
.
fp8_recipe
=
fp8_recipe
ctx
.
fp8
=
is_bwd_fp8
# assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16
# used when some tensors are base tensors and loose the "dtype" attribute
ctx
.
nominal_dtype
=
out_nominal_dtype
from
transformer_engine.pytorch.cpu_offload
import
(
from
transformer_engine.pytorch.cpu_offload
import
(
CPUOffloadEnabled
,
CPUOffloadEnabled
,
...
@@ -1078,7 +1258,7 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1078,7 +1258,7 @@ class FusedAttnFunc(torch.autograd.Function):
if
ctx
.
fp8
:
if
ctx
.
fp8
:
tensor_list
=
fp8_tensors
tensor_list
=
fp8_tensors
else
:
else
:
tensor_list
=
[
q
,
k
,
v
,
out
_save
]
tensor_list
=
[
q
,
k
,
v
,
out
]
qkv_layout
=
"sbhd_sbhd_sbhd"
qkv_layout
=
"sbhd_sbhd_sbhd"
mark_activation_offload
(
*
tensor_list
)
mark_activation_offload
(
*
tensor_list
)
...
@@ -1086,7 +1266,6 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1086,7 +1266,6 @@ class FusedAttnFunc(torch.autograd.Function):
ctx
.
is_input_fp8
=
is_input_fp8
ctx
.
is_input_fp8
=
is_input_fp8
ctx
.
is_output_fp8
=
is_output_fp8
ctx
.
is_output_fp8
=
is_output_fp8
qkvo_tensors
=
(
q
,
k
,
v
,
out_save
)
if
not
ctx
.
fp8
else
(
None
,
None
,
None
,
None
)
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
fp8_tensors
,
*
fp8_tensors
,
*
qkvo_tensors
,
*
qkvo_tensors
,
...
@@ -1100,11 +1279,14 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1100,11 +1279,14 @@ class FusedAttnFunc(torch.autograd.Function):
ctx
.
tensor_objects
=
tensor_objects
ctx
.
tensor_objects
=
tensor_objects
ctx
.
fp8_meta
=
fp8_meta
ctx
.
fp8_meta
=
fp8_meta
ctx
.
layer_number
=
layer_number
ctx
.
QKV_quantizer
=
QKV_quantizer
ctx
.
O_quantizer
=
O_quantizer
ctx
.
dQKV_quantizer
=
dQKV_quantizer
ctx
.
dQKV_quantizer
=
dQKV_quantizer
ctx
.
dO_quantizer
=
dO_quantizer
ctx
.
dO_quantizer
=
dO_quantizer
ctx
.
dP_quantizer
=
dP_quantizer
ctx
.
dP_quantizer
=
dP_quantizer
ctx
.
S_quantizer
=
S_quantizer
ctx
.
S_quantizer
=
S_quantizer
if
ctx
.
fp8
:
if
ctx
.
fp8
and
isinstance
(
ctx
.
S_quantizer
,
Float8Quantizer
)
:
ctx
.
S_quantizer
=
S_quantizer
.
copy
()
ctx
.
S_quantizer
=
S_quantizer
.
copy
()
ctx
.
S_quantizer
.
scale
=
S_quantizer
.
scale
.
clone
()
ctx
.
S_quantizer
.
scale
=
S_quantizer
.
scale
.
clone
()
...
@@ -1116,6 +1298,7 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1116,6 +1298,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx
.
qkv_layout
=
qkv_layout
ctx
.
qkv_layout
=
qkv_layout
ctx
.
attn_bias_type
=
attn_bias_type
ctx
.
attn_bias_type
=
attn_bias_type
ctx
.
attn_mask_type
=
attn_mask_type
ctx
.
attn_mask_type
=
attn_mask_type
ctx
.
softmax_type
=
softmax_type
ctx
.
window_size
=
window_size
ctx
.
window_size
=
window_size
ctx
.
fused_attention_backend
=
(
ctx
.
fused_attention_backend
=
(
fused_attention_backend
if
ctx
.
fp8
else
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
fused_attention_backend
if
ctx
.
fp8
else
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
...
@@ -1128,17 +1311,15 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1128,17 +1311,15 @@ class FusedAttnFunc(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
d_out
):
def
backward
(
ctx
,
d_out
):
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
if
ctx
.
is_output_fp8
:
assert
isinstance
(
d_out
,
Float8Tensor
),
"Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2
fake_dtype
=
d_out
.
dtype
d_out
=
d_out
.
contiguous
()
# d_out is expected to be in FP8 if is_output_fp8=True,
# but in the case it's not, convert it to FP8 before any operation
if
ctx
.
fp8
and
ctx
.
is_output_fp8
and
not
isinstance
(
d_out
,
QuantizedTensorBase
):
d_out
=
ctx
.
dO_quantizer
(
d_out
)
if
not
ctx
.
use_FAv2_bwd
:
d_out
.
_data
=
d_out
.
_data
.
contiguous
()
elif
not
ctx
.
use_FAv2_bwd
:
d_out
=
d_out
.
contiguous
()
(
(
q_fp8
,
q_fp8
,
k_fp8
,
k_fp8
,
...
@@ -1192,16 +1373,55 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1192,16 +1373,55 @@ class FusedAttnFunc(torch.autograd.Function):
dk
=
dk
[...,
:
d_out
.
shape
[
-
1
]]
dk
=
dk
[...,
:
d_out
.
shape
[
-
1
]]
dv
=
dv
[...,
:
d_out
.
shape
[
-
1
]]
dv
=
dv
[...,
:
d_out
.
shape
[
-
1
]]
else
:
else
:
with
torch
.
cuda
.
nvtx
.
range
(
"_FusedAttn"
):
with
torch
.
cuda
.
nvtx
.
range
(
"FusedAttnFunc.backward"
):
# get nominal data type of dq, dk, dv
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
dqkv_nominal_dtype
=
ctx
.
nominal_dtype
if
ctx
.
fp8
:
if
ctx
.
fp8
:
# d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16
# d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
if
ctx
.
is_output_fp8
:
if
ctx
.
is_output_fp8
:
d_out_fp8
=
d_out
d_out_fp8
=
d_out
else
:
else
:
d_out_fp8
=
ctx
.
dO_quantizer
(
d_out
)
d_out_fp8
=
ctx
.
dO_quantizer
(
d_out
)
dqkv_dtype
=
TE_DType
[
d_out_fp8
.
_data
.
dtype
]
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
# print quantizers
# d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2
print_quantizers
(
dq_fp8
,
dk_fp8
,
dv_fp8
,
*
rest
=
fused_attn_bwd
(
"FusedAttnFunc.backward >> before: "
,
ctx
.
layer_number
,
ctx
.
QKV_quantizer
,
ctx
.
O_quantizer
,
ctx
.
S_quantizer
,
ctx
.
dQKV_quantizer
,
ctx
.
dO_quantizer
,
ctx
.
dP_quantizer
,
)
# get tex.DType for dq, dk, dv data
dqkv_te_dtype
=
d_out_fp8
.
_fp8_dtype
# q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16,
# fp8_dtype = tex.DType.kFloat8E4M3
# d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
# out_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
#
# dq_, dk_, dv_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_
=
(
out
if
ctx
.
fp8_recipe
.
float8_current_scaling
()
and
_dpa_fp8_cs_o_in_f16
else
out_fp8
)
dq_
,
dk_
,
dv_
,
*
rest
=
fused_attn_bwd
(
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_kv
,
ctx
.
max_seqlen_kv
,
cu_seqlens_q
,
cu_seqlens_q
,
...
@@ -1209,10 +1429,10 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1209,10 +1429,10 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8
,
q_fp8
,
k_fp8
,
k_fp8
,
v_fp8
,
v_fp8
,
out_
fp8
,
out_
,
d_out_fp8
,
d_out_fp8
,
fake
_dtype
,
dqkv_nominal
_dtype
,
dqkv_dtype
,
dqkv_
te_
dtype
,
aux_ctx_tensors
,
aux_ctx_tensors
,
ctx
.
fused_attention_backend
,
ctx
.
fused_attention_backend
,
cu_seqlens_q_padded
,
cu_seqlens_q_padded
,
...
@@ -1226,44 +1446,45 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1226,44 +1446,45 @@ class FusedAttnFunc(torch.autograd.Function):
ctx
.
qkv_layout
,
ctx
.
qkv_layout
,
ctx
.
attn_bias_type
,
ctx
.
attn_bias_type
,
ctx
.
attn_mask_type
,
ctx
.
attn_mask_type
,
ctx
.
softmax_type
,
ctx
.
window_size
,
ctx
.
window_size
,
ctx
.
deterministic
,
ctx
.
deterministic
,
)
)
#
is_input_fp8 = False: dq, dk, dv:
torch.float16 or torch.bfloat16
#
dq, dk, dv: torch.Tensor; dtype =
torch.float16 or torch.bfloat16
# is_input_fp8 = True:
dq, dk, dv
: torch.float8_e5m2
dq
,
dk
,
dv
=
dq
_
,
dk
_
,
dv
_
i
f
not
ctx
.
is
_
in
put_fp8
:
i
s_float8tensor
=
isin
stance
(
dq_
,
Float8Tensor
)
qkv_group
=
len
(
ctx
.
qkv_layout
.
replace
(
"paged_kv_"
,
""
).
split
(
"_"
))
if
is_float8tensor
and
not
ctx
.
is_input_fp8
:
if
qkv_group
==
1
:
# return in F16
dim
=
ctx
.
qkv_layout
.
find
(
"3"
)
dq
,
dk
,
dv
=
combine_and_dequantize
(
dqkv_fp8_data
=
combine_tensors
(
ctx
.
qkv_layout
,
[
dq_fp8
.
_data
,
dk_fp8
.
_data
,
dv_fp8
.
_data
],
dim
dq_
,
)
dk_
,
d
qkv_fp8
=
dq_fp8
.
make_like
(
d
v_
,
tensor
=
dq_fp8
,
data
=
dqkv_fp8_data
,
shape
=
dqkv_fp8_data
.
sha
pe
src_nominal_dtype
=
dq_
.
dty
pe
,
)
)
dqkv
=
dqkv_fp8
.
dequantize
()
if
not
is_float8tensor
and
ctx
.
is_input_fp8
:
dq
,
dk
,
dv
=
SplitAlongDim
.
apply
(
dqkv
,
dim
,
[
1
,
1
,
1
],
True
)
# return in FP8
if
qkv_group
==
2
:
dq
,
dk
,
dv
=
combine_and_quantize
(
dq
=
dq_fp8
.
de
quantize
()
ctx
.
qkv_layout
,
dq_
,
dk_
,
dv_
,
ctx
.
dQKV_
quantize
r
dim
=
ctx
.
qkv_layout
.
split
(
"_"
)[
1
].
find
(
"2"
)
)
dkv_fp8
=
combine_tensors
([
dk_fp8
,
dv_fp8
],
dim
)
dkv_c_fp8
=
dkv_fp8
.
view
(
# print quantizers
-
1
,
dkv_fp8
.
shape
[
-
3
]
*
dkv_fp8
.
shape
[
-
2
]
*
dkv_fp8
.
shape
[
-
1
]
print_quantizers
(
)
"FusedAttnFunc.backward >> after: "
,
dkv
=
dkv_c_fp8
.
dequantize
()
ctx
.
layer_number
,
dk
,
dv
=
SplitAlongDim
.
apply
(
dkv
,
dim
,
[
1
,
1
],
True
)
ctx
.
QKV_quantizer
,
if
qkv_group
==
3
:
ctx
.
O_quantizer
,
dq
=
dq_fp8
.
de
quantize
()
ctx
.
S_
quantize
r
,
dk
=
dk_fp8
.
de
quantize
()
ctx
.
dQKV_
quantize
r
,
dv
=
dv_fp8
.
de
quantize
()
ctx
.
dO_
quantize
r
,
else
:
ctx
.
dP_quantizer
,
dq
,
dk
,
dv
=
dq_fp8
,
dk_fp8
,
dv_fp8
)
else
:
else
:
if
isinstance
(
d_out
,
QuantizedTensor
):
if
isinstance
(
d_out
,
QuantizedTensor
Base
):
d_out
=
d_out
.
dequantize
()
d_out
=
d_out
.
dequantize
(
dtype
=
ctx
.
nominal_dtype
)
dqkv_dtype
=
TE_DType
[
d_out
.
dtype
]
dqkv_
te_
dtype
=
TE_DType
[
d_out
.
dtype
]
# q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16
# q, k, v, out, d_out, dq, dk, dv:
torch.Tensor;
torch.float16 or torch.bfloat16
dq
,
dk
,
dv
,
*
rest
=
fused_attn_bwd
(
dq
,
dk
,
dv
,
*
rest
=
fused_attn_bwd
(
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_kv
,
ctx
.
max_seqlen_kv
,
...
@@ -1274,8 +1495,8 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1274,8 +1495,8 @@ class FusedAttnFunc(torch.autograd.Function):
v
,
v
,
out
,
out
,
d_out
,
d_out
,
fake
_dtype
,
dqkv_nominal
_dtype
,
dqkv_dtype
,
dqkv_
te_
dtype
,
aux_ctx_tensors
,
aux_ctx_tensors
,
ctx
.
fused_attention_backend
,
ctx
.
fused_attention_backend
,
cu_seqlens_q_padded
,
cu_seqlens_q_padded
,
...
@@ -1289,42 +1510,17 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1289,42 +1510,17 @@ class FusedAttnFunc(torch.autograd.Function):
ctx
.
qkv_layout
,
ctx
.
qkv_layout
,
ctx
.
attn_bias_type
,
ctx
.
attn_bias_type
,
ctx
.
attn_mask_type
,
ctx
.
attn_mask_type
,
ctx
.
softmax_type
,
ctx
.
window_size
,
ctx
.
window_size
,
ctx
.
deterministic
,
ctx
.
deterministic
,
)
)
# if no_bias or alibi, return dqkv
d_bias
=
None
if
ctx
.
attn_bias_type
in
[
"no_bias"
,
"alibi"
]:
if
ctx
.
attn_bias_type
not
in
[
"no_bias"
,
"alibi"
]:
return
(
d_bias
=
rest
[
0
]
None
,
d_softmax_offset
=
None
None
,
if
ctx
.
softmax_type
!=
"vanilla"
:
None
,
d_softmax_offset
=
rest
[
1
]
None
,
None
,
None
,
None
,
None
,
None
,
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
# else, return (dqkv, dbias)
return
(
return
(
None
,
None
,
None
,
None
,
...
@@ -1338,7 +1534,10 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1338,7 +1534,10 @@ class FusedAttnFunc(torch.autograd.Function):
dq
,
dq
,
dk
,
dk
,
dv
,
dv
,
rest
[
0
],
d_bias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
...
@@ -1351,6 +1550,7 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -1351,6 +1550,7 @@ class FusedAttnFunc(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
d_softmax_offset
,
None
,
None
,
None
,
None
,
)
)
...
@@ -1392,6 +1592,7 @@ class FusedAttention(torch.nn.Module):
...
@@ -1392,6 +1592,7 @@ class FusedAttention(torch.nn.Module):
attention_type
:
str
=
"self"
,
attention_type
:
str
=
"self"
,
layer_number
:
Optional
[
int
]
=
None
,
layer_number
:
Optional
[
int
]
=
None
,
deterministic
:
bool
=
False
,
deterministic
:
bool
=
False
,
softmax_type
:
str
=
"vanilla"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -1404,6 +1605,7 @@ class FusedAttention(torch.nn.Module):
...
@@ -1404,6 +1605,7 @@ class FusedAttention(torch.nn.Module):
)
==
"1"
and
get_device_compute_capability
()
==
(
9
,
0
)
)
==
"1"
and
get_device_compute_capability
()
==
(
9
,
0
)
self
.
layer_number
=
1
if
layer_number
is
None
else
layer_number
self
.
layer_number
=
1
if
layer_number
is
None
else
layer_number
self
.
deterministic
=
deterministic
self
.
deterministic
=
deterministic
self
.
softmax_type
=
softmax_type
def
remove_extra_states_check
(
self
,
incompatible_keys
):
# pylint: disable=unused-argument
def
remove_extra_states_check
(
self
,
incompatible_keys
):
# pylint: disable=unused-argument
"""
"""
...
@@ -1455,6 +1657,8 @@ class FusedAttention(torch.nn.Module):
...
@@ -1455,6 +1657,8 @@ class FusedAttention(torch.nn.Module):
quantizers
=
None
,
quantizers
=
None
,
pad_between_seqs
:
bool
=
False
,
pad_between_seqs
:
bool
=
False
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
softmax_offset
:
torch
.
Tensor
=
None
,
fp8_output
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""fused attention fprop"""
"""fused attention fprop"""
assert
(
assert
(
...
@@ -1555,15 +1759,27 @@ class FusedAttention(torch.nn.Module):
...
@@ -1555,15 +1759,27 @@ class FusedAttention(torch.nn.Module):
)
)
if
fp8
:
if
fp8
:
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8_meta
is
not
None
and
fp8_meta
.
get
(
"local_recipes"
,
None
)
is
not
None
:
fp8_recipe
=
fp8_meta
[
"local_recipes"
][
0
]
assert
fused_attention_backend
==
tex
.
NVTE_Fused_Attn_Backend
.
NVTE_FP8
,
(
assert
fused_attention_backend
==
tex
.
NVTE_Fused_Attn_Backend
.
NVTE_FP8
,
(
f
"cuDNN attention sub-backend
{
int
(
tex
.
NVTE_Fused_Attn_Backend
.
NVTE_FP8
)
}
"
f
"cuDNN attention sub-backend
{
int
(
tex
.
NVTE_Fused_Attn_Backend
.
NVTE_FP8
)
}
"
" is required for FP8 attention!"
" is required for FP8 attention!"
)
)
assert
fp8_meta
is
not
None
,
"FP8 metadata fp8_meta is required for FP8 attention!"
assert
fp8_meta
is
not
None
,
"FP8 metadata fp8_meta is required for FP8 attention!"
assert
not
context_parallel
or
fp8_meta
[
"recipe"
].
reduce_amax
,
(
if
fp8_recipe
.
delayed
():
"Amax reduction across TP+CP group is necessary when using context parallelism with"
assert
not
context_parallel
or
fp8_recipe
.
reduce_amax
,
(
" FP8!"
"Amax reduction across TP+CP group is necessary when using context parallelism"
)
" with FP8!"
)
if
fp8_recipe
.
float8_current_scaling
()
and
context_parallel
:
all_quantizers
=
dpa_utils
.
get_attention_quantizers
(
fp8
,
quantizers
)
for
q
in
all_quantizers
:
if
isinstance
(
q
,
Float8CurrentScalingQuantizer
):
q
.
with_amax_reduction
=
True
q
.
amax_reduction_group
=
(
cp_group
[
0
]
if
cp_comm_type
==
"a2a+p2p"
else
cp_group
)
if
context_parallel
:
if
context_parallel
:
assert
(
assert
(
...
@@ -1605,6 +1821,10 @@ class FusedAttention(torch.nn.Module):
...
@@ -1605,6 +1821,10 @@ class FusedAttention(torch.nn.Module):
fp8_meta
=
fp8_meta
,
fp8_meta
=
fp8_meta
,
quantizers
=
quantizers
,
quantizers
=
quantizers
,
pad_between_seqs
=
pad_between_seqs
,
pad_between_seqs
=
pad_between_seqs
,
softmax_type
=
self
.
softmax_type
,
softmax_offset
=
softmax_offset
,
fp8_output
=
fp8_output
,
layer_number
=
self
.
layer_number
,
)
)
else
:
else
:
with
self
.
attention_dropout_ctx
():
with
self
.
attention_dropout_ctx
():
...
@@ -1628,6 +1848,7 @@ class FusedAttention(torch.nn.Module):
...
@@ -1628,6 +1848,7 @@ class FusedAttention(torch.nn.Module):
qkv_layout
,
qkv_layout
,
core_attention_bias_type
,
core_attention_bias_type
,
attn_mask_type
,
attn_mask_type
,
self
.
softmax_type
,
window_size
,
window_size
,
None
,
# rng_gen
None
,
# rng_gen
fused_attention_backend
,
fused_attention_backend
,
...
@@ -1636,6 +1857,9 @@ class FusedAttention(torch.nn.Module):
...
@@ -1636,6 +1857,9 @@ class FusedAttention(torch.nn.Module):
fp8_meta
,
fp8_meta
,
quantizers
,
quantizers
,
self
.
deterministic
,
self
.
deterministic
,
softmax_offset
,
fp8_output
,
self
.
layer_number
,
)
)
# ...hd -> ...(hd)
# ...hd -> ...(hd)
...
...
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
View file @
53fa872c
This source diff could not be displayed because it is too large. You can
view the blob
instead.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
View file @
53fa872c
...
@@ -11,10 +11,25 @@ import warnings
...
@@ -11,10 +11,25 @@ import warnings
import
logging
import
logging
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
(
Format
,
Recipe
,
DelayedScaling
,
Float8CurrentScaling
,
)
from
transformer_engine.pytorch.utils
import
get_cudnn_version
from
transformer_engine.pytorch.utils
import
get_cudnn_version
from
transformer_engine.pytorch.fp8
import
get_fp8_te_dtype
from
transformer_engine.pytorch.fp8
import
(
get_fp8_te_dtype
,
FP8GlobalStateManager
,
RecipeState
,
DelayedScalingRecipeState
,
MXFP8BlockScalingRecipeState
,
Float8CurrentScalingRecipeState
,
Float8BlockScalingRecipeState
,
)
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
...
@@ -72,6 +87,67 @@ _alibi_cache = {
...
@@ -72,6 +87,67 @@ _alibi_cache = {
"_alibi_bias_require_update"
:
False
,
"_alibi_bias_require_update"
:
False
,
}
}
"""
This feature is **experimental** and subject to change.
Some models may use different FP8 recipes for their linear layers and attention layers. To support this,
users can either use multiple, nested fp8_autocast() contexts to assign a distinct recipe for each layer,
or use a single fp8_autocast() for the non-attention layers and configure the recipe for the attention
layers as follows.
+-------------------+-----------+-----------------------------------------------------------------------------------+
| Linear | Attention | Configuration |
+===================+===========+===================================================================================+
| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to fp8_autocast(); |
| | | export NVTE_DPA_FP8_RECIPE="F16" |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8DS | Pass FP8DS to fp8_autocast(); |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8DS | Pass FP8CS to fp8_autocast(); |
| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8DS | Pass NVFP4 to fp8_autocast(); |
| | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8CS | Pass FP8DS to fp8_autocast(); |
| | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,|
| | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8CS | Pass FP8CS to fp8_autocast(); |
| | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe |
| | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8CS | Pass NVFP4 to fp8_autocast(); |
| | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe |
| | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
"""
_dpa_fp8_recipe
=
os
.
getenv
(
"NVTE_DPA_FP8_RECIPE"
,
""
)
formats
=
{
"HYBRID"
:
Format
.
HYBRID
,
"E4M3"
:
Format
.
E4M3
,
"E5M2"
:
Format
.
E5M2
}
_dpa_fp8_format
=
formats
[
os
.
getenv
(
"NVTE_DPA_FP8_FORMAT"
,
"HYBRID"
)]
_dpa_fp8ds_amax_algo
=
os
.
getenv
(
"NVTE_DPA_FP8DS_AMAX_ALGO"
,
"most_recent"
)
_dpa_fp8ds_amax_histlen
=
int
(
os
.
getenv
(
"NVTE_DPA_FP8DS_AMAX_HISTLEN"
,
"1"
))
_dpa_fp8ds_reduce_amax
=
os
.
getenv
(
"NVTE_DPA_FP8DS_REDUCE_AMAX"
,
"1"
)
==
"1"
__all__
=
[
"DotProductAttention"
]
__all__
=
[
"DotProductAttention"
]
...
@@ -168,6 +244,17 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -168,6 +244,17 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_scale: Optional[float], default = `None`
softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to
softmax scale for the attention scores. If `None`, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
Parallelism parameters
----------------------
----------------------
...
@@ -223,6 +310,7 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -223,6 +310,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream
:
torch
.
cuda
.
Stream
=
None
,
cp_stream
:
torch
.
cuda
.
Stream
=
None
,
cp_comm_type
:
str
=
"p2p"
,
cp_comm_type
:
str
=
"p2p"
,
softmax_scale
:
Optional
[
float
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
softmax_type
:
str
=
"vanilla"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -307,6 +395,20 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -307,6 +395,20 @@ class DotProductAttention(TransformerEngineBaseModule):
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
self
.
softmax_type
=
softmax_type
if
self
.
softmax_type
==
"vanilla"
:
self
.
softmax_offset
=
None
if
self
.
softmax_type
==
"off-by-one"
:
self
.
softmax_offset
=
torch
.
zeros
(
self
.
num_attention_heads
//
self
.
tp_size
,
device
=
"cuda"
)
if
self
.
softmax_type
==
"learnable"
:
self
.
register_parameter
(
"softmax_offset"
,
Parameter
(
torch
.
empty
(
self
.
num_attention_heads
//
self
.
tp_size
,
device
=
"cuda"
)),
get_rng_state_tracker
=
get_rng_state_tracker
,
)
attn_kwargs
=
{
attn_kwargs
=
{
"attention_dropout"
:
attention_dropout
,
"attention_dropout"
:
attention_dropout
,
"attention_dropout_ctx"
:
attention_dropout_ctx
,
"attention_dropout_ctx"
:
attention_dropout_ctx
,
...
@@ -328,6 +430,7 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -328,6 +430,7 @@ class DotProductAttention(TransformerEngineBaseModule):
layer_number
=
layer_number
,
layer_number
=
layer_number
,
deterministic
=
self
.
deterministic
,
deterministic
=
self
.
deterministic
,
**
attn_kwargs
,
**
attn_kwargs
,
softmax_type
=
self
.
softmax_type
,
)
)
self
.
unfused_attention
=
UnfusedDotProductAttention
(
self
.
unfused_attention
=
UnfusedDotProductAttention
(
...
@@ -335,6 +438,7 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -335,6 +438,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type
=
attention_type
,
attention_type
=
attention_type
,
**
attn_kwargs
,
**
attn_kwargs
,
layer_number
=
layer_number
,
layer_number
=
layer_number
,
softmax_type
=
self
.
softmax_type
,
)
)
def
remove_extra_states_check
(
self
,
incompatible_keys
):
# pylint: disable=unused-argument
def
remove_extra_states_check
(
self
,
incompatible_keys
):
# pylint: disable=unused-argument
...
@@ -433,6 +537,231 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -433,6 +537,231 @@ class DotProductAttention(TransformerEngineBaseModule):
self
.
cp_stream
=
cp_stream
self
.
cp_stream
=
cp_stream
self
.
cp_comm_type
=
cp_comm_type
self
.
cp_comm_type
=
cp_comm_type
def
init_fp8_metadata
(
self
,
num_gemms
:
int
=
1
)
->
None
:
"""
Override TransformerEngineBaseModule.init_fp8_metadata to allow for more flexible recipe support.
Initialize fp8 related metadata and tensors during fprop.
"""
_original_recipe
=
self
.
fp8_meta
.
get
(
"recipe"
,
None
)
# global recipe set in fp8_autocast()
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
# switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to
# a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well.
#
# fp8_recipe | NVTE_DPA_FP8_RECIPE | self.fp8_meta["recipe"] | self.quantizers
# --------------------------------------------------------------------------------------------
# DelayedScaling (DS) | unset | DS | all DS
# Float8CurrentScaling (CS) | unset | DS | CS for QKV, O, dO, dQKV; DS for S, dP
# x={DS, CS} | y | refer to row x=y | refer to row x=y
fp8_recipe_dpa
=
fp8_recipe
fp8_recipes
=
fp8_recipe
if
_dpa_fp8_recipe
==
"F16"
:
# ignore the recipe from fp8_autocast, set fp8_dpa = False, fp8_mha = False
fp8_recipe
.
fp8_dpa
=
False
fp8_recipe
.
fp8_mha
=
False
elif
fp8_recipe
.
float8_current_scaling
()
and
_dpa_fp8_recipe
==
"DelayedScaling"
:
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe
fake_recipe
=
DelayedScaling
(
fp8_format
=
fp8_recipe
.
fp8_format
,
amax_history_len
=
_dpa_fp8ds_amax_histlen
,
amax_compute_algo
=
_dpa_fp8ds_amax_algo
,
fp8_dpa
=
fp8_recipe
.
fp8_dpa
,
fp8_mha
=
fp8_recipe
.
fp8_mha
,
reduce_amax
=
_dpa_fp8ds_reduce_amax
,
)
fp8_recipe_dpa
=
fake_recipe
fp8_recipes
=
fp8_recipe_dpa
elif
fp8_recipe
.
nvfp4
()
and
_dpa_fp8_recipe
==
"DelayedScaling"
:
# reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a DS recipe
fake_recipe
=
DelayedScaling
(
fp8_format
=
_dpa_fp8_format
,
amax_history_len
=
_dpa_fp8ds_amax_histlen
,
amax_compute_algo
=
_dpa_fp8ds_amax_algo
,
fp8_dpa
=
fp8_recipe
.
fp8_dpa
,
fp8_mha
=
fp8_recipe
.
fp8_mha
,
reduce_amax
=
_dpa_fp8ds_reduce_amax
,
)
fp8_recipe_dpa
=
fake_recipe
fp8_recipes
=
fp8_recipe_dpa
elif
fp8_recipe
.
delayed
()
and
_dpa_fp8_recipe
==
"Float8CurrentScaling"
:
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe
fake_recipes
=
[
Float8CurrentScaling
(
fp8_format
=
fp8_recipe
.
fp8_format
,
fp8_dpa
=
fp8_recipe
.
fp8_dpa
,
fp8_mha
=
fp8_recipe
.
fp8_mha
,
),
fp8_recipe
,
]
fp8_recipe_dpa
=
fake_recipes
[
1
]
fp8_recipes
=
fake_recipes
elif
fp8_recipe
.
float8_current_scaling
()
and
_dpa_fp8_recipe
in
(
""
,
"Float8CurrentScaling"
,
):
# use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe
fake_recipe
=
DelayedScaling
(
fp8_format
=
fp8_recipe
.
fp8_format
,
amax_history_len
=
_dpa_fp8ds_amax_histlen
,
amax_compute_algo
=
_dpa_fp8ds_amax_algo
,
fp8_dpa
=
fp8_recipe
.
fp8_dpa
,
fp8_mha
=
fp8_recipe
.
fp8_mha
,
reduce_amax
=
_dpa_fp8ds_reduce_amax
,
)
fp8_recipe_dpa
=
fake_recipe
fp8_recipes
=
[
fp8_recipe
,
fp8_recipe_dpa
]
elif
fp8_recipe
.
nvfp4
()
and
_dpa_fp8_recipe
==
"Float8CurrentScaling"
:
# reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format
# construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP
fake_recipes
=
[
Float8CurrentScaling
(
fp8_format
=
_dpa_fp8_format
,
fp8_dpa
=
fp8_recipe
.
fp8_dpa
,
fp8_mha
=
fp8_recipe
.
fp8_mha
,
),
DelayedScaling
(
fp8_format
=
_dpa_fp8_format
,
amax_history_len
=
_dpa_fp8ds_amax_histlen
,
amax_compute_algo
=
_dpa_fp8ds_amax_algo
,
fp8_dpa
=
fp8_recipe
.
fp8_dpa
,
fp8_mha
=
fp8_recipe
.
fp8_mha
,
reduce_amax
=
_dpa_fp8ds_reduce_amax
,
),
]
fp8_recipe_dpa
=
fake_recipes
[
1
]
fp8_recipes
=
fake_recipes
# DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False
if
not
fp8_recipe_dpa
.
float8_per_tensor_scaling
():
assert
not
(
fp8_recipe_dpa
.
fp8_dpa
or
fp8_recipe_dpa
.
fp8_mha
),
f
"DotProductAttention does not support
{
fp8_recipe_dpa
.
__class__
.
__name__
}
recipe"
# reduce over TP+CP groups; expect fp8_group to be set up so
# assume attention uses the same fp8_group as GEMMs
fp8_group
=
FP8GlobalStateManager
.
get_fp8_group
()
self
.
fp8_parameters
=
FP8GlobalStateManager
.
with_fp8_parameters
()
self
.
fp8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
self
.
fp8_calibration
=
FP8GlobalStateManager
.
is_fp8_calibration
()
fp8_enabled
=
self
.
fp8
or
self
.
fp8_calibration
self
.
fp8_meta
[
"fp8_checkpoint"
]
=
self
.
fp8
or
self
.
fp8_calibration
if
self
.
fp8_parameters
or
fp8_enabled
:
self
.
fp8_meta
[
"global_recipe"
]
=
fp8_recipe
self
.
fp8_meta
[
"local_recipes"
]
=
(
fp8_recipes
if
isinstance
(
fp8_recipes
,
List
)
else
[
fp8_recipes
]
)
if
self
.
fp8_parameters
or
fp8_enabled
:
if
self
.
fp8_initialized
and
fp8_recipe_dpa
==
self
.
fp8_meta
[
"recipe"
]:
# FP8 init has already been run and recipe is the same, don't do anything.
return
self
.
fp8_meta
[
"recipe"
]
=
fp8_recipe_dpa
if
fp8_recipe
!=
fp8_recipe_dpa
:
# fp8_recipe has changed, rehash the key.
autocast_key
=
FP8GlobalStateManager
.
get_unique_autocast_key
(
fp8_recipe_dpa
,
fp8_group
)
FP8GlobalStateManager
.
autocast_arguments
[
autocast_key
]
=
(
fp8_recipe_dpa
,
fp8_group
,
)
else
:
# If fp8 isn't enabled, turn off and return.
self
.
fp8_initialized
=
False
return
if
self
.
fp8_parameters
and
not
self
.
fp8_initialized
:
self
.
fp8_meta
[
"num_gemms"
]
=
num_gemms
self
.
init_fp8_meta_tensors
(
fp8_recipes
)
if
fp8_enabled
:
# Set FP8 and other FP8 metadata
self
.
fp8_meta
[
"num_gemms"
]
=
num_gemms
self
.
fp8_meta
[
"fp8_group"
]
=
fp8_group
# Set FP8_MAX per tensor according to recipe
self
.
fp8_meta
[
"fp8_max_fwd"
]
=
self
.
fp8_meta
[
"recipe"
].
fp8_format
.
value
.
max_fwd
self
.
fp8_meta
[
"fp8_max_bwd"
]
=
self
.
fp8_meta
[
"recipe"
].
fp8_format
.
value
.
max_bwd
# Allocate scales and amaxes
self
.
init_fp8_meta_tensors
(
fp8_recipes
)
self
.
fp8_initialized
=
True
self
.
fp8_meta
[
"recipe"
]
=
fp8_recipe_dpa
if
fp8_recipe
!=
fp8_recipe_dpa
:
# fp8_recipe has changed, rehash the key.
autocast_key
=
FP8GlobalStateManager
.
get_unique_autocast_key
(
fp8_recipe_dpa
,
fp8_group
)
FP8GlobalStateManager
.
autocast_arguments
[
autocast_key
]
=
(
fp8_recipe_dpa
,
fp8_group
,
)
_current_recipe
=
self
.
fp8_meta
[
"recipe"
]
if
_original_recipe
is
not
None
and
not
(
issubclass
(
_current_recipe
.
__class__
,
_original_recipe
.
__class__
)
or
issubclass
(
_original_recipe
.
__class__
,
_current_recipe
.
__class__
)
):
warnings
.
warn
(
f
"Recipe type changed from
{
_original_recipe
.
__class__
.
__name__
}
"
f
"to
{
_current_recipe
.
__class__
.
__name__
}
. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self
.
_fp8_workspaces
.
clear
()
def
set_meta_tensor
(
self
,
fwd
:
bool
,
recipe
:
Union
[
Recipe
,
List
[
Recipe
]])
->
None
:
"""Override to allow multiple recipes. Init scales and amaxes for fwd | bwd."""
if
isinstance
(
recipe
,
Recipe
):
recipe
=
[
recipe
]
fp8_recipe_dpa
=
recipe
[
-
1
]
fp8_meta_tensor_key
=
"scaling_fwd"
if
fwd
else
"scaling_bwd"
# Return early if recipe state matches recipe
if
self
.
fp8_meta_tensors_initialized
:
recipe_state
=
self
.
fp8_meta
[
fp8_meta_tensor_key
]
if
fp8_recipe_dpa
.
delayed
()
and
isinstance
(
recipe_state
,
DelayedScalingRecipeState
):
self
.
adjust_amax_history_length
(
fp8_recipe_dpa
.
amax_history_len
,
fwd
=
fwd
)
return
if
fp8_recipe_dpa
.
mxfp8
()
and
isinstance
(
recipe_state
,
MXFP8BlockScalingRecipeState
):
return
if
fp8_recipe_dpa
.
float8_current_scaling
()
and
isinstance
(
recipe_state
,
Float8CurrentScalingRecipeState
):
return
if
fp8_recipe_dpa
.
float8_block_scaling
()
and
isinstance
(
recipe_state
,
Float8BlockScalingRecipeState
):
return
# When fp8_recipe=Float8CurrentScaling, recipe=[CS, DS], and QKV/dQKV, O/dO use CS quantizers, S/dP use DS quantizers.
# See table above in init_fp8_metadata for more detail.
num_gemms
=
[
2
,
1
]
if
len
(
recipe
)
==
2
else
[
3
]
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors
=
[
x
*
3
if
fwd
else
x
*
2
for
x
in
num_gemms
]
# Initialize recipe state and quantizers
recipe_states
=
[
RecipeState
.
create
(
recipe
[
i
],
mode
=
(
"forward"
if
fwd
else
"backward"
),
num_quantizers
=
num_fp8_tensors
[
i
],
)
for
i
in
range
(
len
(
recipe
))
]
self
.
fp8_meta
[
fp8_meta_tensor_key
]
=
(
recipe_states
[
-
1
]
if
len
(
recipe
)
==
2
else
recipe_states
[
0
]
)
self
.
quantizers
[
fp8_meta_tensor_key
]
=
[]
for
recipe_state
in
recipe_states
:
self
.
quantizers
[
fp8_meta_tensor_key
].
extend
(
recipe_state
.
make_quantizers
())
@
no_torch_dynamo
(
recursive
=
False
)
@
no_torch_dynamo
(
recursive
=
False
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -456,6 +785,7 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -456,6 +785,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fast_zero_fill
:
bool
=
True
,
fast_zero_fill
:
bool
=
True
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
pad_between_seqs
:
Optional
[
bool
]
=
None
,
pad_between_seqs
:
Optional
[
bool
]
=
None
,
fp8_output
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Dot Product Attention Layer.
Dot Product Attention Layer.
...
@@ -628,12 +958,15 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -628,12 +958,15 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs: Optional[bool], default = `None`
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
If true, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False`
Whether to enforce output to be in FP8 or not.
"""
"""
with
torch
.
cuda
.
device
(
query_layer
.
device
),
self
.
prepare_forward
(
with
torch
.
cuda
.
device
(
query_layer
.
device
),
self
.
prepare_forward
(
query_layer
,
query_layer
,
num_gemms
=
3
,
num_gemms
=
3
,
allow_non_contiguous
=
True
,
allow_non_contiguous
=
True
,
allow_different_data_and_param_types
=
self
.
softmax_type
!=
"vanilla"
,
)
as
query_layer
:
)
as
query_layer
:
# checks for RNG
# checks for RNG
if
self
.
rng_states_tracker
is
not
None
and
is_graph_capturing
():
if
self
.
rng_states_tracker
is
not
None
and
is_graph_capturing
():
...
@@ -663,6 +996,8 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -663,6 +996,8 @@ class DotProductAttention(TransformerEngineBaseModule):
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
,
tex
.
DType
.
kFloat8E5M2
,
],
"""DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
],
"""DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
else
:
fp8_output
=
False
# checks for q/k/v shapes
# checks for q/k/v shapes
assert
(
assert
(
...
@@ -922,6 +1257,7 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -922,6 +1257,7 @@ class DotProductAttention(TransformerEngineBaseModule):
False
False
),
"core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
),
"core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
# check if there is padding between sequences when qkv_format='thd'
if
pad_between_seqs
is
None
:
if
pad_between_seqs
is
None
:
if
qkv_format
==
"thd"
:
if
qkv_format
==
"thd"
:
pad_between_seqs
=
(
pad_between_seqs
=
(
...
@@ -957,11 +1293,13 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -957,11 +1293,13 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs
=
pad_between_seqs
,
pad_between_seqs
=
pad_between_seqs
,
attention_dropout
=
self
.
attention_dropout
,
attention_dropout
=
self
.
attention_dropout
,
context_parallel
=
context_parallel
,
context_parallel
=
context_parallel
,
cp_comm_type
=
self
.
cp_comm_type
,
deterministic
=
self
.
deterministic
,
deterministic
=
self
.
deterministic
,
is_training
=
self
.
training
,
is_training
=
self
.
training
,
fp8
=
self
.
fp8
,
fp8
=
self
.
fp8
,
fp8_meta
=
self
.
fp8_meta
,
fp8_meta
=
self
.
fp8_meta
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
softmax_type
=
self
.
softmax_type
,
)
)
global
_attention_backends
global
_attention_backends
if
is_in_onnx_export_mode
():
if
is_in_onnx_export_mode
():
...
@@ -1022,6 +1360,12 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -1022,6 +1360,12 @@ class DotProductAttention(TransformerEngineBaseModule):
)
)
# run attention
# run attention
softmax_offset
=
(
self
.
softmax_offset
.
reshape
(
1
,
-
1
,
1
,
1
).
to
(
torch
.
float32
)
if
self
.
softmax_offset
is
not
None
else
None
)
if
use_flash_attention
:
if
use_flash_attention
:
if
core_attention_bias_type
==
"alibi"
:
if
core_attention_bias_type
==
"alibi"
:
alibi_slopes
,
_
=
dpa_utils
.
get_alibi
(
alibi_slopes
,
_
=
dpa_utils
.
get_alibi
(
...
@@ -1053,6 +1397,7 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -1053,6 +1397,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers
=
self
.
quantizers
,
quantizers
=
self
.
quantizers
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
flash_attention_backend
=
flash_attention_backend
,
flash_attention_backend
=
flash_attention_backend
,
fp8_output
=
fp8_output
,
)
)
if
use_fused_attention
:
if
use_fused_attention
:
...
@@ -1071,7 +1416,6 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -1071,7 +1416,6 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype
=
query_layer
.
dtype
,
bias_dtype
=
query_layer
.
dtype
,
bottom_right_alignment
=
attn_mask_type
not
in
[
"causal"
,
"padding_causal"
],
bottom_right_alignment
=
attn_mask_type
not
in
[
"causal"
,
"padding_causal"
],
)
)
# checkpoint_core_attention=False
if
checkpoint_core_attention
:
if
checkpoint_core_attention
:
return
self
.
_checkpointed_attention_forward
(
return
self
.
_checkpointed_attention_forward
(
self
.
fused_attention
,
self
.
fused_attention
,
...
@@ -1101,6 +1445,8 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -1101,6 +1445,8 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers
=
self
.
quantizers
,
quantizers
=
self
.
quantizers
,
pad_between_seqs
=
pad_between_seqs
,
pad_between_seqs
=
pad_between_seqs
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
softmax_offset
=
softmax_offset
,
fp8_output
=
fp8_output
,
)
)
return
self
.
fused_attention
(
return
self
.
fused_attention
(
query_layer
,
query_layer
,
...
@@ -1129,6 +1475,8 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -1129,6 +1475,8 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers
=
self
.
quantizers
,
quantizers
=
self
.
quantizers
,
pad_between_seqs
=
pad_between_seqs
,
pad_between_seqs
=
pad_between_seqs
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
softmax_offset
=
softmax_offset
,
fp8_output
=
fp8_output
,
)
)
from
transformer_engine.pytorch.cpu_offload
import
CPUOffloadEnabled
from
transformer_engine.pytorch.cpu_offload
import
CPUOffloadEnabled
...
@@ -1140,6 +1488,7 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -1140,6 +1488,7 @@ class DotProductAttention(TransformerEngineBaseModule):
)
)
if
use_unfused_attention
:
if
use_unfused_attention
:
allow_emulation
=
os
.
getenv
(
"NVTE_UnfusedDPA_Emulate_FP8"
,
"0"
)
==
"1"
if
checkpoint_core_attention
:
if
checkpoint_core_attention
:
return
self
.
_checkpointed_attention_forward
(
return
self
.
_checkpointed_attention_forward
(
self
.
unfused_attention
,
self
.
unfused_attention
,
...
@@ -1157,6 +1506,11 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -1157,6 +1506,11 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias
=
core_attention_bias
,
core_attention_bias
=
core_attention_bias
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
softmax_offset
=
softmax_offset
,
fp8
=
self
.
fp8
and
self
.
fp8_meta
[
"recipe"
].
fp8_dpa
and
allow_emulation
,
fp8_meta
=
self
.
fp8_meta
,
quantizers
=
self
.
quantizers
,
fp8_output
=
fp8_output
,
)
)
return
self
.
unfused_attention
(
return
self
.
unfused_attention
(
_alibi_cache
,
_alibi_cache
,
...
@@ -1173,5 +1527,10 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -1173,5 +1527,10 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias
=
core_attention_bias
,
core_attention_bias
=
core_attention_bias
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
softmax_offset
=
softmax_offset
,
fp8
=
self
.
fp8
and
self
.
fp8_meta
[
"recipe"
].
fp8_dpa
and
allow_emulation
,
fp8_meta
=
self
.
fp8_meta
,
quantizers
=
self
.
quantizers
,
fp8_output
=
fp8_output
,
)
)
return
None
return
None
transformer_engine/pytorch/attention/dot_product_attention/utils.py
View file @
53fa872c
...
@@ -17,6 +17,7 @@ import numpy as np
...
@@ -17,6 +17,7 @@ import numpy as np
from
packaging.version
import
Version
as
PkgVersion
from
packaging.version
import
Version
as
PkgVersion
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
transformer_engine
as
te
import
transformer_engine
as
te
...
@@ -24,6 +25,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
...
@@ -24,6 +25,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout
,
QKVLayout
,
AttnBiasType
,
AttnBiasType
,
AttnMaskType
,
AttnMaskType
,
SoftmaxType
,
FusedAttnBackend
,
FusedAttnBackend
,
META_QKV
,
META_QKV
,
META_DQKV
,
META_DQKV
,
...
@@ -31,11 +33,13 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
...
@@ -31,11 +33,13 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_DO
,
META_DO
,
META_S
,
META_S
,
META_DP
,
META_DP
,
META_O_CP
,
META_DQKV_CP
,
)
)
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.fp8
import
get_fp8_te_dtype
from
transformer_engine.pytorch.fp8
import
get_fp8_te_dtype
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
@@ -43,6 +47,8 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
...
@@ -43,6 +47,8 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from
transformer_engine.pytorch.utils
import
(
from
transformer_engine.pytorch.utils
import
(
get_device_compute_capability
,
get_device_compute_capability
,
get_cudnn_version
,
get_cudnn_version
,
SplitAlongDim
,
combine_tensors
,
)
)
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
...
@@ -53,6 +59,9 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
...
@@ -53,6 +59,9 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL
=
int
(
os
.
getenv
(
"NVTE_DEBUG_LEVEL"
,
"0"
))
_NVTE_DEBUG_LEVEL
=
int
(
os
.
getenv
(
"NVTE_DEBUG_LEVEL"
,
"0"
))
_NVTE_FLASH_ATTN
=
int
(
os
.
getenv
(
"NVTE_FLASH_ATTN"
,
"1"
))
_NVTE_FLASH_ATTN
=
int
(
os
.
getenv
(
"NVTE_FLASH_ATTN"
,
"1"
))
# print quantizer info for a particular layer on a particular rank
_print_layer
=
int
(
os
.
getenv
(
"NVTE_PRINT_LAYER_NUMBER"
,
"1"
))
_print_rank
=
int
(
os
.
getenv
(
"NVTE_PRINT_RANK"
,
"0"
))
_cu_seqlens_cache
=
{}
_cu_seqlens_cache
=
{}
...
@@ -206,6 +215,8 @@ class AttentionParams:
...
@@ -206,6 +215,8 @@ class AttentionParams:
Attention dropout.
Attention dropout.
context_parallel: bool, default = `False`
context_parallel: bool, default = `False`
Whether context parallelism is used or not.
Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False`
deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not.
Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
is_training: bool, default = `True`
...
@@ -216,6 +227,8 @@ class AttentionParams:
...
@@ -216,6 +227,8 @@ class AttentionParams:
The FP8 metadata tensor of `DotProductAttention`.
The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None`
inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details.
Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
"""
"""
qkv_type
:
Union
[
torch
.
Tensor
,
Float8Tensor
]
=
torch
.
Tensor
qkv_type
:
Union
[
torch
.
Tensor
,
Float8Tensor
]
=
torch
.
Tensor
...
@@ -237,11 +250,13 @@ class AttentionParams:
...
@@ -237,11 +250,13 @@ class AttentionParams:
pad_between_seqs
:
bool
=
False
pad_between_seqs
:
bool
=
False
attention_dropout
:
float
=
0.0
attention_dropout
:
float
=
0.0
context_parallel
:
bool
=
False
context_parallel
:
bool
=
False
cp_comm_type
:
str
=
"p2p"
deterministic
:
bool
=
False
deterministic
:
bool
=
False
is_training
:
bool
=
True
is_training
:
bool
=
True
fp8
:
bool
=
False
fp8
:
bool
=
False
fp8_meta
:
Union
[
Dict
[
str
,
Any
],
None
]
=
None
fp8_meta
:
Union
[
Dict
[
str
,
Any
],
None
]
=
None
inference_params
:
Optional
[
InferenceParams
]
=
None
inference_params
:
Optional
[
InferenceParams
]
=
None
softmax_type
:
str
=
"vanilla"
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
"""
"""
...
@@ -308,11 +323,13 @@ def get_attention_backend(
...
@@ -308,11 +323,13 @@ def get_attention_backend(
pad_between_seqs
=
attention_params
.
pad_between_seqs
pad_between_seqs
=
attention_params
.
pad_between_seqs
attention_dropout
=
attention_params
.
attention_dropout
attention_dropout
=
attention_params
.
attention_dropout
context_parallel
=
attention_params
.
context_parallel
context_parallel
=
attention_params
.
context_parallel
cp_comm_type
=
attention_params
.
cp_comm_type
deterministic
=
attention_params
.
deterministic
deterministic
=
attention_params
.
deterministic
is_training
=
attention_params
.
is_training
is_training
=
attention_params
.
is_training
fp8
=
attention_params
.
fp8
fp8
=
attention_params
.
fp8
fp8_meta
=
attention_params
.
fp8_meta
fp8_meta
=
attention_params
.
fp8_meta
inference_params
=
attention_params
.
inference_params
inference_params
=
attention_params
.
inference_params
softmax_type
=
attention_params
.
softmax_type
# Run config
# Run config
logger
=
logging
.
getLogger
(
"DotProductAttention"
)
logger
=
logging
.
getLogger
(
"DotProductAttention"
)
...
@@ -341,8 +358,31 @@ def get_attention_backend(
...
@@ -341,8 +358,31 @@ def get_attention_backend(
field
.
name
:
getattr
(
attention_params
,
field
.
name
)
for
field
in
fields
(
attention_params
)
field
.
name
:
getattr
(
attention_params
,
field
.
name
)
for
field
in
fields
(
attention_params
)
}
}
run_config
.
update
(
attention_params_dict
)
run_config
.
update
(
attention_params_dict
)
# Add FP8 environment variables to config
if
fp8
:
if
fp8
:
# all FP8 recipes: 1: (FP8 fwd, FP8 bwd), 0: (FP8 fwd, F16 bwd)
run_config
[
"NVTE_FP8_DPA_BWD"
]
=
int
(
os
.
getenv
(
"NVTE_FP8_DPA_BWD"
,
"1"
))
run_config
[
"NVTE_FP8_DPA_BWD"
]
=
int
(
os
.
getenv
(
"NVTE_FP8_DPA_BWD"
,
"1"
))
# Float8CurrentScaling: 1: use F16 O in bwd, 0: use FP8 O in bwd
run_config
[
"NVTE_DPA_FP8CS_O_in_F16"
]
=
int
(
os
.
getenv
(
"NVTE_DPA_FP8CS_O_in_F16"
,
"1"
))
# switch recipe to "F16", "DelayedScaling", or "Float8CurrentScaling"
_dpa_fp8_recipe
=
os
.
getenv
(
"NVTE_DPA_FP8_RECIPE"
,
""
)
run_config
[
"NVTE_DPA_FP8_RECIPE"
]
=
_dpa_fp8_recipe
if
_dpa_fp8_recipe
!=
""
:
# config new recipe if switched
run_config
[
"NVTE_DPA_FP8_FORMAT"
]
=
os
.
getenv
(
"NVTE_DPA_FP8_FORMAT"
,
"HYBRID"
)
run_config
[
"NVTE_DPA_FP8DS_AMAX_ALGO"
]
=
os
.
getenv
(
"NVTE_DPA_FP8DS_AMAX_ALGO"
,
"most_recent"
)
run_config
[
"NVTE_DPA_FP8DS_AMAX_HISTLEN"
]
=
int
(
os
.
getenv
(
"NVTE_DPA_FP8DS_AMAX_HISTLEN"
,
"1"
)
)
run_config
[
"NVTE_DPA_FP8DS_REDUCE_AMAX"
]
=
int
(
os
.
getenv
(
"NVTE_DPA_FP8DS_REDUCE_AMAX"
,
"1"
)
)
# UnfusedDotProductAttention: 1: allow FP8 emulation, 0: do not allow
run_config
[
"NVTE_UnfusedDPA_Emulate_FP8"
]
=
int
(
os
.
getenv
(
"NVTE_UnfusedDPA_Emulate_FP8"
,
"0"
)
)
logger
.
debug
(
"Running with config=%s"
,
run_config
)
logger
.
debug
(
"Running with config=%s"
,
run_config
)
# The following sections check if `FlashAttention` supports the provided attention params,
# The following sections check if `FlashAttention` supports the provided attention params,
...
@@ -422,8 +462,20 @@ def get_attention_backend(
...
@@ -422,8 +462,20 @@ def get_attention_backend(
logger
.
debug
(
"Disabling FlashAttention 3 for FP8 training"
)
logger
.
debug
(
"Disabling FlashAttention 3 for FP8 training"
)
use_flash_attention_3
=
False
use_flash_attention_3
=
False
if
use_unfused_attention
:
if
use_unfused_attention
:
logger
.
debug
(
"Disabling UnfusedDotProductAttention for FP8 attention"
)
allow_emulation
=
os
.
getenv
(
"NVTE_UnfusedDPA_Emulate_FP8"
,
"0"
)
==
"1"
use_unfused_attention
=
False
if
not
allow_emulation
:
logger
.
debug
(
"Disabling UnfusedDotProductAttention for FP8 attention"
)
use_unfused_attention
=
False
fp8_recipe
=
fp8_meta
[
"recipe"
]
if
fp8_meta
.
get
(
"local_recipes"
,
None
)
is
not
None
:
fp8_recipe
=
fp8_meta
[
"local_recipes"
][
0
]
if
(
use_fused_attention
and
fp8_recipe
.
float8_current_scaling
()
and
device_compute_capability
<
(
10
,
0
)
):
logger
.
debug
(
"Disabling FusedAttention for FP8 current scaling on arch < sm100"
)
use_fused_attention
=
False
# TODO: rocm fused attention backends does not support fp8 yet
# TODO: rocm fused attention backends does not support fp8 yet
if
IS_HIP_EXTENSION
and
use_fused_attention
:
if
IS_HIP_EXTENSION
and
use_fused_attention
:
logger
.
debug
(
"Disabling ROCm FusedAttention as it does not support FP8"
)
logger
.
debug
(
"Disabling ROCm FusedAttention as it does not support FP8"
)
...
@@ -581,6 +633,51 @@ def get_attention_backend(
...
@@ -581,6 +633,51 @@ def get_attention_backend(
logger
.
debug
(
"Disabling FlashAttention 3 for dropout"
)
logger
.
debug
(
"Disabling FlashAttention 3 for dropout"
)
use_flash_attention_3
=
False
use_flash_attention_3
=
False
# Filter: Softmax type
# context_parallel | softmax_type | supported backends
# ----------------------------------------------------------------------------------------------------
# no | vanilla | All
# no | off-by-one | FusedAttention, UnfusedDotProductAttention
# no | learnable | FusedAttention, UnfusedDotProductAttention
# yes | vanilla | FusedAttention, FlashAttention
# yes | off-by-one | FusedAttention
# yes | learnable | FusedAttention
if
softmax_type
!=
"vanilla"
:
logger
.
debug
(
"Disabling FlashAttention for softmax_type = %s"
,
softmax_type
)
use_flash_attention
=
False
if
fp8
and
fp8_meta
[
"recipe"
].
fp8_dpa
:
logger
.
debug
(
"Disabling FusedAttention for softmax_type = %s in FP8"
,
softmax_type
)
use_fused_attention
=
False
logger
.
debug
(
"Disabling UnfusedDotProductAttention for softmax_type = %s in FP8"
,
softmax_type
)
use_unfused_attention
=
False
if
qkv_format
==
"thd"
:
logger
.
debug
(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd"
,
softmax_type
)
use_fused_attention
=
False
logger
.
debug
(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd"
,
softmax_type
,
)
use_unfused_attention
=
False
if
context_parallel
:
logger
.
debug
(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s"
,
softmax_type
,
)
use_unfused_attention
=
False
if
cp_comm_type
!=
"a2a"
:
logger
.
debug
(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
" cp_comm_type = %s"
,
softmax_type
,
cp_comm_type
,
)
use_fused_attention
=
False
# Filter: Context parallelism
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# ----------------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------------
...
@@ -822,6 +919,7 @@ def get_attention_backend(
...
@@ -822,6 +919,7 @@ def get_attention_backend(
QKVLayout
[
qkv_layout
],
QKVLayout
[
qkv_layout
],
AttnBiasType
[
fu_core_attention_bias_type
],
AttnBiasType
[
fu_core_attention_bias_type
],
AttnMaskType
[
attn_mask_type
],
AttnMaskType
[
attn_mask_type
],
SoftmaxType
[
softmax_type
],
attention_dropout
,
attention_dropout
,
num_heads
,
num_heads
,
num_gqa_groups
,
num_gqa_groups
,
...
@@ -1836,11 +1934,10 @@ def check_set_window_size(
...
@@ -1836,11 +1934,10 @@ def check_set_window_size(
return
window_size
return
window_size
def
get_attention_quantizers
(
fp8
,
quantizers
,
cp_specific_quantizers
=
False
):
def
get_attention_quantizers
(
fp8
,
quantizers
):
"""Get the list of quantizers used in attention from the quantizers list."""
"""Get the list of quantizers used in attention from the quantizers list."""
if
not
fp8
:
if
not
fp8
:
num_of_nones
=
8
if
cp_specific_quantizers
else
6
return
[
None
]
*
6
return
[
None
]
*
num_of_nones
QKV_quantizer
=
quantizers
[
"scaling_fwd"
][
META_QKV
]
QKV_quantizer
=
quantizers
[
"scaling_fwd"
][
META_QKV
]
QKV_quantizer
.
internal
=
True
QKV_quantizer
.
internal
=
True
QKV_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
QKV_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
...
@@ -1849,6 +1946,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
...
@@ -1849,6 +1946,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
S_quantizer
=
quantizers
[
"scaling_fwd"
][
META_S
]
S_quantizer
=
quantizers
[
"scaling_fwd"
][
META_S
]
S_quantizer
.
internal
=
True
S_quantizer
.
internal
=
True
S_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
S_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
dQKV_quantizer
=
quantizers
[
"scaling_bwd"
][
META_DQKV
]
dQKV_quantizer
=
quantizers
[
"scaling_bwd"
][
META_DQKV
]
dQKV_quantizer
.
interal
=
True
dQKV_quantizer
.
interal
=
True
dQKV_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
dQKV_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
...
@@ -1858,22 +1956,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
...
@@ -1858,22 +1956,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
dP_quantizer
=
quantizers
[
"scaling_bwd"
][
META_DP
]
dP_quantizer
=
quantizers
[
"scaling_bwd"
][
META_DP
]
dP_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
dP_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
dP_quantizer
.
interal
=
True
dP_quantizer
.
interal
=
True
dQKV_CP_quantizer
=
quantizers
[
"scaling_bwd"
][
META_DQKV_CP
]
dQKV_CP_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
return
QKV_quantizer
,
O_quantizer
,
S_quantizer
,
dQKV_quantizer
,
dO_quantizer
,
dP_quantizer
dQKV_CP_quantizer
.
internal
=
True
O_CP_quantizer
=
quantizers
[
"scaling_fwd"
][
META_O_CP
]
O_CP_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
def
print_quantizers
(
label
,
if
cp_specific_quantizers
:
layer_number
,
return
(
QKV_quantizer
,
O_quantizer
,
S_quantizer
,
dQKV_quantizer
,
dO_quantizer
,
dP_quantizer
,
):
"""Print the type and scale/amax of attention quantizers"""
_to_print
=
_NVTE_DEBUG
*
_NVTE_DEBUG_LEVEL
==
2
if
(
_to_print
and
_print_layer
==
layer_number
and
(
not
dist
.
is_initialized
()
or
(
dist
.
is_initialized
()
and
dist
.
get_rank
()
==
_print_rank
)
)
):
names
=
[
"QKV_quantizer"
,
"S_quantizer"
,
"O_quantizer"
,
"dO_quantizer"
,
"dP_quantizer"
,
"dQKV_quantizer"
,
]
quantizers
=
[
QKV_quantizer
,
QKV_quantizer
,
O_quantizer
,
O_CP_quantizer
,
S_quantizer
,
S_quantizer
,
dQKV_quantizer
,
O_quantizer
,
dQKV_CP_quantizer
,
dO_quantizer
,
dO_quantizer
,
dP_quantizer
,
dP_quantizer
,
)
dQKV_quantizer
,
]
if
"forward"
in
label
:
names
=
names
[:
3
]
quantizers
=
quantizers
[:
3
]
if
"backward"
in
label
:
names
=
names
[
3
:]
quantizers
=
quantizers
[
3
:]
for
i
,
q
in
enumerate
(
quantizers
):
type_str
=
""
if
q
is
None
:
type_str
=
"None"
elif
isinstance
(
q
,
Float8Quantizer
):
type_str
=
"DS"
elif
isinstance
(
q
,
Float8CurrentScalingQuantizer
):
type_str
=
"CS"
print
(
f
"
{
label
}
>>
{
names
[
i
]:
14
s
}
:
{
type_str
}
,
{
q
.
scale
.
item
():.
4
e
}
x"
f
"
{
q
.
amax
.
item
():.
4
e
}
=
{
q
.
scale
.
item
()
*
q
.
amax
.
item
():.
4
e
}
"
)
return
QKV_quantizer
,
O_quantizer
,
S_quantizer
,
dQKV_quantizer
,
dO_quantizer
,
dP_quantizer
def
combine_and_quantize
(
qkv_layout
,
q
,
k
,
v
,
qkv_quantizer
):
"""Combine q,k,v based on qkv_layout and quantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout
=
qkv_layout
.
replace
(
"paged_kv_"
,
""
)
qkv_group
=
len
(
qkv_layout
.
split
(
"_"
))
src_nominal_dtype
=
q
.
dtype
match
qkv_group
:
case
1
:
dim
=
qkv_layout
.
find
(
"3"
)
qkv
=
combine_tensors
([
q
,
k
,
v
],
dim
)
qkv_fp8
=
qkv_quantizer
(
qkv
)
q_data
,
k_data
,
v_data
=
SplitAlongDim
.
apply
(
qkv_fp8
.
_data
,
dim
,
[
1
,
1
,
1
],
True
)
case
2
:
dim
=
qkv_layout
.
split
(
"_"
)[
1
].
find
(
"2"
)
kv
=
combine_tensors
([
k
,
v
],
dim
)
tensors
=
[
q
,
kv
]
num_tensors
=
len
(
tensors
)
shapes
=
[
x
.
shape
for
x
in
tensors
]
numels
=
[
x
.
numel
()
for
x
in
tensors
]
numels
=
[
sum
(
numels
[:
i
])
for
i
in
range
(
num_tensors
+
1
)]
qkv
=
torch
.
cat
([
x
.
view
(
-
1
)
for
x
in
tensors
],
dim
=
0
)
qkv_fp8
=
qkv_quantizer
(
qkv
)
q_data
,
kv_data
=
[
qkv_fp8
.
_data
[
numels
[
i
]
:
numels
[
i
+
1
]].
view
(
shapes
[
i
])
for
i
in
range
(
num_tensors
)
]
k_data
,
v_data
=
SplitAlongDim
.
apply
(
kv_data
,
dim
,
[
1
,
1
],
True
)
case
3
:
tensors
=
[
q
,
k
,
v
]
num_tensors
=
len
(
tensors
)
shapes
=
[
x
.
shape
for
x
in
tensors
]
numels
=
[
x
.
numel
()
for
x
in
tensors
]
numels
=
[
sum
(
numels
[:
i
])
for
i
in
range
(
num_tensors
+
1
)]
qkv
=
torch
.
cat
([
x
.
view
(
-
1
)
for
x
in
tensors
],
dim
=
0
)
qkv_fp8
=
qkv_quantizer
(
qkv
)
q_data
,
k_data
,
v_data
=
[
qkv_fp8
.
_data
[
numels
[
i
]
:
numels
[
i
+
1
]].
view
(
shapes
[
i
])
for
i
in
range
(
num_tensors
)
]
case
_
:
raise
RuntimeError
(
"Invalid qkv_layout "
+
qkv_layout
)
q_fp8
,
k_fp8
,
v_fp8
=
[
Float8Tensor
.
make_like
(
qkv_fp8
,
data
=
x
,
dtype
=
src_nominal_dtype
)
for
x
in
[
q_data
,
k_data
,
v_data
]
]
return
q_fp8
,
k_fp8
,
v_fp8
def
combine_and_dequantize
(
qkv_layout
,
q_fp8
,
k_fp8
,
v_fp8
,
src_nominal_dtype
=
None
,
des_nominal_dtype
=
None
):
"""Combine q,k,v based on qkv_layout and dequantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout
=
qkv_layout
.
replace
(
"paged_kv_"
,
""
)
qkv_group
=
len
(
qkv_layout
.
split
(
"_"
))
if
all
(
isinstance
(
x
,
Float8Tensor
)
for
x
in
[
q_fp8
,
k_fp8
,
v_fp8
]):
src_nominal_dtype
=
q_fp8
.
dtype
else
:
assert
src_nominal_dtype
is
not
None
,
"The nominal dtype of input tensors is required!"
if
des_nominal_dtype
is
None
:
des_nominal_dtype
=
src_nominal_dtype
q_data
,
k_data
,
v_data
=
[
x
.
_data
for
x
in
[
q_fp8
,
k_fp8
,
v_fp8
]]
match
qkv_group
:
case
1
:
dim
=
qkv_layout
.
find
(
"3"
)
qkv_data
=
combine_tensors
([
q_data
,
k_data
,
v_data
],
dim
)
qkv_fp8
=
Float8Tensor
.
make_like
(
q_fp8
,
data
=
qkv_data
)
qkv
=
qkv_fp8
.
dequantize
(
dtype
=
des_nominal_dtype
)
q
,
k
,
v
=
SplitAlongDim
.
apply
(
qkv
,
dim
,
[
1
,
1
,
1
],
True
)
case
2
:
dim
=
qkv_layout
.
split
(
"_"
)[
1
].
find
(
"2"
)
kv_data
=
combine_tensors
([
k_data
,
v_data
],
dim
)
tensors
=
[
q_data
,
kv_data
]
num_tensors
=
len
(
tensors
)
shapes
=
[
x
.
shape
for
x
in
tensors
]
numels
=
[
x
.
numel
()
for
x
in
tensors
]
numels
=
[
sum
(
numels
[:
i
])
for
i
in
range
(
num_tensors
+
1
)]
qkv_data
=
torch
.
cat
([
x
.
reshape
(
-
1
)
for
x
in
tensors
],
dim
=
0
)
qkv_fp8
=
Float8Tensor
.
make_like
(
q_fp8
,
data
=
qkv_data
,
dtype
=
src_nominal_dtype
)
qkv
=
qkv_fp8
.
dequantize
(
dtype
=
des_nominal_dtype
)
q
,
kv
=
[
qkv
[
numels
[
i
]
:
numels
[
i
+
1
]].
view
(
shapes
[
i
])
for
i
in
range
(
num_tensors
)]
k
,
v
=
SplitAlongDim
.
apply
(
kv
,
dim
,
[
1
,
1
],
True
)
case
3
:
tensors
=
[
q_data
,
k_data
,
v_data
]
num_tensors
=
len
(
tensors
)
shapes
=
[
x
.
shape
for
x
in
tensors
]
numels
=
[
x
.
numel
()
for
x
in
tensors
]
numels
=
[
sum
(
numels
[:
i
])
for
i
in
range
(
num_tensors
+
1
)]
qkv_data
=
torch
.
cat
([
x
.
contiguous
().
reshape
(
-
1
)
for
x
in
tensors
],
dim
=
0
)
qkv_fp8
=
Float8Tensor
.
make_like
(
q_fp8
,
data
=
qkv_data
,
dtype
=
src_nominal_dtype
)
qkv
=
qkv_fp8
.
dequantize
(
dtype
=
des_nominal_dtype
)
q
,
k
,
v
=
[
qkv
[
numels
[
i
]
:
numels
[
i
+
1
]].
view
(
shapes
[
i
])
for
i
in
range
(
num_tensors
)]
case
_
:
raise
RuntimeError
(
"Invalid qkv_layout "
+
qkv_layout
)
return
q
,
k
,
v
transformer_engine/pytorch/attention/multi_head_attention.py
View file @
53fa872c
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
# See LICENSE for license information.
# See LICENSE for license information.
"""Multi-head Attention."""
"""Multi-head Attention."""
import
os
import
collections
import
collections
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -31,7 +32,13 @@ from transformer_engine.pytorch.distributed import (
...
@@ -31,7 +32,13 @@ from transformer_engine.pytorch.distributed import (
from
transformer_engine.pytorch.attention.dot_product_attention
import
DotProductAttention
from
transformer_engine.pytorch.attention.dot_product_attention
import
DotProductAttention
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.attention.rope
import
apply_rotary_pos_emb
from
transformer_engine.pytorch.attention.rope
import
apply_rotary_pos_emb
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensor
# Force DotProductAttention to use a different recipe than the fp8_recipe set in fp8_autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
_dpa_fp8_recipe
=
os
.
getenv
(
"NVTE_DPA_FP8_RECIPE"
,
""
)
_dpa_fp8_recipe_dpa
=
os
.
getenv
(
"NVTE_DPA_FP8_RECIPE_DPA"
,
"0"
)
==
"1"
_dpa_fp8_recipe_mha
=
os
.
getenv
(
"NVTE_DPA_FP8_RECIPE_MHA"
,
"0"
)
==
"1"
class
MultiheadAttention
(
torch
.
nn
.
Module
):
class
MultiheadAttention
(
torch
.
nn
.
Module
):
...
@@ -135,6 +142,17 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -135,6 +142,17 @@ class MultiheadAttention(torch.nn.Module):
For that, please use `get_qkv_layout` to gain the layout information.
For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name: str, default = `None`
name of the module, currently used for debugging purposes.
name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
Parallelism parameters
----------------------
----------------------
...
@@ -245,6 +263,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -245,6 +263,7 @@ class MultiheadAttention(torch.nn.Module):
qk_norm_before_rope
:
bool
=
False
,
qk_norm_before_rope
:
bool
=
False
,
seq_length
:
Optional
[
int
]
=
None
,
seq_length
:
Optional
[
int
]
=
None
,
micro_batch_size
:
Optional
[
int
]
=
None
,
micro_batch_size
:
Optional
[
int
]
=
None
,
softmax_type
:
str
=
"vanilla"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -262,6 +281,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -262,6 +281,7 @@ class MultiheadAttention(torch.nn.Module):
self
.
return_bias
=
return_bias
self
.
return_bias
=
return_bias
self
.
cp_size
=
1
self
.
cp_size
=
1
self
.
cp_rank
=
0
self
.
cp_rank
=
0
self
.
softmax_type
=
softmax_type
kv_channels
=
kv_channels
if
kv_channels
else
(
hidden_size
//
num_attention_heads
)
kv_channels
=
kv_channels
if
kv_channels
else
(
hidden_size
//
num_attention_heads
)
...
@@ -416,6 +436,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -416,6 +436,7 @@ class MultiheadAttention(torch.nn.Module):
tp_group
=
tp_group
,
tp_group
=
tp_group
,
layer_number
=
self
.
layer_number
,
layer_number
=
self
.
layer_number
,
attention_type
=
self
.
attention_type
,
attention_type
=
self
.
attention_type
,
softmax_type
=
self
.
softmax_type
,
)
)
# Linear
# Linear
...
@@ -556,10 +577,12 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -556,10 +577,12 @@ class MultiheadAttention(torch.nn.Module):
self
.
cp_size
=
get_distributed_world_size
(
cp_group
)
self
.
cp_size
=
get_distributed_world_size
(
cp_group
)
self
.
cp_rank
=
get_distributed_rank
(
cp_group
)
self
.
cp_rank
=
get_distributed_rank
(
cp_group
)
elif
isinstance
(
cp_group
,
list
):
elif
isinstance
(
cp_group
,
list
):
assert
len
(
cp_group
)
==
2
,
"Current implementation only supports two-level CP groups!"
assert
(
assert
(
cp_comm_type
==
"a2a+p2p"
cp_comm_type
==
"a2a+p2p"
),
"Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
),
"Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
assert
(
len
(
cp_group
)
==
2
),
"cp_comm_type = a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!"
cp_size_a2a
=
get_distributed_world_size
(
cp_group
[
0
])
cp_size_a2a
=
get_distributed_world_size
(
cp_group
[
0
])
cp_rank_a2a
=
get_distributed_rank
(
cp_group
[
0
])
cp_rank_a2a
=
get_distributed_rank
(
cp_group
[
0
])
cp_size_p2p
=
get_distributed_world_size
(
cp_group
[
1
])
cp_size_p2p
=
get_distributed_world_size
(
cp_group
[
1
])
...
@@ -716,10 +739,22 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -716,10 +739,22 @@ class MultiheadAttention(torch.nn.Module):
# Query, Key, and Value
# Query, Key, and Value
# ======================
# ======================
fp8_mha
=
(
fp8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
FP8GlobalStateManager
.
is_fp8_enabled
()
if
_dpa_fp8_recipe
==
""
:
and
FP8GlobalStateManager
.
get_fp8_recipe
().
fp8_mha
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
)
fp8_dpa
=
fp8_recipe
.
fp8_dpa
fp8_mha
=
fp8_recipe
.
fp8_mha
float8_current_scaling
=
fp8_recipe
.
float8_current_scaling
()
else
:
fp8_dpa
=
_dpa_fp8_recipe_dpa
fp8_mha
=
_dpa_fp8_recipe_mha
float8_current_scaling
=
_dpa_fp8_recipe
==
"Float8CurrentScaling"
# QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe
qkv_fp8_output
=
fp8
and
fp8_mha
and
rotary_pos_emb
is
None
and
not
float8_current_scaling
# DPA: always produce FP8 output when fp8=True to take advantage of the O amax
dpa_fp8_output
=
fp8
and
(
fp8_dpa
or
fp8_mha
)
# Proj Gemm: match DPA output except for Float8CurrentScaling
proj_fp8_grad
=
dpa_fp8_output
and
not
float8_current_scaling
layernorm_output
=
None
layernorm_output
=
None
if
self
.
attention_type
==
"self"
:
if
self
.
attention_type
==
"self"
:
...
@@ -728,7 +763,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -728,7 +763,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_qkv_outputs
=
self
.
layernorm_qkv
(
layernorm_qkv_outputs
=
self
.
layernorm_qkv
(
hidden_states
,
hidden_states
,
is_first_microbatch
=
is_first_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
fp8_output
=
fp8_mha
and
rotary_pos_emb
is
None
,
fp8_output
=
qkv_fp8_output
,
)
)
if
self
.
return_layernorm_output
:
if
self
.
return_layernorm_output
:
mixed_x_layer
,
layernorm_output
=
layernorm_qkv_outputs
mixed_x_layer
,
layernorm_output
=
layernorm_qkv_outputs
...
@@ -738,7 +773,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -738,7 +773,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_x_layer
=
self
.
qkv
(
mixed_x_layer
=
self
.
qkv
(
hidden_states
,
hidden_states
,
is_first_microbatch
=
is_first_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
fp8_output
=
fp8_mha
and
rotary_pos_emb
is
None
,
fp8_output
=
qkv_fp8_output
,
)
)
num_queries_per_key_value
=
(
num_queries_per_key_value
=
(
...
@@ -792,7 +827,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -792,7 +827,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer
=
self
.
key_value
(
mixed_kv_layer
=
self
.
key_value
(
encoder_output
,
encoder_output
,
is_first_microbatch
=
is_first_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
fp8_output
=
fp8_mha
and
rotary_pos_emb
is
None
,
fp8_output
=
qkv_fp8_output
,
)
)
if
self
.
qkv_weight_interleaved
:
if
self
.
qkv_weight_interleaved
:
...
@@ -847,7 +882,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -847,7 +882,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_query_outputs
=
self
.
layernorm_query
(
layernorm_query_outputs
=
self
.
layernorm_query
(
hidden_states
,
hidden_states
,
is_first_microbatch
=
is_first_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
fp8_output
=
fp8_mha
and
rotary_pos_emb
is
None
,
fp8_output
=
qkv_fp8_output
,
)
)
if
self
.
return_layernorm_output
:
if
self
.
return_layernorm_output
:
query_layer
,
layernorm_output
=
layernorm_query_outputs
query_layer
,
layernorm_output
=
layernorm_query_outputs
...
@@ -857,7 +892,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -857,7 +892,7 @@ class MultiheadAttention(torch.nn.Module):
query_layer
=
self
.
query_layer
(
query_layer
=
self
.
query_layer
(
hidden_states
,
hidden_states
,
is_first_microbatch
=
is_first_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
fp8_output
=
fp8_mha
and
rotary_pos_emb
is
None
,
fp8_output
=
qkv_fp8_output
,
)
)
# [sq, b, hp] --> [sq, b, np, hn]
# [sq, b, hp] --> [sq, b, np, hn]
...
@@ -958,6 +993,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -958,6 +993,7 @@ class MultiheadAttention(torch.nn.Module):
fast_zero_fill
=
fast_zero_fill
,
fast_zero_fill
=
fast_zero_fill
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
pad_between_seqs
=
pad_between_seqs
,
pad_between_seqs
=
pad_between_seqs
,
fp8_output
=
dpa_fp8_output
,
)
)
# ===================
# ===================
...
@@ -966,7 +1002,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -966,7 +1002,7 @@ class MultiheadAttention(torch.nn.Module):
projection_output
=
self
.
proj
(
projection_output
=
self
.
proj
(
context_layer
,
context_layer
,
is_first_microbatch
=
is_first_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
fp8_grad
=
isinstance
(
context_layer
,
QuantizedTensor
)
,
fp8_grad
=
proj_fp8_grad
,
)
)
if
self
.
return_bias
:
if
self
.
return_bias
:
...
...
transformer_engine/pytorch/constants.py
View file @
53fa872c
...
@@ -91,3 +91,5 @@ GemmParallelModes = ("row", "column", None)
...
@@ -91,3 +91,5 @@ GemmParallelModes = ("row", "column", None)
dist_group_type
=
torch
.
distributed
.
ProcessGroup
dist_group_type
=
torch
.
distributed
.
ProcessGroup
MXFP8_BLOCK_SCALING_SIZE
=
32
MXFP8_BLOCK_SCALING_SIZE
=
32
NVFP4_BLOCK_SCALING_SIZE
=
16
transformer_engine/pytorch/cpp_extensions/fused_attn.py
View file @
53fa872c
...
@@ -12,6 +12,7 @@ from transformer_engine_torch import (
...
@@ -12,6 +12,7 @@ from transformer_engine_torch import (
NVTE_QKV_Format
,
NVTE_QKV_Format
,
NVTE_Bias_Type
,
NVTE_Bias_Type
,
NVTE_Mask_Type
,
NVTE_Mask_Type
,
NVTE_Softmax_Type
,
NVTE_Fused_Attn_Backend
,
NVTE_Fused_Attn_Backend
,
)
)
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor.quantized_tensor
import
Quantizer
...
@@ -86,6 +87,12 @@ AttnMaskType = {
...
@@ -86,6 +87,12 @@ AttnMaskType = {
"padding_causal_bottom_right"
:
NVTE_Mask_Type
.
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
,
"padding_causal_bottom_right"
:
NVTE_Mask_Type
.
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
,
}
}
SoftmaxType
=
{
"vanilla"
:
NVTE_Softmax_Type
.
NVTE_VANILLA_SOFTMAX
,
"off-by-one"
:
NVTE_Softmax_Type
.
NVTE_OFF_BY_ONE_SOFTMAX
,
"learnable"
:
NVTE_Softmax_Type
.
NVTE_LEARNABLE_SOFTMAX
,
}
FusedAttnBackend
=
{
FusedAttnBackend
=
{
"F16_max512_seqlen"
:
NVTE_Fused_Attn_Backend
.
NVTE_F16_max512_seqlen
,
"F16_max512_seqlen"
:
NVTE_Fused_Attn_Backend
.
NVTE_F16_max512_seqlen
,
"F16_arbitrary_seqlen"
:
NVTE_Fused_Attn_Backend
.
NVTE_F16_arbitrary_seqlen
,
"F16_arbitrary_seqlen"
:
NVTE_Fused_Attn_Backend
.
NVTE_F16_arbitrary_seqlen
,
...
@@ -102,9 +109,6 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT
...
@@ -102,9 +109,6 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO
=
tex
.
FP8BwdTensors
.
GRAD_INPUT2
META_DO
=
tex
.
FP8BwdTensors
.
GRAD_INPUT2
META_S
=
tex
.
FP8FwdTensors
.
GEMM3_OUTPUT
META_S
=
tex
.
FP8FwdTensors
.
GEMM3_OUTPUT
META_DP
=
tex
.
FP8BwdTensors
.
GRAD_INPUT3
META_DP
=
tex
.
FP8BwdTensors
.
GRAD_INPUT3
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP
=
tex
.
FP8FwdTensors
.
GEMM2_OUTPUT
META_DQKV_CP
=
tex
.
FP8BwdTensors
.
GRAD_INPUT1
def
fused_attn_fwd
(
def
fused_attn_fwd
(
...
@@ -131,8 +135,10 @@ def fused_attn_fwd(
...
@@ -131,8 +135,10 @@ def fused_attn_fwd(
qkv_layout
:
str
=
"sbh3d"
,
qkv_layout
:
str
=
"sbh3d"
,
attn_bias_type
:
str
=
"no_bias"
,
attn_bias_type
:
str
=
"no_bias"
,
attn_mask_type
:
str
=
"padding"
,
attn_mask_type
:
str
=
"padding"
,
softmax_type
:
str
=
"vanilla"
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
rng_gen
:
torch
.
Generator
=
None
,
rng_gen
:
torch
.
Generator
=
None
,
softmax_offset
:
torch
.
Tensor
=
None
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
"""Fused Attention FWD for separate QKV input.
"""Fused Attention FWD for separate QKV input.
...
@@ -197,6 +203,8 @@ def fused_attn_fwd(
...
@@ -197,6 +203,8 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
...
@@ -205,6 +213,9 @@ def fused_attn_fwd(
...
@@ -205,6 +213,9 @@ def fused_attn_fwd(
rng_gen: torch.Generator, default = None
rng_gen: torch.Generator, default = None
random number generator;
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
Returns
Returns
----------
----------
...
@@ -286,6 +297,7 @@ def fused_attn_fwd(
...
@@ -286,6 +297,7 @@ def fused_attn_fwd(
QKVLayout
[
qkv_layout
],
QKVLayout
[
qkv_layout
],
AttnBiasType
[
attn_bias_type
],
AttnBiasType
[
attn_bias_type
],
AttnMaskType
[
attn_mask_type
],
AttnMaskType
[
attn_mask_type
],
SoftmaxType
[
softmax_type
],
window_size
,
window_size
,
cu_seqlens_q
,
cu_seqlens_q
,
cu_seqlens_kv
,
cu_seqlens_kv
,
...
@@ -300,6 +312,7 @@ def fused_attn_fwd(
...
@@ -300,6 +312,7 @@ def fused_attn_fwd(
s_quantizer
,
s_quantizer
,
o_quantizer
,
o_quantizer
,
attn_bias
,
attn_bias
,
softmax_offset
,
rng_gen
,
rng_gen
,
rng_elts_per_thread
,
rng_elts_per_thread
,
)
)
...
@@ -333,6 +346,7 @@ def fused_attn_bwd(
...
@@ -333,6 +346,7 @@ def fused_attn_bwd(
qkv_layout
:
str
=
"sbh3d"
,
qkv_layout
:
str
=
"sbh3d"
,
attn_bias_type
:
str
=
"no_bias"
,
attn_bias_type
:
str
=
"no_bias"
,
attn_mask_type
:
str
=
"padding"
,
attn_mask_type
:
str
=
"padding"
,
softmax_type
:
str
=
"vanilla"
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
deterministic
:
bool
=
False
,
deterministic
:
bool
=
False
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
...
@@ -398,6 +412,8 @@ def fused_attn_bwd(
...
@@ -398,6 +412,8 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
...
@@ -417,6 +433,9 @@ def fused_attn_bwd(
...
@@ -417,6 +433,9 @@ def fused_attn_bwd(
d_bias: torch.Tensor, optional
d_bias: torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias
or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
"""
"""
if
attn_scale
is
None
:
if
attn_scale
is
None
:
d
=
q
.
size
(
-
1
)
d
=
q
.
size
(
-
1
)
...
@@ -454,6 +473,7 @@ def fused_attn_bwd(
...
@@ -454,6 +473,7 @@ def fused_attn_bwd(
QKVLayout
[
qkv_layout
],
QKVLayout
[
qkv_layout
],
AttnBiasType
[
attn_bias_type
],
AttnBiasType
[
attn_bias_type
],
AttnMaskType
[
attn_mask_type
],
AttnMaskType
[
attn_mask_type
],
SoftmaxType
[
softmax_type
],
window_size
,
window_size
,
deterministic
,
deterministic
,
cu_seqlens_q
,
cu_seqlens_q
,
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
53fa872c
...
@@ -20,6 +20,8 @@ from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_
...
@@ -20,6 +20,8 @@ from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
,
w8a8_block_int8_matmul_wgrad_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
,
w8a8_block_int8_matmul_wgrad_batched
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..tensor.utils
import
is_experimental
from
..experimental.gemm
import
experimental_gemm
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
...
@@ -169,6 +171,24 @@ def general_gemm(
...
@@ -169,6 +171,24 @@ def general_gemm(
if
not
out
.
is_contiguous
():
if
not
out
.
is_contiguous
():
raise
ValueError
(
"Output tensor is not contiguous."
)
raise
ValueError
(
"Output tensor is not contiguous."
)
# If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation
if
is_experimental
(
A
)
or
is_experimental
(
B
):
return
experimental_gemm
(
A
,
B
,
workspace
,
out_dtype
,
quantization_params
,
gelu
,
gelu_in
,
accumulate
,
layout
,
out
,
bias
,
use_split_accumulator
,
grad
,
)
debug_quantizer
=
None
debug_quantizer
=
None
if
isinstance
(
quantization_params
,
DebugQuantizer
):
if
isinstance
(
quantization_params
,
DebugQuantizer
):
debug_quantizer
=
quantization_params
debug_quantizer
=
quantization_params
...
...
transformer_engine/pytorch/csrc/common.cpp
View file @
53fa872c
...
@@ -12,6 +12,20 @@
...
@@ -12,6 +12,20 @@
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
/*! convert fp4 data shape back to original shape */
std
::
vector
<
size_t
>
convert_shape_back_from_fp4
(
const
std
::
vector
<
size_t
>&
shape
,
bool
transpose
)
{
std
::
vector
<
size_t
>
ret
;
size_t
start_idx
=
(
transpose
)
?
1
:
0
;
for
(
size_t
i
=
start_idx
;
i
<
shape
.
size
()
-
1
;
++
i
)
{
ret
.
push_back
(
shape
[
i
]);
}
ret
.
push_back
(
shape
.
back
()
*
2
);
if
(
transpose
)
{
ret
.
push_back
(
shape
.
front
());
}
return
ret
;
}
std
::
vector
<
size_t
>
getTensorShape
(
const
at
::
Tensor
&
t
)
{
std
::
vector
<
size_t
>
getTensorShape
(
const
at
::
Tensor
&
t
)
{
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
size_t
>
shape
;
for
(
auto
s
:
t
.
sizes
())
{
for
(
auto
s
:
t
.
sizes
())
{
...
@@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) {
...
@@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) {
return
((
value
+
multiple
-
1
)
/
multiple
)
*
multiple
;
return
((
value
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
}
void
philox_unpack
(
at
::
PhiloxCudaState
arg
,
int64_t
*
rng_state_ptr
)
{
NVTE_SCOPED_GIL_RELEASE
({
nvte_extract_seed_and_offset
(
rng_state_ptr
,
arg
.
captured_
,
arg
.
seed_
.
ptr
,
arg
.
seed_
.
val
,
arg
.
offset_
.
ptr
,
arg
.
offset_
.
val
,
arg
.
offset_intragraph_
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
// extract PhiloxCudaState from CUDA random number generator
at
::
PhiloxCudaState
init_philox_state
(
at
::
CUDAGeneratorImpl
*
gen
,
size_t
elts_per_thread
)
{
at
::
PhiloxCudaState
philox_args
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
philox_args
=
gen
->
philox_cuda_state
(
elts_per_thread
);
return
philox_args
;
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/common.h
View file @
53fa872c
...
@@ -35,6 +35,7 @@
...
@@ -35,6 +35,7 @@
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_router.h>
#include <transformer_engine/fused_router.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/hadamard_transform.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/normalization.h>
...
@@ -212,20 +213,25 @@ class Float8CurrentScalingQuantizer : public Quantizer {
...
@@ -212,20 +213,25 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
override
;
DType
dtype
)
const
override
;
/*! @brief Construct a
high precision tensor giving it this
quantizer's amax
/*! @brief Construct a
n unquantized tensor that shares the
quantizer's amax
pointer.
*
Note: this member function also
zero
s
out
the amax, as it is meant to be used in conjunction with
* The amax is
zero
ed
out
. Most TE kernels that output amax expect
a kernel computing the amax, which might expect the
amax to be initialized to zero
*
amax to be initialized to zero
.
*/
*/
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_
hp
_tensor_with_amax
(
const
std
::
vector
<
size_t
>&
shape
,
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_
unquantized
_tensor_with_amax
(
DType
dtype
);
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
data
=
std
::
nullopt
);
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
/*! @brief Convert to a quantized data format avoiding amax computation */
/*! @brief Quantize to FP8, skipping local amax computation
*
* The quantizer's amax pointer is assumed to already hold the local
* amax. The amax may still be reduced across the amax reduction
* group.
*/
void
quantize_with_amax
(
TensorWrapper
&
input
,
TensorWrapper
&
out
,
void
quantize_with_amax
(
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
);
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
);
...
@@ -295,6 +301,60 @@ class MXFP8Quantizer : public Quantizer {
...
@@ -295,6 +301,60 @@ class MXFP8Quantizer : public Quantizer {
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
};
};
class
NVFP4Quantizer
:
public
Quantizer
{
public:
// fp4 dtype
DType
dtype
;
// amax reduction for low precision FP4 AG
bool
with_amax_reduction
;
c10
::
intrusive_ptr
<
dist_group_type
>
amax_reduction_group
;
// random hadamard transform
bool
with_rht
;
bool
with_post_rht_amax
;
// 2D block scaling
bool
with_2d_quantization
;
bool
stochastic_rounding
;
int
rht_matrix_random_sign_mask_t
;
at
::
Tensor
rht_matrix
;
explicit
NVFP4Quantizer
(
const
py
::
handle
&
quantizer
);
NVTEScalingMode
get_scaling_mode
()
const
override
{
return
NVTE_NVFP4_1D_SCALING
;
}
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
override
;
/*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer
*
* The amax is zeroed out. Most TE kernels that output amax expect
* amax to be initialized to zero.
*/
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_unquantized_tensor_with_amax
(
TensorWrapper
&
quantized_tensor
,
DType
dtype
);
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
/*! @brief Quantize to NVFP4, skipping local amax computation
*
* The input tensor's amax pointer is assumed to already hold the
* local amax. The amax may still be reduced across the amax
* reduction group.
*/
void
quantize_with_amax
(
TensorWrapper
&
input
,
TensorWrapper
&
out
);
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
private:
void
quantize_impl
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
,
bool
compute_amax
);
};
std
::
unique_ptr
<
Quantizer
>
convert_quantizer
(
py
::
handle
quantizer
);
std
::
unique_ptr
<
Quantizer
>
convert_quantizer
(
py
::
handle
quantizer
);
std
::
vector
<
size_t
>
getTensorShape
(
const
at
::
Tensor
&
t
);
std
::
vector
<
size_t
>
getTensorShape
(
const
at
::
Tensor
&
t
);
...
@@ -445,6 +505,15 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
...
@@ -445,6 +505,15 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
size_t
roundup
(
const
size_t
value
,
const
size_t
multiple
);
size_t
roundup
(
const
size_t
value
,
const
size_t
multiple
);
NVTEShape
convertTorchShape
(
const
c10
::
IntArrayRef
torch_shape
);
NVTEShape
convertTorchShape
(
const
c10
::
IntArrayRef
torch_shape
);
std
::
vector
<
size_t
>
convert_shape_back_from_fp4
(
const
std
::
vector
<
size_t
>&
shape
,
bool
transpose
);
// unpack the PhiloxCudaState into CUDA tensor
void
philox_unpack
(
at
::
PhiloxCudaState
arg
,
int64_t
*
rng_state_ptr
);
// extract PhiloxCudaState from CUDA random number generator
at
::
PhiloxCudaState
init_philox_state
(
at
::
CUDAGeneratorImpl
*
gen
,
size_t
elts_per_thread
);
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
namespace
std
{
namespace
std
{
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
53fa872c
...
@@ -73,28 +73,36 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
...
@@ -73,28 +73,36 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
NVTE_Fused_Attn_Backend
get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
get_fused_attn_backend
(
bool
is_training
,
const
DType
q_dtype
,
const
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
bool
is_training
,
const
DType
q_dtype
,
const
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
p_dropout
,
size_t
num_attn_heads
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
std
::
pair
<
TensorWrapper
,
py
::
object
>
quantizer_helper
(
py
::
handle
quantizer
,
const
std
::
vector
<
size_t
>
&
shape
,
DType
dtype
,
bool
create_hp_tensor_for_cs
,
std
::
optional
<
at
::
Tensor
>
data
);
std
::
vector
<
py
::
object
>
fused_attn_fwd
(
std
::
vector
<
py
::
object
>
fused_attn_fwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
set_zero
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
const
std
::
vector
<
int64_t
>
window_siz
e
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_typ
e
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_
kv
,
const
py
::
handle
Q
,
const
std
::
vector
<
int64_t
>
window_size
,
const
at
::
Tensor
cu_seqlens_
q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
at
::
ScalarType
fake_dtype
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
const
std
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
);
const
std
::
optional
<
at
::
Tensor
>
SoftmaxOffset
,
const
std
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
);
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
const
std
::
vector
<
int64_t
>
window_size
,
bool
deterministic
,
const
at
::
Tensor
cu_seqlens_q
,
NVTE_Softmax_Type
softmax_type
,
const
std
::
vector
<
int64_t
>
window_size
,
bool
deterministic
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
DType
dqkv_type
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
DType
dqkv_type
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
...
...
transformer_engine/pytorch/csrc/extensions/activation.cpp
View file @
53fa872c
...
@@ -8,179 +8,269 @@
...
@@ -8,179 +8,269 @@
#include "common.h"
#include "common.h"
#include "pybind.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
{
namespace
pytorch
{
template
<
void
(
*
act_func
)(
const
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
namespace
{
py
::
object
activation_helper
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
,
int
shape_divisor
=
1
)
{
py
::
object
activation_forward
(
void
(
*
act_func
)(
const
NVTETensor
,
NVTETensor
,
cudaStream_t
),
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
,
int
shape_divisor
=
1
)
{
init_extension
();
init_extension
();
// Input tensor
// Input tensor
auto
input_tensor
=
input
.
contiguous
();
auto
input_tensor
=
input
.
contiguous
();
const
TensorWrapper
&
input_
cpp
=
makeTransformerEngineTensor
(
input_tensor
);
const
TensorWrapper
&
input_
nvte
=
makeTransformerEngineTensor
(
input_tensor
);
// Construct output tensor
// Construct output tensor
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
const
auto
input_shape
=
input_
cpp
.
shape
();
const
auto
input_shape
=
input_
nvte
.
shape
();
std
::
vector
<
size_t
>
output_shape
(
input_shape
.
data
,
input_shape
.
data
+
input_shape
.
ndim
);
std
::
vector
<
size_t
>
output_shape
(
input_shape
.
data
,
input_shape
.
data
+
input_shape
.
ndim
);
output_shape
.
back
()
/=
shape_divisor
;
output_shape
.
back
()
/=
shape_divisor
;
auto
fake_dtype
=
GetTransformerEngineDType
(
input_tensor
.
scalar_type
());
auto
fake_dtype
=
GetTransformerEngineDType
(
input_tensor
.
scalar_type
());
auto
[
out_
cpp
,
out_py
]
=
quantizer_cpp
->
create_tensor
(
output_shape
,
fake_dtype
);
auto
[
out_
nvte
,
out_py
]
=
quantizer_cpp
->
create_tensor
(
output_shape
,
fake_dtype
);
// Compute activation
// Choose implementation
enum
class
Impl
{
UNFUSED
,
FULLY_FUSED
,
FUSED_ACTIVATION_AMAX_FP8
,
FUSED_ACTIVATION_AMAX_NVFP4
};
Impl
impl
=
Impl
::
UNFUSED
;
if
(
quantizer
.
is_none
()
||
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
if
(
quantizer
.
is_none
()
||
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
// Compute activation directly
impl
=
Impl
::
FULLY_FUSED
;
NVTE_SCOPED_GIL_RELEASE
(
{
act_func
(
input_cpp
.
data
(),
out_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// Compute activation in high-precision fused together with amax, then quantize.
impl
=
Impl
::
FUSED_ACTIVATION_AMAX_FP8
;
}
else
if
(
detail
::
IsNVFP4Quantizers
(
quantizer
.
ptr
()))
{
auto
quantizer_cpp_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
auto
nvfp4_quantizer_cpp
=
dynamic_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
auto
[
temp_cpp
,
_
]
=
quantizer_cpp_cs
->
create_hp_tensor_with_amax
(
output_shape
,
fake_dtype
);
NVTE_CHECK
(
nvfp4_quantizer_cpp
!=
nullptr
,
"Could not cast to NVFP4 quantizer"
);
NVTE_SCOPED_GIL_RELEASE
(
if
(
nvfp4_quantizer_cpp
->
with_rht
&&
nvfp4_quantizer_cpp
->
with_post_rht_amax
)
{
{
act_func
(
input_cpp
.
data
(),
temp_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// Post-RHT amax is handled within NVFP4 quantizer
quantizer_cpp_cs
->
quantize_with_amax
(
temp_cpp
,
out_cpp
);
impl
=
Impl
::
UNFUSED
;
}
else
{
}
else
{
// Compute activation in high-precision, then quantize
impl
=
Impl
::
FUSED_ACTIVATION_AMAX_NVFP4
;
}
auto
[
temp_cpp
,
_
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
output_shape
,
fake_dtype
);
}
NVTE_SCOPED_GIL_RELEASE
(
{
act_func
(
input_cpp
.
data
(),
temp_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// Perform compute
quantizer_cpp
->
quantize
(
temp_cpp
,
out_cpp
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
impl
)
{
case
Impl
::
UNFUSED
:
// Compute activation in high precision, then quantize
{
auto
[
temp_nvte
,
_
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
output_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
act_func
(
input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
quantizer_cpp
->
quantize
(
temp_nvte
,
out_nvte
);
}
break
;
case
Impl
::
FULLY_FUSED
:
// Compute activation directly
{
NVTE_SCOPED_GIL_RELEASE
({
act_func
(
input_nvte
.
data
(),
out_nvte
.
data
(),
stream
);
});
}
break
;
case
Impl
::
FUSED_ACTIVATION_AMAX_FP8
:
// Compute activation and amax in high precision, then quantize to FP8
{
auto
fp8_quantizer_cpp
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
fp8_quantizer_cpp
!=
nullptr
,
"Could not cast to FP8 current scaling quantizer"
);
auto
[
temp_nvte
,
_
]
=
fp8_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
output_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
act_func
(
input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
fp8_quantizer_cpp
->
quantize_with_amax
(
temp_nvte
,
out_nvte
);
}
break
;
case
Impl
::
FUSED_ACTIVATION_AMAX_NVFP4
:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
// Already checked cast is valid
auto
[
temp_nvte
,
_
]
=
nvfp4_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
out_nvte
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
act_func
(
input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
nvfp4_quantizer_cpp
->
quantize_with_amax
(
temp_nvte
,
out_nvte
);
}
break
;
default:
NVTE_ERROR
(
"Invalid activation implementation ("
,
static_cast
<
int
>
(
impl
),
")"
);
}
}
return
out_py
;
return
out_py
;
}
}
template
<
void
(
*
dact_func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
py
::
object
activation_backward
(
void
(
*
dact_func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
py
::
object
dactivation_helper
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
input
,
cudaStream_t
),
py
::
handle
quantizer
)
{
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
init_extension
();
init_extension
();
// Grad output and input tensors
// Grad output and input tensors
auto
grad_output_tensor
=
grad_output
.
contiguous
();
auto
grad_output_tensor
=
grad_output
.
contiguous
();
auto
input_tensor
=
input
.
contiguous
();
auto
input_tensor
=
input
.
contiguous
();
const
TensorWrapper
&
grad_output_
cpp
=
makeTransformerEngineTensor
(
grad_output_tensor
);
const
TensorWrapper
&
grad_output_
nvte
=
makeTransformerEngineTensor
(
grad_output_tensor
);
const
TensorWrapper
&
input_
cpp
=
makeTransformerEngineTensor
(
input_tensor
);
const
TensorWrapper
&
input_
nvte
=
makeTransformerEngineTensor
(
input_tensor
);
// Construct grad input tensor
// Construct grad input tensor
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
const
auto
input_shape_te
=
input_
cpp
.
shape
();
const
auto
input_shape_te
=
input_
nvte
.
shape
();
const
std
::
vector
<
size_t
>
input_shape
(
input_shape_te
.
data
,
const
std
::
vector
<
size_t
>
input_shape
(
input_shape_te
.
data
,
input_shape_te
.
data
+
input_shape_te
.
ndim
);
input_shape_te
.
data
+
input_shape_te
.
ndim
);
auto
fake_dtype
=
GetTransformerEngineDType
(
input_tensor
.
scalar_type
());
auto
fake_dtype
=
GetTransformerEngineDType
(
input_tensor
.
scalar_type
());
auto
[
grad_input_
cpp
,
grad_input_py
]
=
quantizer_cpp
->
create_tensor
(
input_shape
,
fake_dtype
);
auto
[
grad_input_
nvte
,
grad_input_py
]
=
quantizer_cpp
->
create_tensor
(
input_shape
,
fake_dtype
);
// Compute activation backward
// Choose implementation
enum
class
Impl
{
UNFUSED
,
FULLY_FUSED
,
FUSED_ACTIVATION_AMAX_FP8
,
FUSED_ACTIVATION_AMAX_NVFP4
};
Impl
impl
=
Impl
::
UNFUSED
;
if
(
quantizer
.
is_none
()
||
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
if
(
quantizer
.
is_none
()
||
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
// Compute activation backward directly
impl
=
Impl
::
FULLY_FUSED
;
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_cpp
.
data
(),
input_cpp
.
data
(),
grad_input_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// Compute activation backward in high-precision fused together with amax, then quantize.
impl
=
Impl
::
FUSED_ACTIVATION_AMAX_FP8
;
auto
quantizer_cpp_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
}
else
if
(
detail
::
IsNVFP4Quantizers
(
quantizer
.
ptr
()))
{
auto
[
temp_cpp
,
_
]
=
quantizer_cpp_cs
->
create_hp_tensor_with_amax
(
input_shape
,
fake_dtype
);
auto
nvfp4_quantizer_cpp
=
dynamic_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
NVTE_SCOPED_GIL_RELEASE
({
NVTE_CHECK
(
nvfp4_quantizer_cpp
!=
nullptr
,
"Could not cast to NVFP4 quantizer"
);
dact_func
(
grad_output_cpp
.
data
(),
input_cpp
.
data
(),
temp_cpp
.
data
(),
if
(
nvfp4_quantizer_cpp
->
with_rht
&&
nvfp4_quantizer_cpp
->
with_post_rht_amax
)
{
at
::
cuda
::
getCurrentCUDAStream
());
// Post-RHT amax is handled within NVFP4 quantizer
});
impl
=
Impl
::
UNFUSED
;
quantizer_cpp_cs
->
quantize_with_amax
(
temp_cpp
,
grad_input_cpp
);
}
else
{
}
else
{
impl
=
Impl
::
FUSED_ACTIVATION_AMAX_NVFP4
;
// Compute activation backward in high-precision, then quantize
}
auto
[
temp_cpp
,
_
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
input_shape
,
fake_dtype
);
}
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_cpp
.
data
(),
input_cpp
.
data
(),
temp_cpp
.
data
(),
// Perform compute
at
::
cuda
::
getCurrentCUDAStream
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
});
switch
(
impl
)
{
quantizer_cpp
->
quantize
(
temp_cpp
,
grad_input_cpp
);
case
Impl
::
UNFUSED
:
// Compute activation backward in high precision, then quantize
{
auto
[
temp_nvte
,
_
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
input_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_nvte
.
data
(),
input_nvte
.
data
(),
temp_nvte
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
quantizer_cpp
->
quantize
(
temp_nvte
,
grad_input_nvte
);
}
break
;
case
Impl
::
FULLY_FUSED
:
// Compute activation backward directly
{
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_nvte
.
data
(),
input_nvte
.
data
(),
grad_input_nvte
.
data
(),
stream
);
});
}
break
;
case
Impl
::
FUSED_ACTIVATION_AMAX_FP8
:
// Compute activation and amax in high precision, then quantize to FP8
{
auto
fp8_quantizer_cpp
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
fp8_quantizer_cpp
!=
nullptr
,
"Could not cast to FP8 current scaling quantizer"
);
auto
[
temp_nvte
,
_
]
=
fp8_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
input_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
(
{
dact_func
(
grad_output_nvte
.
data
(),
input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
fp8_quantizer_cpp
->
quantize_with_amax
(
temp_nvte
,
grad_input_nvte
);
}
break
;
case
Impl
::
FUSED_ACTIVATION_AMAX_NVFP4
:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto
nvfp4_quantizer_cpp
=
static_cast
<
NVFP4Quantizer
*>
(
quantizer_cpp
.
get
());
// Already checked cast is valid
auto
[
temp_nvte
,
_
]
=
nvfp4_quantizer_cpp
->
create_unquantized_tensor_with_amax
(
grad_input_nvte
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
(
{
dact_func
(
grad_output_nvte
.
data
(),
input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
nvfp4_quantizer_cpp
->
quantize_with_amax
(
temp_nvte
,
grad_input_nvte
);
}
break
;
default:
NVTE_ERROR
(
"Invalid activation implementation ("
,
static_cast
<
int
>
(
impl
),
")"
);
}
}
return
grad_input_py
;
return
grad_input_py
;
}
}
/* GELU and variants*/
}
// namespace
/* GELU and variants */
py
::
object
gelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
gelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_
helper
<
nvte_gelu
>
(
input
,
quantizer
);
return
activation_
forward
(
nvte_gelu
,
input
,
quantizer
);
}
}
py
::
object
dgelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
dgelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_
helper
<
nvte_dgelu
>
(
grad
,
input
,
quantizer
);
return
activation_
backward
(
nvte_dgelu
,
grad
,
input
,
quantizer
);
}
}
py
::
object
geglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
geglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_
helper
<
nvte_geglu
>
(
input
,
quantizer
,
2
);
return
activation_
forward
(
nvte_geglu
,
input
,
quantizer
,
2
);
}
}
py
::
object
dgeglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
dgeglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_
helper
<
nvte_dgeglu
>
(
grad
,
input
,
quantizer
);
return
activation_
backward
(
nvte_dgeglu
,
grad
,
input
,
quantizer
);
}
}
py
::
object
qgelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
qgelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_
helper
<
nvte_qgelu
>
(
input
,
quantizer
);
return
activation_
forward
(
nvte_qgelu
,
input
,
quantizer
);
}
}
py
::
object
dqgelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
dqgelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_
helper
<
nvte_dqgelu
>
(
grad
,
input
,
quantizer
);
return
activation_
backward
(
nvte_dqgelu
,
grad
,
input
,
quantizer
);
}
}
py
::
object
qgeglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
qgeglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_
helper
<
nvte_qgeglu
>
(
input
,
quantizer
,
2
);
return
activation_
forward
(
nvte_qgeglu
,
input
,
quantizer
,
2
);
}
}
py
::
object
dqgeglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
dqgeglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_
helper
<
nvte_dqgeglu
>
(
grad
,
input
,
quantizer
);
return
activation_
backward
(
nvte_dqgeglu
,
grad
,
input
,
quantizer
);
}
}
/* ReLU and variants*/
/* ReLU and variants
*/
py
::
object
relu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
relu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_
helper
<
nvte_relu
>
(
input
,
quantizer
);
return
activation_
forward
(
nvte_relu
,
input
,
quantizer
);
}
}
py
::
object
drelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
drelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_
helper
<
nvte_drelu
>
(
grad
,
input
,
quantizer
);
return
activation_
backward
(
nvte_drelu
,
grad
,
input
,
quantizer
);
}
}
py
::
object
reglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
reglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_
helper
<
nvte_reglu
>
(
input
,
quantizer
,
2
);
return
activation_
forward
(
nvte_reglu
,
input
,
quantizer
,
2
);
}
}
py
::
object
dreglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
dreglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_
helper
<
nvte_dreglu
>
(
grad
,
input
,
quantizer
);
return
activation_
backward
(
nvte_dreglu
,
grad
,
input
,
quantizer
);
}
}
py
::
object
srelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
srelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_
helper
<
nvte_srelu
>
(
input
,
quantizer
);
return
activation_
forward
(
nvte_srelu
,
input
,
quantizer
);
}
}
py
::
object
dsrelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
dsrelu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_
helper
<
nvte_dsrelu
>
(
grad
,
input
,
quantizer
);
return
activation_
backward
(
nvte_dsrelu
,
grad
,
input
,
quantizer
);
}
}
py
::
object
sreglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
sreglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_
helper
<
nvte_sreglu
>
(
input
,
quantizer
,
2
);
return
activation_
forward
(
nvte_sreglu
,
input
,
quantizer
,
2
);
}
}
py
::
object
dsreglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
dsreglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_
helper
<
nvte_dsreglu
>
(
grad
,
input
,
quantizer
);
return
activation_
backward
(
nvte_dsreglu
,
grad
,
input
,
quantizer
);
}
}
/* Silu and variants*/
/* Silu and variants
*/
py
::
object
silu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
silu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_
helper
<
nvte_silu
>
(
input
,
quantizer
);
return
activation_
forward
(
nvte_silu
,
input
,
quantizer
);
}
}
py
::
object
dsilu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
dsilu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_
helper
<
nvte_dsilu
>
(
grad
,
input
,
quantizer
);
return
activation_
backward
(
nvte_dsilu
,
grad
,
input
,
quantizer
);
}
}
py
::
object
swiglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
swiglu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
activation_
helper
<
nvte_swiglu
>
(
input
,
quantizer
,
2
);
return
activation_
forward
(
nvte_swiglu
,
input
,
quantizer
,
2
);
}
}
py
::
object
dswiglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
dswiglu
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
return
d
activation_
helper
<
nvte_dswiglu
>
(
grad
,
input
,
quantizer
);
return
activation_
backward
(
nvte_dswiglu
,
grad
,
input
,
quantizer
);
}
}
}
// namespace transformer_engine::pytorch
}
// namespace pytorch
}
// namespace transformer_engine
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment