Commit 53fa872c authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_release_v2.8' into release_v2.8

parents 27ddce40 40c69e75
...@@ -87,5 +87,31 @@ constexpr struct Alignment { ...@@ -87,5 +87,31 @@ constexpr struct Alignment {
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);
template <typename T, typename... Rest>
void hash_combine(int64_t &seed, const T &v, Rest... rest) {
seed ^= std::hash<T>{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
(hash_combine(seed, rest), ...);
}
enum class JAXX_Collective_Op : int64_t {
NONE = 0,
ALL_GATHER = 1,
REDUCE_SCATTER = 2,
};
static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) {
switch (op) {
case JAXX_Collective_Op::ALL_GATHER:
return CommOverlapType::AG;
break;
case JAXX_Collective_Op::REDUCE_SCATTER:
return CommOverlapType::RS;
break;
default:
NVTE_ERROR("Invalid Collective Op ", static_cast<int>(op));
break;
}
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -180,6 +180,42 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -180,6 +180,42 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<bool>("is_2x"), .Attr<bool>("is_2x"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type mu_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type,
bool zero_centered_gamma, double epsilon, int64_t sm_margin,
JAXX_Scaling_Mode scaling_mode, bool is_2x) {
return wrapInStreamCapture(
std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("norm_type")
.Attr<bool>("zero_centered_gamma")
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, NVTE_Norm_Type norm_type, DType w_dtype, NVTE_Norm_Type norm_type,
bool zero_centered_gamma, int sm_margin) { bool zero_centered_gamma, int sm_margin) {
...@@ -305,5 +341,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI, ...@@ -305,5 +341,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI,
.Attr<int64_t>("sm_margin"), .Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type NormBackwardInitializeFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf,
Buffer_Type gamma_buf, Result_Type xgrad_buf,
Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, int64_t norm_type,
bool zero_centered_gamma, int64_t sm_margin) {
return wrapInStreamCapture(std::function(NormBackwardFFI), stream, dz_buf, x_buf, mu_buf,
rsigma_buf, gamma_buf, xgrad_buf, wgrad_buf, dbeta_buf, wkspace_buf,
norm_type, zero_centered_gamma, sm_margin);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardInitializeHandler, NormBackwardInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("norm_type")
.Attr<bool>("zero_centered_gamma")
.Attr<int64_t>("sm_margin"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
************************************************************************/ ************************************************************************/
#include "../extensions.h" #include "../extensions.h"
#include "cgemm_helper.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -20,8 +22,12 @@ pybind11::dict Registrations() { ...@@ -20,8 +22,12 @@ pybind11::dict Registrations() {
pybind11::dict dict; pybind11::dict dict;
// Activation // Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_ffi"] =
dict["te_dact_dbias_quantize_ffi"] = EncapsulateFFI(DActLuDBiasQuantizeHandler); pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(ActLuInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ActLuHandler));
dict["te_dact_dbias_quantize_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(DActLuDBiasQuantizeInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler));
// Quantization // Quantization
dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler);
...@@ -42,9 +48,11 @@ pybind11::dict Registrations() { ...@@ -42,9 +48,11 @@ pybind11::dict Registrations() {
// Normalization // Normalization
dict["te_norm_forward_ffi"] = dict["te_norm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(NormForwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler)); pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler));
dict["te_norm_backward_ffi"] = dict["te_norm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(NormBackwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler)); pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler));
// Attention // Attention
...@@ -57,7 +65,7 @@ pybind11::dict Registrations() { ...@@ -57,7 +65,7 @@ pybind11::dict Registrations() {
// GEMM // GEMM
dict["te_gemm_ffi"] = dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM // Grouped GEMM
...@@ -84,6 +92,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -84,6 +92,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator);
m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
...@@ -159,6 +169,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -159,6 +169,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values(); .export_values();
pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
.value("NONE", JAXX_Collective_Op::NONE)
.value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER)
.value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER)
.export_values();
} }
} // namespace jax } // namespace jax
......
...@@ -11,10 +11,12 @@ customizable contracting dimensions for flexible tensor operations. ...@@ -11,10 +11,12 @@ customizable contracting dimensions for flexible tensor operations.
from typing import Tuple, Sequence from typing import Tuple, Sequence
from functools import partial from functools import partial
import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .quantize import ( from .quantize import (
ScaledTensorFactory, ScaledTensorFactory,
ScalingMode, ScalingMode,
...@@ -61,8 +63,12 @@ def dense( ...@@ -61,8 +63,12 @@ def dense(
kernel: jnp.ndarray, kernel: jnp.ndarray,
bias: jnp.ndarray = None, bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
batch_sequence_transpose: bool = False,
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
output_axes: Tuple[str, ...] = None,
using_global_amax_of_x: bool = False,
collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -76,11 +82,20 @@ def dense( ...@@ -76,11 +82,20 @@ def dense(
kernel: Weight matrix for the dense layer transformation kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
if batch_sequence_transpose:
warnings.warn("batch_sequence_transpose is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled(): if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype input_dtype = x.dtype
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
...@@ -90,29 +105,30 @@ def dense( ...@@ -90,29 +105,30 @@ def dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set, quantizer_set,
) )
return output return output
@partial( @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9))
jax.custom_vjp,
nondiff_argnums=(
3,
4,
5,
),
)
def _dense( def _dense(
x, x,
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose,
input_axes, input_axes,
kernel_axes, kernel_axes,
quantizer_set, output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set, # need to be a diff_arg for DelayedScaling state management
): ):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
...@@ -124,8 +140,12 @@ def _dense( ...@@ -124,8 +140,12 @@ def _dense(
kernel: Weight matrix kernel: Weight matrix
bias: Optional bias tensor bias: Optional bias tensor
contracting_dims: Contracting dimensions specification contracting_dims: Contracting dimensions specification
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
...@@ -136,8 +156,12 @@ def _dense( ...@@ -136,8 +156,12 @@ def _dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -148,8 +172,12 @@ def _dense_fwd_rule( ...@@ -148,8 +172,12 @@ def _dense_fwd_rule(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
...@@ -175,6 +203,7 @@ def _dense_fwd_rule( ...@@ -175,6 +203,7 @@ def _dense_fwd_rule(
x, x,
flatten_axis=flatten_axis_x, flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x, quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL,
) )
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
...@@ -182,6 +211,7 @@ def _dense_fwd_rule( ...@@ -182,6 +211,7 @@ def _dense_fwd_rule(
kernel, kernel,
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel, quantizer=quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
...@@ -191,9 +221,12 @@ def _dense_fwd_rule( ...@@ -191,9 +221,12 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS), casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set.forward,
) )
output = with_sharding_constraint_by_logical_axes(output, output_axes)
if use_bias and tex.gemm_uses_jax_dot(): if use_bias and tex.gemm_uses_jax_dot():
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
...@@ -212,8 +245,16 @@ def _dense_fwd_rule( ...@@ -212,8 +245,16 @@ def _dense_fwd_rule(
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad contracting_dims,
): # pylint: disable=unused-argument batch_sequence_transpose,
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
ctx,
grad,
):
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
Returns: Returns:
...@@ -228,6 +269,7 @@ def _dense_bwd_rule( ...@@ -228,6 +269,7 @@ def _dense_bwd_rule(
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
) = ctx ) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, output_axes)
fwd_x_contracting_dims, fwd_k_contracting_dims = map( fwd_x_contracting_dims, fwd_k_contracting_dims = map(
tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
...@@ -238,6 +280,7 @@ def _dense_bwd_rule( ...@@ -238,6 +280,7 @@ def _dense_bwd_rule(
is_dbias=use_bias, is_dbias=use_bias,
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP,
) )
# GEMM NT # GEMM NT
...@@ -254,8 +297,9 @@ def _dense_bwd_rule( ...@@ -254,8 +297,9 @@ def _dense_bwd_rule(
casted_grad.get_tensor(usage=TensorUsage.LHS), casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs, casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set.backward,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN # GEMM TN
# x_non_contracting_dims # x_non_contracting_dims
...@@ -267,7 +311,10 @@ def _dense_bwd_rule( ...@@ -267,7 +311,10 @@ def _dense_bwd_rule(
casted_x_lhs, casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS), casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim), contracting_dims=(x_contracting_dim, g_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set return dgrad, wgrad, dbias, quantizer_set
......
...@@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[ ...@@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[
return drop_path_shape return drop_path_shape
# TODO(Phuong): move this function to sharding.py
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
""" """
Extend the given Flax logical axis rules with the predefined TransformerLayer's Extend the given Flax logical axis rules with the predefined TransformerLayer's
......
...@@ -21,6 +21,7 @@ import jax.numpy as jnp ...@@ -21,6 +21,7 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .layernorm import canonicalize_norm_type from .layernorm import canonicalize_norm_type
from .quantize import ( from .quantize import (
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
...@@ -40,6 +41,7 @@ def layernorm_mlp( ...@@ -40,6 +41,7 @@ def layernorm_mlp(
norm_type: str, norm_type: str,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
batch_sequence_transpose: bool = False,
norm_input_axes: Tuple[str, ...] = None, norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None,
...@@ -48,6 +50,10 @@ def layernorm_mlp( ...@@ -48,6 +50,10 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
collective_op_sets: Tuple[tex.CollectiveOpSet] = (
tex.noop_collective_op_set,
tex.noop_collective_op_set,
),
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block. """Apply layer normalization followed by MLP block.
...@@ -71,6 +77,7 @@ def layernorm_mlp( ...@@ -71,6 +77,7 @@ def layernorm_mlp(
norm_type: Type of normalization ("layernorm" or "rmsnorm") norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for sharding the layernorm input norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication
...@@ -79,6 +86,7 @@ def layernorm_mlp( ...@@ -79,6 +86,7 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation activation_type: Activation function(s) to apply after the first dense layer transformation
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns: Returns:
...@@ -121,6 +129,7 @@ def layernorm_mlp( ...@@ -121,6 +129,7 @@ def layernorm_mlp(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -129,12 +138,13 @@ def layernorm_mlp( ...@@ -129,12 +138,13 @@ def layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
collective_op_sets,
quantizer_sets, quantizer_sets,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -146,6 +156,7 @@ def _layernorm_mlp( ...@@ -146,6 +156,7 @@ def _layernorm_mlp(
norm_type: str, norm_type: str,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
batch_sequence_transpose: bool,
norm_input_axes: Tuple[str, ...], norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
...@@ -154,6 +165,7 @@ def _layernorm_mlp( ...@@ -154,6 +165,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
collective_op_sets: Tuple[tex.CollectiveOpSet],
quantizer_sets, quantizer_sets,
): ):
"""Internal implementation of layernorm_mlp with custom VJP. """Internal implementation of layernorm_mlp with custom VJP.
...@@ -173,12 +185,16 @@ def _layernorm_mlp( ...@@ -173,12 +185,16 @@ def _layernorm_mlp(
norm_type: Type of normalization norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for layernorm sharding norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding
kernel_1_axes: Logical axes for first weight matrix sharding
kernel_2_axes: Logical axes for second weight matrix sharding
ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s) activation_type: Activation function(s)
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of quantizer sets quantizer_sets: Tuple of quantizer sets
Returns: Returns:
...@@ -195,6 +211,7 @@ def _layernorm_mlp( ...@@ -195,6 +211,7 @@ def _layernorm_mlp(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -203,6 +220,7 @@ def _layernorm_mlp( ...@@ -203,6 +220,7 @@ def _layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
collective_op_sets,
quantizer_sets, quantizer_sets,
) )
return output return output
...@@ -219,6 +237,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -219,6 +237,7 @@ def _layernorm_mlp_fwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -227,6 +246,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -227,6 +246,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
collective_op_sets,
quantizer_sets, quantizer_sets,
): ):
"""Forward pass rule for layernorm_mlp. """Forward pass rule for layernorm_mlp.
...@@ -246,6 +266,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -246,6 +266,10 @@ def _layernorm_mlp_fwd_rule(
del kernel_1_axes, kernel_2_axes del kernel_1_axes, kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
collective_op_set_1, collective_op_set_2 = collective_op_sets
assert not collective_op_set_1.forward.is_reduce_scatter
assert not collective_op_set_2.forward.is_all_gather
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
...@@ -272,13 +296,12 @@ def _layernorm_mlp_fwd_rule( ...@@ -272,13 +296,12 @@ def _layernorm_mlp_fwd_rule(
epsilon, epsilon,
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
amax_scope=AmaxScope.TPSP,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize( casted_kernel_1 = tex.quantize(
kernel_1, kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
) )
# NN GEMM # NN GEMM
...@@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS), casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias_1 if not tex.gemm_uses_jax_dot() else None, bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_1.forward,
) )
if use_bias_1 and tex.gemm_uses_jax_dot(): if use_bias_1 and tex.gemm_uses_jax_dot():
...@@ -317,6 +342,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -317,6 +342,7 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_2 = tex.quantize( casted_kernel_2 = tex.quantize(
kernel_2, kernel_2,
quantizer=ffn2_quantizer_set.kernel, quantizer=ffn2_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
) )
# NN GEMM # NN GEMM
...@@ -325,8 +351,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -325,8 +351,10 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS), casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias_2 if not tex.gemm_uses_jax_dot() else None, bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_2.forward,
) )
if use_bias_2 and tex.gemm_uses_jax_dot(): if use_bias_2 and tex.gemm_uses_jax_dot():
...@@ -334,6 +362,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -334,6 +362,8 @@ def _layernorm_mlp_fwd_rule(
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
# sharding of outputs should be the same as dot_1's input
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes)
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = ( ctx = (
...@@ -363,6 +393,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -363,6 +393,7 @@ def _layernorm_mlp_bwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -371,6 +402,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -371,6 +402,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
collective_op_sets,
ctx, ctx,
grad, grad,
): ):
...@@ -409,6 +441,10 @@ def _layernorm_mlp_bwd_rule( ...@@ -409,6 +441,10 @@ def _layernorm_mlp_bwd_rule(
) = ctx ) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
collective_op_set_1, collective_op_set_2 = collective_op_sets
assert not collective_op_set_1.backward.is_all_gather
assert not collective_op_set_2.backward.is_reduce_scatter
# Since the sharding of outputs should be the same as dot_1's input # Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
...@@ -417,6 +453,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -417,6 +453,7 @@ def _layernorm_mlp_bwd_rule(
grad, grad,
is_dbias=use_bias_2, is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad, quantizer=ffn1_quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -434,6 +471,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -434,6 +471,8 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2, casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set_2.backward,
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
...@@ -448,6 +487,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -448,6 +487,7 @@ def _layernorm_mlp_bwd_rule(
casted_act_out, casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
...@@ -474,6 +514,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -474,6 +514,8 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS), casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1, casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set_1.backward,
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
...@@ -484,6 +526,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -484,6 +526,7 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out, casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS), casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
...@@ -17,7 +17,7 @@ from functools import reduce, lru_cache ...@@ -17,7 +17,7 @@ from functools import reduce, lru_cache
import operator import operator
import numpy as np import numpy as np
from jax.experimental.custom_partitioning import BATCHING from jax.experimental.custom_partitioning import BATCHING, CompoundFactor
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp import jax.numpy as jnp
...@@ -152,12 +152,15 @@ class ScalingModeMetadataImpl(ABC): ...@@ -152,12 +152,15 @@ class ScalingModeMetadataImpl(ABC):
@abstractmethod @abstractmethod
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization. flatten_axis: Axis along which data can be flattened to 2D for quantization.
...@@ -232,12 +235,15 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -232,12 +235,15 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_groups,) return (n_groups,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization. flatten_axis: Axis along which data can be flattened to 2D for quantization.
...@@ -245,7 +251,7 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -245,7 +251,7 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv" scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
...@@ -323,20 +329,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -323,20 +329,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_groups,) return (n_groups,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization. flatten_axis: Axis along which data can be flattened to 2D for quantization
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv" scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
...@@ -562,52 +571,55 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -562,52 +571,55 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_block_x * n_block_y,) return (n_block_x * n_block_y,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis input_rank = len(input_shape)
input_spec = [f"{unique_var}{i}" for i in range(input_rank)] input_spec = [f"{unique_var}_{i}" for i in range(input_rank)]
rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] flatten_axis = (flatten_axis + input_rank) % input_rank
colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)]
# This implementation needs to be updated for different block dims.
# NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors. assert self._block_dims == (1, 32)
# Unfortunately, because Shardy rules are applied to the inner primitive, the
# only way to preserve the relationship is to lower unpadded scales to the # We have to use two different factors in the two CompoundFactors because of Shardy
# underlying custom call and pad them in C++. Until that's implemented, the # verifier requirements, even though they are the same.
# Shardy rules for block scales have to be completely disconnected from the blocksizes = {}
# Shardy rules for the tensor they belong to. colwise_var = f"{unique_var}_None"
rowwise_var = f"{unique_var}_None"
# # We have to use two different factors in the two CompoundFactors because of Shardy if not input_shape[-1] == 32:
# # verifier requirements, even though they are the same. rowwise_var = input_spec[-1] + "_compound"
# rowwise_var = unique_var input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x")
# colwise_var = f"{unique_var}_" blocksizes["blocksize_x"] = 32
# input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") if not input_shape[flatten_axis - 1] == 32:
# input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") colwise_var = input_spec[flatten_axis - 1] + "_compound"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y")
# # The rowwise and colwise scale tensors should be sharded the same way as the input. blocksizes["blocksize_y"] = 32
# # However, we need to adjust the dimensions where the block scaling factor applies.
# rowwise = input_spec.copy() # The rowwise and colwise scale tensors should be sharded the same way as the input.
# rowwise[-1] = rowwise_var # However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
# colwise = input_spec.copy() rowwise[-1] = rowwise_var
# colwise[flatten_axis - 1] = colwise_var
colwise = input_spec.copy()
# # This implementation needs to be updated for different block dims. colwise[flatten_axis - 1] = colwise_var
# assert self._block_dims == (1, 32)
return QuantizeShardyRules( return QuantizeShardyRules(
tuple(input_spec), tuple(input_spec),
tuple(rowwise), tuple(rowwise),
tuple(colwise), tuple(colwise),
{}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, blocksizes,
) )
...@@ -697,18 +709,22 @@ class ScalingMode(Enum): ...@@ -697,18 +709,22 @@ class ScalingMode(Enum):
return self._get_impl().get_quantize_layout(usage) return self._get_impl().get_quantize_layout(usage)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1 self,
input_shape,
unique_var,
flatten_axis=-1,
) -> Tuple[Tuple[str]]: ) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) return self._get_impl().get_shardy_sharding_rules(input_shape, unique_var, flatten_axis)
def get_grouped_scale_shape_2x( def get_grouped_scale_shape_2x(
self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
......
...@@ -13,6 +13,7 @@ from contextlib import contextmanager ...@@ -13,6 +13,7 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional from typing import Callable, Optional
import warnings import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.interpreters import pxla from jax.interpreters import pxla
...@@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes ...@@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
if axis != global_mesh_resource().pp_resource: if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis, mesh) x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x return x
def tpsp_axis_size():
"""
Get the size of the tensor parallelism axis.
Return 1 if no TP axis is set.
"""
return get_mesh_axis_size(global_mesh_resource().tpsp_resource)
def dp_or_fsdp_axis_size():
"""
Get the size of the data parallelism or FSDP axis.
Return 1 if no DP/FSDP axis is set.
"""
dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource)
fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource)
return dp_size if dp_size > 1 else fsdp_size
...@@ -13,17 +13,20 @@ import logging ...@@ -13,17 +13,20 @@ import logging
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
import torch import torch
import torch.nn.functional as F
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
SplitAlongDim,
get_device_compute_capability, get_device_compute_capability,
combine_tensors,
split_tensor_along_dim, split_tensor_along_dim,
) )
from transformer_engine.pytorch.utils import attention_mask_func from transformer_engine.pytorch.utils import attention_mask_func, nvtx_range_push, nvtx_range_pop
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensorBase,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
...@@ -40,7 +43,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -40,7 +43,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_O, META_O,
META_QKV, META_QKV,
) )
from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype, FP8GlobalStateManager
from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.jit import no_torch_dynamo from transformer_engine.pytorch.jit import no_torch_dynamo
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
...@@ -53,6 +56,9 @@ from transformer_engine.pytorch.attention.inference import InferenceParams ...@@ -53,6 +56,9 @@ from transformer_engine.pytorch.attention.inference import InferenceParams
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils as fa_utils, FlashAttentionUtils as fa_utils,
combine_and_quantize,
combine_and_dequantize,
print_quantizers,
) )
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log, AttentionLogging as attn_log,
...@@ -131,6 +137,58 @@ if not IS_HIP_EXTENSION: ...@@ -131,6 +137,58 @@ if not IS_HIP_EXTENSION:
fa_utils.set_flash_attention_3_params() fa_utils.set_flash_attention_3_params()
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1"
class FP8EmulationFunc(torch.autograd.Function):
"""
Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows:
- forward : QKV (quantize+dequantize), P (pass-through), S (quantize+dequantize), O (pass-through)
- backward: dO (quantize+dequantize), dS (pass-through), dP (quantize+dequantize), dQKV (pass-through)
"""
@staticmethod
def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout):
# pylint: disable=missing-function-docstring
if quantizer_name == "QKV_quantizer":
query_layer, key_layer, value_layer = [
x.contiguous() for x in [tensor1, tensor2, tensor3]
]
q_fp8, k_fp8, v_fp8 = combine_and_quantize(
qkv_layout, query_layer, key_layer, value_layer, quantizer
)
tensors = combine_and_dequantize(
qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype
)
elif quantizer_name in ["S_quantizer", "O_quantizer"]:
t_fp8 = quantizer(tensor1)
tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3)
else:
tensors = (tensor1, tensor2, tensor3)
ctx.quantizer = quantizer
ctx.quantizer_name = quantizer_name
ctx.qkv_layout = qkv_layout
return tensors[0], tensors[1], tensors[2]
@staticmethod
def backward(ctx, grad1, grad2, grad3):
# pylint: disable=missing-function-docstring
if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]:
dt_fp8 = ctx.quantizer(grad1)
tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3
elif ctx.quantizer_name == "dQKV_quantizer":
query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]]
dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize(
ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer
)
tensors = combine_and_dequantize(
ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype
)
else:
tensors = grad1, grad2, grad3
return tensors[0], tensors[1], tensors[2], None, None, None
class UnfusedDotProductAttention(torch.nn.Module): class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms """Parallel attention w/o QKV and Proj Gemms
...@@ -144,6 +202,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -144,6 +202,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -151,6 +210,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -151,6 +210,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_type = attention_type self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number self.layer_number = layer_number
self.softmax_type = softmax_type
def mask_func(x, y): def mask_func(x, y):
return ( return (
...@@ -187,6 +247,11 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -187,6 +247,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None,
fp8_output: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""Unfused attention fprop""" """Unfused attention fprop"""
assert ( assert (
...@@ -284,6 +349,35 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -284,6 +349,35 @@ class UnfusedDotProductAttention(torch.nn.Module):
if apply_qk_layer_scaling: if apply_qk_layer_scaling:
scale /= self.layer_number scale /= self.layer_number
if fp8:
# get quantizers from DPA; all Nones if not fp8
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
dpa_utils.get_attention_quantizers(fp8, quantizers)
)
# S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
if fp8_recipe.float8_current_scaling():
S_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=S_quantizer.dtype, device="cuda"
)
dP_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=dP_quantizer.dtype, device="cuda"
)
if "2" in qkv_layout or "3" in qkv_layout:
qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout)
qkv_layout = "_".join([qkv_format] * 3)
# quantize and dequantize QKV to emulate FP8
query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout
)
# quantize and dequantize dQKV to emulate FP8
query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout
)
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
if core_attention_bias_type == "no_bias": if core_attention_bias_type == "no_bias":
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
...@@ -328,7 +422,27 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -328,7 +422,27 @@ class UnfusedDotProductAttention(torch.nn.Module):
dtype=query_layer.dtype dtype=query_layer.dtype
) )
# attention scores and attention mask [b, np, sq, sk] if fp8:
# quantize and dequantize dP to emulate FP8
matmul_result, *_ = FP8EmulationFunc.apply(
matmul_result, None, None, dP_quantizer, "dP_quantizer", None
)
# add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla":
matmul_result = torch.cat(
[
matmul_result,
softmax_offset.to(dtype=matmul_result.dtype).expand(
matmul_result.size(0), -1, matmul_result.size(2), -1
),
],
dim=-1,
)
attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False)
attn_mask_type = "arbitrary"
# attention scores and attention mask
softmax_scale = self.layer_number if apply_qk_layer_scaling else None softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax( attention_probs = self.scale_mask_softmax(
matmul_result, attention_mask, attn_mask_type, softmax_scale matmul_result, attention_mask, attn_mask_type, softmax_scale
...@@ -339,6 +453,10 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -339,6 +453,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type: if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0) attention_probs = attention_probs.masked_fill(attention_mask, 0)
# remove attention sink: [b, np, sq, sk]
if self.softmax_type != "vanilla":
attention_probs = attention_probs[..., :-1]
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -359,6 +477,12 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -359,6 +477,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
# change view [b * np, sq, sk] # change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
if fp8:
# quantize and dequantize S to emulate FP8
attention_probs, *_ = FP8EmulationFunc.apply(
attention_probs, None, None, S_quantizer, "S_quantizer", None
)
# matmul: [b * np, sq, hn] # matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
...@@ -393,6 +517,20 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -393,6 +517,20 @@ class UnfusedDotProductAttention(torch.nn.Module):
# [tq, np, hn] --> [tq, hp] # [tq, np, hn] --> [tq, hp]
context_layer = context_layer.view(total_tokens, -1) context_layer = context_layer.view(total_tokens, -1)
if fp8:
# quantize and dequantize O to emulate FP8
context_layer, *_ = FP8EmulationFunc.apply(
context_layer, None, None, O_quantizer, "O_quantizer", None
)
# quantize and dequantize dO to emulate FP8
context_layer, *_ = FP8EmulationFunc.apply(
context_layer, None, None, dO_quantizer, "dO_quantizer", None
)
# quantize O
if fp8_output:
context_layer = O_quantizer(context_layer)
return context_layer return context_layer
...@@ -491,6 +629,7 @@ class FlashAttention(torch.nn.Module): ...@@ -491,6 +629,7 @@ class FlashAttention(torch.nn.Module):
quantizers=None, quantizers=None,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
fp8_output: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
...@@ -696,6 +835,7 @@ class FlashAttention(torch.nn.Module): ...@@ -696,6 +835,7 @@ class FlashAttention(torch.nn.Module):
quantizers=quantizers, quantizers=quantizers,
pad_between_seqs=False, pad_between_seqs=False,
use_flash_attn_3=use_flash_attn_3, use_flash_attn_3=use_flash_attn_3,
fp8_output=fp8_output,
) )
else: else:
from transformer_engine.pytorch.cpu_offload import ( from transformer_engine.pytorch.cpu_offload import (
...@@ -795,8 +935,6 @@ class FlashAttention(torch.nn.Module): ...@@ -795,8 +935,6 @@ class FlashAttention(torch.nn.Module):
) )
return out return out
# "fp8_mha" decides outputs in fp8, while inputs are inferred from
# the real dtype
assert isinstance(key_layer, query_layer.__class__) and isinstance( assert isinstance(key_layer, query_layer.__class__) and isinstance(
value_layer, query_layer.__class__ value_layer, query_layer.__class__
), "q, k, and v must have the same type." ), "q, k, and v must have the same type."
...@@ -843,7 +981,7 @@ class FlashAttention(torch.nn.Module): ...@@ -843,7 +981,7 @@ class FlashAttention(torch.nn.Module):
if fp8: if fp8:
output = output.to(dtype=torch_orig_dtype) output = output.to(dtype=torch_orig_dtype)
if fp8 and fp8_meta["recipe"].fp8_mha: if fp8 and fp8_output:
O_quantizer = quantizers["scaling_fwd"][META_O] O_quantizer = quantizers["scaling_fwd"][META_O]
output = O_quantizer(output) output = O_quantizer(output)
...@@ -871,7 +1009,7 @@ class FlashAttention(torch.nn.Module): ...@@ -871,7 +1009,7 @@ class FlashAttention(torch.nn.Module):
if q_format == "sbhd": if q_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd) # (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha: if fp8 and fp8_output:
output_data = ( output_data = (
output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1) .transpose(0, 1)
...@@ -895,7 +1033,7 @@ class FlashAttention(torch.nn.Module): ...@@ -895,7 +1033,7 @@ class FlashAttention(torch.nn.Module):
class FusedAttnFunc(torch.autograd.Function): class FusedAttnFunc(torch.autograd.Function):
"""Function for FusedAttention with separate Q, K, V tensors""" """FusedAttention forward and backward implementation"""
@staticmethod @staticmethod
def forward( def forward(
...@@ -919,6 +1057,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -919,6 +1057,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
fused_attention_backend, fused_attention_backend,
...@@ -927,55 +1066,72 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -927,55 +1066,72 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta, fp8_meta,
quantizers, quantizers,
deterministic, deterministic,
softmax_offset,
fp8_output,
layer_number,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn
fake_dtype = q.dtype
# add NVTX range
nvtx_label = "transformer_engine.FusedAttnFunc.forward"
nvtx_range_push(f"{nvtx_label}")
# recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
# input types are inferred from the real data while output types are controlled by fp8_output
# fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha)
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor."
is_input_fp8 = isinstance(q, Float8Tensor)
is_output_fp8 = fp8_output
# whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa)
# whether bwd kernel in FP8:
is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# get quantizers from DPA; all Nones if not fp8
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) dpa_utils.get_attention_quantizers(fp8, quantizers)
) )
# get nominal data type for out
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
out_nominal_dtype = q.dtype
if fp8: if fp8:
fused_attention_backend = FusedAttnBackend["FP8"] fused_attention_backend = FusedAttnBackend["FP8"]
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor) # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16
q_fp8, k_fp8, v_fp8 = None, None, None # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
if is_input_fp8: if is_input_fp8:
q_fp8, k_fp8, v_fp8 = q, k, v q_fp8, k_fp8, v_fp8 = q, k, v
else: else:
# 1: qkv packed, 2: kv packed, 3: qkv separate q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer)
qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
match qkv_group: # print quantizers
case 1: print_quantizers(
dim = qkv_layout.find("3") "FusedAttnFunc.forward >> before: ",
qkv = combine_tensors([q, k, v], dim) layer_number,
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) QKV_quantizer,
qkv_fp8 = QKV_quantizer(qkv) O_quantizer,
q_fp8, k_fp8, v_fp8 = SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True) S_quantizer,
case 2: dQKV_quantizer,
q_fp8 = QKV_quantizer(q) dO_quantizer,
dim = qkv_layout.split("_")[1].find("2") dP_quantizer,
kv = combine_tensors([k, v], dim) )
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_fp8 = QKV_quantizer(kv_c) # out_:
k_fp8, v_fp8 = SplitAlongDim.apply(kv_fp8, dim, [1, 1], True) # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
case 3: # fp8_dtype = tex.DType.kFloat8E4M3
q_fp8 = QKV_quantizer(q) # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
k_fp8 = QKV_quantizer(k) out_, aux_ctx_tensors = fused_attn_fwd(
v_fp8 = QKV_quantizer(v)
case _:
raise "Invalid qkv_layout " + qkv_layout
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
out_fp8, aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
...@@ -984,7 +1140,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -984,7 +1140,7 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8, q_fp8,
k_fp8, k_fp8,
v_fp8, v_fp8,
fake_dtype, out_nominal_dtype,
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
cu_seqlens_q_padded, cu_seqlens_q_padded,
...@@ -999,45 +1155,59 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -999,45 +1155,59 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
softmax_offset,
) )
if is_output_fp8:
out_ret = out_fp8 # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# out: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_fp8 = out_
out = out_
if isinstance(out_, Float8Tensor):
if not is_output_fp8 or not is_bwd_fp8:
out = out_.dequantize().view(out_.shape)
else: else:
out_ret = out_fp8.dequantize().view(out_fp8.shape) if is_output_fp8 or (
# is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16 is_bwd_fp8
# is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)
out_save = out_ret ):
out_fp8 = O_quantizer(out_)
# print quantizers
print_quantizers(
"FusedAttnFunc.forward >> after: ",
layer_number,
QKV_quantizer,
O_quantizer,
S_quantizer,
dQKV_quantizer,
dO_quantizer,
dP_quantizer,
)
if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): # return appropriate tensors
# 1: qkv packed, 2: kv packed, 3: qkv separate out_ret = out_fp8 if is_output_fp8 else out
# save appropriate tensors
fp8_tensors = (None, None, None, None)
qkvo_tensors = (None, None, None, None)
if is_bwd_fp8:
if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16:
fp8_tensors = (q_fp8, k_fp8, v_fp8, None)
qkvo_tensors = (None, None, None, out)
else:
fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
else:
if is_input_fp8: if is_input_fp8:
qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8)
if qkv_group == 1: qkvo_tensors = (q, k, v, out)
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape)
q, k, v = SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True)
if qkv_group == 2:
q = q.dequantize()
dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_no_fp8 = kv.dequantize()
k, v = SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True)
if qkv_group == 3:
q = q.dequantize()
k = k.dequantize()
v = v.dequantize()
if is_output_fp8:
out_save = out_fp8.dequantize()
fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
else: else:
# q, k, v, out_ret: torch.float16 or torch.bfloat16 # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_ret, aux_ctx_tensors = fused_attn_fwd( out_, aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
...@@ -1046,7 +1216,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1046,7 +1216,7 @@ class FusedAttnFunc(torch.autograd.Function):
q, q,
k, k,
v, v,
fake_dtype, out_nominal_dtype,
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
cu_seqlens_q_padded, cu_seqlens_q_padded,
...@@ -1061,13 +1231,23 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1061,13 +1231,23 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
softmax_offset,
) )
out_save = out_ret out = out_
out_ret = out_
fp8_tensors = (None, None, None, None) fp8_tensors = (None, None, None, None)
qkvo_tensors = (q, k, v, out)
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) nvtx_range_pop(f"{nvtx_label}")
ctx.fp8_recipe = fp8_recipe
ctx.fp8 = is_bwd_fp8
# assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16
# used when some tensors are base tensors and loose the "dtype" attribute
ctx.nominal_dtype = out_nominal_dtype
from transformer_engine.pytorch.cpu_offload import ( from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled, CPUOffloadEnabled,
...@@ -1078,7 +1258,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1078,7 +1258,7 @@ class FusedAttnFunc(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
tensor_list = fp8_tensors tensor_list = fp8_tensors
else: else:
tensor_list = [q, k, v, out_save] tensor_list = [q, k, v, out]
qkv_layout = "sbhd_sbhd_sbhd" qkv_layout = "sbhd_sbhd_sbhd"
mark_activation_offload(*tensor_list) mark_activation_offload(*tensor_list)
...@@ -1086,7 +1266,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1086,7 +1266,6 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*fp8_tensors, *fp8_tensors,
*qkvo_tensors, *qkvo_tensors,
...@@ -1100,11 +1279,14 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1100,11 +1279,14 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.layer_number = layer_number
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.dQKV_quantizer = dQKV_quantizer ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer ctx.dP_quantizer = dP_quantizer
ctx.S_quantizer = S_quantizer ctx.S_quantizer = S_quantizer
if ctx.fp8: if ctx.fp8 and isinstance(ctx.S_quantizer, Float8Quantizer):
ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone() ctx.S_quantizer.scale = S_quantizer.scale.clone()
...@@ -1116,6 +1298,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1116,6 +1298,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type
ctx.window_size = window_size ctx.window_size = window_size
ctx.fused_attention_backend = ( ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
...@@ -1128,17 +1311,15 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1128,17 +1311,15 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if ctx.is_output_fp8:
assert isinstance(
d_out, Float8Tensor
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2
fake_dtype = d_out.dtype
d_out = d_out.contiguous() # d_out is expected to be in FP8 if is_output_fp8=True,
# but in the case it's not, convert it to FP8 before any operation
if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorBase):
d_out = ctx.dO_quantizer(d_out)
if not ctx.use_FAv2_bwd:
d_out._data = d_out._data.contiguous()
elif not ctx.use_FAv2_bwd:
d_out = d_out.contiguous()
( (
q_fp8, q_fp8,
k_fp8, k_fp8,
...@@ -1192,16 +1373,55 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1192,16 +1373,55 @@ class FusedAttnFunc(torch.autograd.Function):
dk = dk[..., : d_out.shape[-1]] dk = dk[..., : d_out.shape[-1]]
dv = dv[..., : d_out.shape[-1]] dv = dv[..., : d_out.shape[-1]]
else: else:
with torch.cuda.nvtx.range("_FusedAttn"): with torch.cuda.nvtx.range("FusedAttnFunc.backward"):
# get nominal data type of dq, dk, dv
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
dqkv_nominal_dtype = ctx.nominal_dtype
if ctx.fp8: if ctx.fp8:
# d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16
# d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
if ctx.is_output_fp8: if ctx.is_output_fp8:
d_out_fp8 = d_out d_out_fp8 = d_out
else: else:
d_out_fp8 = ctx.dO_quantizer(d_out) d_out_fp8 = ctx.dO_quantizer(d_out)
dqkv_dtype = TE_DType[d_out_fp8._data.dtype]
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn # print quantizers
# d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2 print_quantizers(
dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( "FusedAttnFunc.backward >> before: ",
ctx.layer_number,
ctx.QKV_quantizer,
ctx.O_quantizer,
ctx.S_quantizer,
ctx.dQKV_quantizer,
ctx.dO_quantizer,
ctx.dP_quantizer,
)
# get tex.DType for dq, dk, dv data
dqkv_te_dtype = d_out_fp8._fp8_dtype
# q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16,
# fp8_dtype = tex.DType.kFloat8E4M3
# d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
# out_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
#
# dq_, dk_, dv_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_ = (
out
if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16
else out_fp8
)
dq_, dk_, dv_, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q,
...@@ -1209,10 +1429,10 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1209,10 +1429,10 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8, q_fp8,
k_fp8, k_fp8,
v_fp8, v_fp8,
out_fp8, out_,
d_out_fp8, d_out_fp8,
fake_dtype, dqkv_nominal_dtype,
dqkv_dtype, dqkv_te_dtype,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
cu_seqlens_q_padded, cu_seqlens_q_padded,
...@@ -1226,44 +1446,45 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1226,44 +1446,45 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
) )
# is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16
# is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 dq, dk, dv = dq_, dk_, dv_
if not ctx.is_input_fp8: is_float8tensor = isinstance(dq_, Float8Tensor)
qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_")) if is_float8tensor and not ctx.is_input_fp8:
if qkv_group == 1: # return in F16
dim = ctx.qkv_layout.find("3") dq, dk, dv = combine_and_dequantize(
dqkv_fp8_data = combine_tensors( ctx.qkv_layout,
[dq_fp8._data, dk_fp8._data, dv_fp8._data], dim dq_,
) dk_,
dqkv_fp8 = dq_fp8.make_like( dv_,
tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape src_nominal_dtype=dq_.dtype,
) )
dqkv = dqkv_fp8.dequantize() if not is_float8tensor and ctx.is_input_fp8:
dq, dk, dv = SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True) # return in FP8
if qkv_group == 2: dq, dk, dv = combine_and_quantize(
dq = dq_fp8.dequantize() ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer
dim = ctx.qkv_layout.split("_")[1].find("2") )
dkv_fp8 = combine_tensors([dk_fp8, dv_fp8], dim)
dkv_c_fp8 = dkv_fp8.view( # print quantizers
-1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] print_quantizers(
) "FusedAttnFunc.backward >> after: ",
dkv = dkv_c_fp8.dequantize() ctx.layer_number,
dk, dv = SplitAlongDim.apply(dkv, dim, [1, 1], True) ctx.QKV_quantizer,
if qkv_group == 3: ctx.O_quantizer,
dq = dq_fp8.dequantize() ctx.S_quantizer,
dk = dk_fp8.dequantize() ctx.dQKV_quantizer,
dv = dv_fp8.dequantize() ctx.dO_quantizer,
else: ctx.dP_quantizer,
dq, dk, dv = dq_fp8, dk_fp8, dv_fp8 )
else: else:
if isinstance(d_out, QuantizedTensor): if isinstance(d_out, QuantizedTensorBase):
d_out = d_out.dequantize() d_out = d_out.dequantize(dtype=ctx.nominal_dtype)
dqkv_dtype = TE_DType[d_out.dtype] dqkv_te_dtype = TE_DType[d_out.dtype]
# q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16 # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16
dq, dk, dv, *rest = fused_attn_bwd( dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
...@@ -1274,8 +1495,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1274,8 +1495,8 @@ class FusedAttnFunc(torch.autograd.Function):
v, v,
out, out,
d_out, d_out,
fake_dtype, dqkv_nominal_dtype,
dqkv_dtype, dqkv_te_dtype,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
cu_seqlens_q_padded, cu_seqlens_q_padded,
...@@ -1289,42 +1510,17 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1289,42 +1510,17 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
) )
# if no_bias or alibi, return dqkv d_bias = None
if ctx.attn_bias_type in ["no_bias", "alibi"]: if ctx.attn_bias_type not in ["no_bias", "alibi"]:
return ( d_bias = rest[0]
None, d_softmax_offset = None
None, if ctx.softmax_type != "vanilla":
None, d_softmax_offset = rest[1]
None,
None,
None,
None,
None,
None,
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
# else, return (dqkv, dbias)
return ( return (
None, None,
None, None,
...@@ -1338,7 +1534,10 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1338,7 +1534,10 @@ class FusedAttnFunc(torch.autograd.Function):
dq, dq,
dk, dk,
dv, dv,
rest[0], d_bias,
None,
None,
None,
None, None,
None, None,
None, None,
...@@ -1351,6 +1550,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1351,6 +1550,7 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
d_softmax_offset,
None, None,
None, None,
) )
...@@ -1392,6 +1592,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1392,6 +1592,7 @@ class FusedAttention(torch.nn.Module):
attention_type: str = "self", attention_type: str = "self",
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
deterministic: bool = False, deterministic: bool = False,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1404,6 +1605,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1404,6 +1605,7 @@ class FusedAttention(torch.nn.Module):
) == "1" and get_device_compute_capability() == (9, 0) ) == "1" and get_device_compute_capability() == (9, 0)
self.layer_number = 1 if layer_number is None else layer_number self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic self.deterministic = deterministic
self.softmax_type = softmax_type
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
""" """
...@@ -1455,6 +1657,8 @@ class FusedAttention(torch.nn.Module): ...@@ -1455,6 +1657,8 @@ class FusedAttention(torch.nn.Module):
quantizers=None, quantizers=None,
pad_between_seqs: bool = False, pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
fp8_output: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
assert ( assert (
...@@ -1555,15 +1759,27 @@ class FusedAttention(torch.nn.Module): ...@@ -1555,15 +1759,27 @@ class FusedAttention(torch.nn.Module):
) )
if fp8: if fp8:
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
" is required for FP8 attention!" " is required for FP8 attention!"
) )
assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!" assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
assert not context_parallel or fp8_meta["recipe"].reduce_amax, ( if fp8_recipe.delayed():
"Amax reduction across TP+CP group is necessary when using context parallelism with" assert not context_parallel or fp8_recipe.reduce_amax, (
" FP8!" "Amax reduction across TP+CP group is necessary when using context parallelism"
) " with FP8!"
)
if fp8_recipe.float8_current_scaling() and context_parallel:
all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers)
for q in all_quantizers:
if isinstance(q, Float8CurrentScalingQuantizer):
q.with_amax_reduction = True
q.amax_reduction_group = (
cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group
)
if context_parallel: if context_parallel:
assert ( assert (
...@@ -1605,6 +1821,10 @@ class FusedAttention(torch.nn.Module): ...@@ -1605,6 +1821,10 @@ class FusedAttention(torch.nn.Module):
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
quantizers=quantizers, quantizers=quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
softmax_type=self.softmax_type,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
layer_number=self.layer_number,
) )
else: else:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -1628,6 +1848,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1628,6 +1848,7 @@ class FusedAttention(torch.nn.Module):
qkv_layout, qkv_layout,
core_attention_bias_type, core_attention_bias_type,
attn_mask_type, attn_mask_type,
self.softmax_type,
window_size, window_size,
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
...@@ -1636,6 +1857,9 @@ class FusedAttention(torch.nn.Module): ...@@ -1636,6 +1857,9 @@ class FusedAttention(torch.nn.Module):
fp8_meta, fp8_meta,
quantizers, quantizers,
self.deterministic, self.deterministic,
softmax_offset,
fp8_output,
self.layer_number,
) )
# ...hd -> ...(hd) # ...hd -> ...(hd)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -11,10 +11,25 @@ import warnings ...@@ -11,10 +11,25 @@ import warnings
import logging import logging
import torch import torch
from torch.nn.parameter import Parameter
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
Format,
Recipe,
DelayedScaling,
Float8CurrentScaling,
)
from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.fp8 import (
get_fp8_te_dtype,
FP8GlobalStateManager,
RecipeState,
DelayedScalingRecipeState,
MXFP8BlockScalingRecipeState,
Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState,
)
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
...@@ -72,6 +87,67 @@ _alibi_cache = { ...@@ -72,6 +87,67 @@ _alibi_cache = {
"_alibi_bias_require_update": False, "_alibi_bias_require_update": False,
} }
"""
This feature is **experimental** and subject to change.
Some models may use different FP8 recipes for their linear layers and attention layers. To support this,
users can either use multiple, nested fp8_autocast() contexts to assign a distinct recipe for each layer,
or use a single fp8_autocast() for the non-attention layers and configure the recipe for the attention
layers as follows.
+-------------------+-----------+-----------------------------------------------------------------------------------+
| Linear | Attention | Configuration |
+===================+===========+===================================================================================+
| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to fp8_autocast(); |
| | | export NVTE_DPA_FP8_RECIPE="F16" |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8DS | Pass FP8DS to fp8_autocast(); |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8DS | Pass FP8CS to fp8_autocast(); |
| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8DS | Pass NVFP4 to fp8_autocast(); |
| | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8CS | Pass FP8DS to fp8_autocast(); |
| | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,|
| | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8CS | Pass FP8CS to fp8_autocast(); |
| | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe |
| | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8CS | Pass NVFP4 to fp8_autocast(); |
| | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe |
| | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
"""
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2}
_dpa_fp8_format = formats[os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")]
_dpa_fp8ds_amax_algo = os.getenv("NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent")
_dpa_fp8ds_amax_histlen = int(os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1"))
_dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1"
__all__ = ["DotProductAttention"] __all__ = ["DotProductAttention"]
...@@ -168,6 +244,17 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -168,6 +244,17 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_scale: Optional[float], default = `None` softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to softmax scale for the attention scores. If `None`, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -223,6 +310,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -223,6 +310,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -307,6 +395,20 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -307,6 +395,20 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type self.attention_type = attention_type
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.softmax_type = softmax_type
if self.softmax_type == "vanilla":
self.softmax_offset = None
if self.softmax_type == "off-by-one":
self.softmax_offset = torch.zeros(
self.num_attention_heads // self.tp_size, device="cuda"
)
if self.softmax_type == "learnable":
self.register_parameter(
"softmax_offset",
Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")),
get_rng_state_tracker=get_rng_state_tracker,
)
attn_kwargs = { attn_kwargs = {
"attention_dropout": attention_dropout, "attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx, "attention_dropout_ctx": attention_dropout_ctx,
...@@ -328,6 +430,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -328,6 +430,7 @@ class DotProductAttention(TransformerEngineBaseModule):
layer_number=layer_number, layer_number=layer_number,
deterministic=self.deterministic, deterministic=self.deterministic,
**attn_kwargs, **attn_kwargs,
softmax_type=self.softmax_type,
) )
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
...@@ -335,6 +438,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -335,6 +438,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type=attention_type, attention_type=attention_type,
**attn_kwargs, **attn_kwargs,
layer_number=layer_number, layer_number=layer_number,
softmax_type=self.softmax_type,
) )
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
...@@ -433,6 +537,231 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -433,6 +537,231 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_stream = cp_stream self.cp_stream = cp_stream
self.cp_comm_type = cp_comm_type self.cp_comm_type = cp_comm_type
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""
Override TransformerEngineBaseModule.init_fp8_metadata to allow for more flexible recipe support.
Initialize fp8 related metadata and tensors during fprop.
"""
_original_recipe = self.fp8_meta.get("recipe", None)
# global recipe set in fp8_autocast()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
# switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to
# a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well.
#
# fp8_recipe | NVTE_DPA_FP8_RECIPE | self.fp8_meta["recipe"] | self.quantizers
# --------------------------------------------------------------------------------------------
# DelayedScaling (DS) | unset | DS | all DS
# Float8CurrentScaling (CS) | unset | DS | CS for QKV, O, dO, dQKV; DS for S, dP
# x={DS, CS} | y | refer to row x=y | refer to row x=y
fp8_recipe_dpa = fp8_recipe
fp8_recipes = fp8_recipe
if _dpa_fp8_recipe == "F16":
# ignore the recipe from fp8_autocast, set fp8_dpa = False, fp8_mha = False
fp8_recipe.fp8_dpa = False
fp8_recipe.fp8_mha = False
elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling":
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe
fake_recipe = DelayedScaling(
fp8_format=fp8_recipe.fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = fp8_recipe_dpa
elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "DelayedScaling":
# reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a DS recipe
fake_recipe = DelayedScaling(
fp8_format=_dpa_fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = fp8_recipe_dpa
elif fp8_recipe.delayed() and _dpa_fp8_recipe == "Float8CurrentScaling":
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe
fake_recipes = [
Float8CurrentScaling(
fp8_format=fp8_recipe.fp8_format,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
),
fp8_recipe,
]
fp8_recipe_dpa = fake_recipes[1]
fp8_recipes = fake_recipes
elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe in (
"",
"Float8CurrentScaling",
):
# use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe
fake_recipe = DelayedScaling(
fp8_format=fp8_recipe.fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = [fp8_recipe, fp8_recipe_dpa]
elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling":
# reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format
# construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP
fake_recipes = [
Float8CurrentScaling(
fp8_format=_dpa_fp8_format,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
),
DelayedScaling(
fp8_format=_dpa_fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
),
]
fp8_recipe_dpa = fake_recipes[1]
fp8_recipes = fake_recipes
# DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False
if not fp8_recipe_dpa.float8_per_tensor_scaling():
assert not (
fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha
), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe"
# reduce over TP+CP groups; expect fp8_group to be set up so
# assume attention uses the same fp8_group as GEMMs
fp8_group = FP8GlobalStateManager.get_fp8_group()
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled:
self.fp8_meta["global_recipe"] = fp8_recipe
self.fp8_meta["local_recipes"] = (
fp8_recipes if isinstance(fp8_recipes, List) else [fp8_recipes]
)
if self.fp8_parameters or fp8_enabled:
if self.fp8_initialized and fp8_recipe_dpa == self.fp8_meta["recipe"]:
# FP8 init has already been run and recipe is the same, don't do anything.
return
self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
# fp8_recipe has changed, rehash the key.
autocast_key = FP8GlobalStateManager.get_unique_autocast_key(
fp8_recipe_dpa, fp8_group
)
FP8GlobalStateManager.autocast_arguments[autocast_key] = (
fp8_recipe_dpa,
fp8_group,
)
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
if self.fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(fp8_recipes)
if fp8_enabled:
# Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = fp8_group
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors(fp8_recipes)
self.fp8_initialized = True
self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
# fp8_recipe has changed, rehash the key.
autocast_key = FP8GlobalStateManager.get_unique_autocast_key(
fp8_recipe_dpa, fp8_group
)
FP8GlobalStateManager.autocast_arguments[autocast_key] = (
fp8_recipe_dpa,
fp8_group,
)
_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
):
warnings.warn(
f"Recipe type changed from {_original_recipe.__class__.__name__} "
f"to {_current_recipe.__class__.__name__}. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None:
"""Override to allow multiple recipes. Init scales and amaxes for fwd | bwd."""
if isinstance(recipe, Recipe):
recipe = [recipe]
fp8_recipe_dpa = recipe[-1]
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
# Return early if recipe state matches recipe
if self.fp8_meta_tensors_initialized:
recipe_state = self.fp8_meta[fp8_meta_tensor_key]
if fp8_recipe_dpa.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
self.adjust_amax_history_length(fp8_recipe_dpa.amax_history_len, fwd=fwd)
return
if fp8_recipe_dpa.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
return
if fp8_recipe_dpa.float8_current_scaling() and isinstance(
recipe_state, Float8CurrentScalingRecipeState
):
return
if fp8_recipe_dpa.float8_block_scaling() and isinstance(
recipe_state, Float8BlockScalingRecipeState
):
return
# When fp8_recipe=Float8CurrentScaling, recipe=[CS, DS], and QKV/dQKV, O/dO use CS quantizers, S/dP use DS quantizers.
# See table above in init_fp8_metadata for more detail.
num_gemms = [2, 1] if len(recipe) == 2 else [3]
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = [x * 3 if fwd else x * 2 for x in num_gemms]
# Initialize recipe state and quantizers
recipe_states = [
RecipeState.create(
recipe[i],
mode=("forward" if fwd else "backward"),
num_quantizers=num_fp8_tensors[i],
)
for i in range(len(recipe))
]
self.fp8_meta[fp8_meta_tensor_key] = (
recipe_states[-1] if len(recipe) == 2 else recipe_states[0]
)
self.quantizers[fp8_meta_tensor_key] = []
for recipe_state in recipe_states:
self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers())
@no_torch_dynamo(recursive=False) @no_torch_dynamo(recursive=False)
def forward( def forward(
self, self,
...@@ -456,6 +785,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -456,6 +785,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None, pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Dot Product Attention Layer. Dot Product Attention Layer.
...@@ -628,12 +958,15 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -628,12 +958,15 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs: Optional[bool], default = `None` pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch. If true, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False`
Whether to enforce output to be in FP8 or not.
""" """
with torch.cuda.device(query_layer.device), self.prepare_forward( with torch.cuda.device(query_layer.device), self.prepare_forward(
query_layer, query_layer,
num_gemms=3, num_gemms=3,
allow_non_contiguous=True, allow_non_contiguous=True,
allow_different_data_and_param_types=self.softmax_type != "vanilla",
) as query_layer: ) as query_layer:
# checks for RNG # checks for RNG
if self.rng_states_tracker is not None and is_graph_capturing(): if self.rng_states_tracker is not None and is_graph_capturing():
...@@ -663,6 +996,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -663,6 +996,8 @@ class DotProductAttention(TransformerEngineBaseModule):
tex.DType.kFloat8E4M3, tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2, tex.DType.kFloat8E5M2,
], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
else:
fp8_output = False
# checks for q/k/v shapes # checks for q/k/v shapes
assert ( assert (
...@@ -922,6 +1257,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -922,6 +1257,7 @@ class DotProductAttention(TransformerEngineBaseModule):
False False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
# check if there is padding between sequences when qkv_format='thd'
if pad_between_seqs is None: if pad_between_seqs is None:
if qkv_format == "thd": if qkv_format == "thd":
pad_between_seqs = ( pad_between_seqs = (
...@@ -957,11 +1293,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -957,11 +1293,13 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
context_parallel=context_parallel, context_parallel=context_parallel,
cp_comm_type=self.cp_comm_type,
deterministic=self.deterministic, deterministic=self.deterministic,
is_training=self.training, is_training=self.training,
fp8=self.fp8, fp8=self.fp8,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
inference_params=inference_params, inference_params=inference_params,
softmax_type=self.softmax_type,
) )
global _attention_backends global _attention_backends
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -1022,6 +1360,12 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1022,6 +1360,12 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
# run attention # run attention
softmax_offset = (
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
if self.softmax_offset is not None
else None
)
if use_flash_attention: if use_flash_attention:
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi( alibi_slopes, _ = dpa_utils.get_alibi(
...@@ -1053,6 +1397,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1053,6 +1397,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers, quantizers=self.quantizers,
inference_params=inference_params, inference_params=inference_params,
flash_attention_backend=flash_attention_backend, flash_attention_backend=flash_attention_backend,
fp8_output=fp8_output,
) )
if use_fused_attention: if use_fused_attention:
...@@ -1071,7 +1416,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1071,7 +1416,6 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype=query_layer.dtype, bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
) )
# checkpoint_core_attention=False
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.fused_attention, self.fused_attention,
...@@ -1101,6 +1445,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1101,6 +1445,8 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers, quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
) )
return self.fused_attention( return self.fused_attention(
query_layer, query_layer,
...@@ -1129,6 +1475,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1129,6 +1475,8 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers, quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
) )
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
...@@ -1140,6 +1488,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1140,6 +1488,7 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
if use_unfused_attention: if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.unfused_attention, self.unfused_attention,
...@@ -1157,6 +1506,11 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1157,6 +1506,11 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
fp8_output=fp8_output,
) )
return self.unfused_attention( return self.unfused_attention(
_alibi_cache, _alibi_cache,
...@@ -1173,5 +1527,10 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1173,5 +1527,10 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
fp8_output=fp8_output,
) )
return None return None
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
import torch import torch
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
import transformer_engine_torch as tex import transformer_engine_torch as tex
import transformer_engine as te import transformer_engine as te
...@@ -24,6 +25,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -24,6 +25,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout, QKVLayout,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
SoftmaxType,
FusedAttnBackend, FusedAttnBackend,
META_QKV, META_QKV,
META_DQKV, META_DQKV,
...@@ -31,11 +33,13 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -31,11 +33,13 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_DO, META_DO,
META_S, META_S,
META_DP, META_DP,
META_O_CP,
META_DQKV_CP,
) )
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
...@@ -43,6 +47,8 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION ...@@ -43,6 +47,8 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
get_cudnn_version, get_cudnn_version,
SplitAlongDim,
combine_tensors,
) )
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
...@@ -53,6 +59,9 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) ...@@ -53,6 +59,9 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
# print quantizer info for a particular layer on a particular rank
_print_layer = int(os.getenv("NVTE_PRINT_LAYER_NUMBER", "1"))
_print_rank = int(os.getenv("NVTE_PRINT_RANK", "0"))
_cu_seqlens_cache = {} _cu_seqlens_cache = {}
...@@ -206,6 +215,8 @@ class AttentionParams: ...@@ -206,6 +215,8 @@ class AttentionParams:
Attention dropout. Attention dropout.
context_parallel: bool, default = `False` context_parallel: bool, default = `False`
Whether context parallelism is used or not. Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False` deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not. Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True` is_training: bool, default = `True`
...@@ -216,6 +227,8 @@ class AttentionParams: ...@@ -216,6 +227,8 @@ class AttentionParams:
The FP8 metadata tensor of `DotProductAttention`. The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None` inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details. Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
""" """
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
...@@ -237,11 +250,13 @@ class AttentionParams: ...@@ -237,11 +250,13 @@ class AttentionParams:
pad_between_seqs: bool = False pad_between_seqs: bool = False
attention_dropout: float = 0.0 attention_dropout: float = 0.0
context_parallel: bool = False context_parallel: bool = False
cp_comm_type: str = "p2p"
deterministic: bool = False deterministic: bool = False
is_training: bool = True is_training: bool = True
fp8: bool = False fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla"
def __eq__(self, other): def __eq__(self, other):
""" """
...@@ -308,11 +323,13 @@ def get_attention_backend( ...@@ -308,11 +323,13 @@ def get_attention_backend(
pad_between_seqs = attention_params.pad_between_seqs pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel context_parallel = attention_params.context_parallel
cp_comm_type = attention_params.cp_comm_type
deterministic = attention_params.deterministic deterministic = attention_params.deterministic
is_training = attention_params.is_training is_training = attention_params.is_training
fp8 = attention_params.fp8 fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type
# Run config # Run config
logger = logging.getLogger("DotProductAttention") logger = logging.getLogger("DotProductAttention")
...@@ -341,8 +358,31 @@ def get_attention_backend( ...@@ -341,8 +358,31 @@ def get_attention_backend(
field.name: getattr(attention_params, field.name) for field in fields(attention_params) field.name: getattr(attention_params, field.name) for field in fields(attention_params)
} }
run_config.update(attention_params_dict) run_config.update(attention_params_dict)
# Add FP8 environment variables to config
if fp8: if fp8:
# all FP8 recipes: 1: (FP8 fwd, FP8 bwd), 0: (FP8 fwd, F16 bwd)
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# Float8CurrentScaling: 1: use F16 O in bwd, 0: use FP8 O in bwd
run_config["NVTE_DPA_FP8CS_O_in_F16"] = int(os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1"))
# switch recipe to "F16", "DelayedScaling", or "Float8CurrentScaling"
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
run_config["NVTE_DPA_FP8_RECIPE"] = _dpa_fp8_recipe
if _dpa_fp8_recipe != "":
# config new recipe if switched
run_config["NVTE_DPA_FP8_FORMAT"] = os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")
run_config["NVTE_DPA_FP8DS_AMAX_ALGO"] = os.getenv(
"NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent"
)
run_config["NVTE_DPA_FP8DS_AMAX_HISTLEN"] = int(
os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1")
)
run_config["NVTE_DPA_FP8DS_REDUCE_AMAX"] = int(
os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1")
)
# UnfusedDotProductAttention: 1: allow FP8 emulation, 0: do not allow
run_config["NVTE_UnfusedDPA_Emulate_FP8"] = int(
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0")
)
logger.debug("Running with config=%s", run_config) logger.debug("Running with config=%s", run_config)
# The following sections check if `FlashAttention` supports the provided attention params, # The following sections check if `FlashAttention` supports the provided attention params,
...@@ -422,8 +462,20 @@ def get_attention_backend( ...@@ -422,8 +462,20 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for FP8 training") logger.debug("Disabling FlashAttention 3 for FP8 training")
use_flash_attention_3 = False use_flash_attention_3 = False
if use_unfused_attention: if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
use_unfused_attention = False if not allow_emulation:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
use_unfused_attention = False
fp8_recipe = fp8_meta["recipe"]
if fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
if (
use_fused_attention
and fp8_recipe.float8_current_scaling()
and device_compute_capability < (10, 0)
):
logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100")
use_fused_attention = False
# TODO: rocm fused attention backends does not support fp8 yet # TODO: rocm fused attention backends does not support fp8 yet
if IS_HIP_EXTENSION and use_fused_attention: if IS_HIP_EXTENSION and use_fused_attention:
logger.debug("Disabling ROCm FusedAttention as it does not support FP8") logger.debug("Disabling ROCm FusedAttention as it does not support FP8")
...@@ -581,6 +633,51 @@ def get_attention_backend( ...@@ -581,6 +633,51 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for dropout") logger.debug("Disabling FlashAttention 3 for dropout")
use_flash_attention_3 = False use_flash_attention_3 = False
# Filter: Softmax type
# context_parallel | softmax_type | supported backends
# ----------------------------------------------------------------------------------------------------
# no | vanilla | All
# no | off-by-one | FusedAttention, UnfusedDotProductAttention
# no | learnable | FusedAttention, UnfusedDotProductAttention
# yes | vanilla | FusedAttention, FlashAttention
# yes | off-by-one | FusedAttention
# yes | learnable | FusedAttention
if softmax_type != "vanilla":
logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type)
use_flash_attention = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type
)
use_unfused_attention = False
if qkv_format == "thd":
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type,
)
use_unfused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
" cp_comm_type = %s",
softmax_type,
cp_comm_type,
)
use_fused_attention = False
# Filter: Context parallelism # Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends # qkv_format | attn_mask_type | attn_bias_type | supported backends
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
...@@ -822,6 +919,7 @@ def get_attention_backend( ...@@ -822,6 +919,7 @@ def get_attention_backend(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type], AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
attention_dropout, attention_dropout,
num_heads, num_heads,
num_gqa_groups, num_gqa_groups,
...@@ -1836,11 +1934,10 @@ def check_set_window_size( ...@@ -1836,11 +1934,10 @@ def check_set_window_size(
return window_size return window_size
def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): def get_attention_quantizers(fp8, quantizers):
"""Get the list of quantizers used in attention from the quantizers list.""" """Get the list of quantizers used in attention from the quantizers list."""
if not fp8: if not fp8:
num_of_nones = 8 if cp_specific_quantizers else 6 return [None] * 6
return [None] * num_of_nones
QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
QKV_quantizer.internal = True QKV_quantizer.internal = True
QKV_quantizer.set_usage(rowwise=True, columnwise=False) QKV_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -1849,6 +1946,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): ...@@ -1849,6 +1946,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer = quantizers["scaling_fwd"][META_S]
S_quantizer.internal = True S_quantizer.internal = True
S_quantizer.set_usage(rowwise=True, columnwise=False) S_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV]
dQKV_quantizer.interal = True dQKV_quantizer.interal = True
dQKV_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -1858,22 +1956,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): ...@@ -1858,22 +1956,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer = quantizers["scaling_bwd"][META_DP]
dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.set_usage(rowwise=True, columnwise=False)
dP_quantizer.interal = True dP_quantizer.interal = True
dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP]
dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
dQKV_CP_quantizer.internal = True
O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP]
O_CP_quantizer.set_usage(rowwise=True, columnwise=False) def print_quantizers(
label,
if cp_specific_quantizers: layer_number,
return ( QKV_quantizer,
O_quantizer,
S_quantizer,
dQKV_quantizer,
dO_quantizer,
dP_quantizer,
):
"""Print the type and scale/amax of attention quantizers"""
_to_print = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL == 2
if (
_to_print
and _print_layer == layer_number
and (
not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == _print_rank)
)
):
names = [
"QKV_quantizer",
"S_quantizer",
"O_quantizer",
"dO_quantizer",
"dP_quantizer",
"dQKV_quantizer",
]
quantizers = [
QKV_quantizer, QKV_quantizer,
O_quantizer,
O_CP_quantizer,
S_quantizer, S_quantizer,
dQKV_quantizer, O_quantizer,
dQKV_CP_quantizer,
dO_quantizer, dO_quantizer,
dP_quantizer, dP_quantizer,
) dQKV_quantizer,
]
if "forward" in label:
names = names[:3]
quantizers = quantizers[:3]
if "backward" in label:
names = names[3:]
quantizers = quantizers[3:]
for i, q in enumerate(quantizers):
type_str = ""
if q is None:
type_str = "None"
elif isinstance(q, Float8Quantizer):
type_str = "DS"
elif isinstance(q, Float8CurrentScalingQuantizer):
type_str = "CS"
print(
f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x"
f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}"
)
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer):
"""Combine q,k,v based on qkv_layout and quantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout = qkv_layout.replace("paged_kv_", "")
qkv_group = len(qkv_layout.split("_"))
src_nominal_dtype = q.dtype
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_fp8 = qkv_quantizer(qkv)
q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
tensors = [q, kv]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
q_data, kv_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
k_data, v_data = SplitAlongDim.apply(kv_data, dim, [1, 1], True)
case 3:
tensors = [q, k, v]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
q_data, k_data, v_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
case _:
raise RuntimeError("Invalid qkv_layout " + qkv_layout)
q_fp8, k_fp8, v_fp8 = [
Float8Tensor.make_like(qkv_fp8, data=x, dtype=src_nominal_dtype)
for x in [q_data, k_data, v_data]
]
return q_fp8, k_fp8, v_fp8
def combine_and_dequantize(
qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None
):
"""Combine q,k,v based on qkv_layout and dequantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout = qkv_layout.replace("paged_kv_", "")
qkv_group = len(qkv_layout.split("_"))
if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]):
src_nominal_dtype = q_fp8.dtype
else:
assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!"
if des_nominal_dtype is None:
des_nominal_dtype = src_nominal_dtype
q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]]
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv_data = combine_tensors([q_data, k_data, v_data], dim)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, k, v = SplitAlongDim.apply(qkv, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
kv_data = combine_tensors([k_data, v_data], dim)
tensors = [q_data, kv_data]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv_data = torch.cat([x.reshape(-1) for x in tensors], dim=0)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, kv = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)]
k, v = SplitAlongDim.apply(kv, dim, [1, 1], True)
case 3:
tensors = [q_data, k_data, v_data]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv_data = torch.cat([x.contiguous().reshape(-1) for x in tensors], dim=0)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, k, v = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)]
case _:
raise RuntimeError("Invalid qkv_layout " + qkv_layout)
return q, k, v
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Multi-head Attention.""" """Multi-head Attention."""
import os
import collections import collections
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
...@@ -31,7 +32,13 @@ from transformer_engine.pytorch.distributed import ( ...@@ -31,7 +32,13 @@ from transformer_engine.pytorch.distributed import (
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
# Force DotProductAttention to use a different recipe than the fp8_recipe set in fp8_autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
_dpa_fp8_recipe_dpa = os.getenv("NVTE_DPA_FP8_RECIPE_DPA", "0") == "1"
_dpa_fp8_recipe_mha = os.getenv("NVTE_DPA_FP8_RECIPE_MHA", "0") == "1"
class MultiheadAttention(torch.nn.Module): class MultiheadAttention(torch.nn.Module):
...@@ -135,6 +142,17 @@ class MultiheadAttention(torch.nn.Module): ...@@ -135,6 +142,17 @@ class MultiheadAttention(torch.nn.Module):
For that, please use `get_qkv_layout` to gain the layout information. For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None` name: str, default = `None`
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -245,6 +263,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -245,6 +263,7 @@ class MultiheadAttention(torch.nn.Module):
qk_norm_before_rope: bool = False, qk_norm_before_rope: bool = False,
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -262,6 +281,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -262,6 +281,7 @@ class MultiheadAttention(torch.nn.Module):
self.return_bias = return_bias self.return_bias = return_bias
self.cp_size = 1 self.cp_size = 1
self.cp_rank = 0 self.cp_rank = 0
self.softmax_type = softmax_type
kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
...@@ -416,6 +436,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -416,6 +436,7 @@ class MultiheadAttention(torch.nn.Module):
tp_group=tp_group, tp_group=tp_group,
layer_number=self.layer_number, layer_number=self.layer_number,
attention_type=self.attention_type, attention_type=self.attention_type,
softmax_type=self.softmax_type,
) )
# Linear # Linear
...@@ -556,10 +577,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -556,10 +577,12 @@ class MultiheadAttention(torch.nn.Module):
self.cp_size = get_distributed_world_size(cp_group) self.cp_size = get_distributed_world_size(cp_group)
self.cp_rank = get_distributed_rank(cp_group) self.cp_rank = get_distributed_rank(cp_group)
elif isinstance(cp_group, list): elif isinstance(cp_group, list):
assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
assert ( assert (
cp_comm_type == "a2a+p2p" cp_comm_type == "a2a+p2p"
), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!" ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
assert (
len(cp_group) == 2
), "cp_comm_type = a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!"
cp_size_a2a = get_distributed_world_size(cp_group[0]) cp_size_a2a = get_distributed_world_size(cp_group[0])
cp_rank_a2a = get_distributed_rank(cp_group[0]) cp_rank_a2a = get_distributed_rank(cp_group[0])
cp_size_p2p = get_distributed_world_size(cp_group[1]) cp_size_p2p = get_distributed_world_size(cp_group[1])
...@@ -716,10 +739,22 @@ class MultiheadAttention(torch.nn.Module): ...@@ -716,10 +739,22 @@ class MultiheadAttention(torch.nn.Module):
# Query, Key, and Value # Query, Key, and Value
# ====================== # ======================
fp8_mha = ( fp8 = FP8GlobalStateManager.is_fp8_enabled()
FP8GlobalStateManager.is_fp8_enabled() if _dpa_fp8_recipe == "":
and FP8GlobalStateManager.get_fp8_recipe().fp8_mha fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
) fp8_dpa = fp8_recipe.fp8_dpa
fp8_mha = fp8_recipe.fp8_mha
float8_current_scaling = fp8_recipe.float8_current_scaling()
else:
fp8_dpa = _dpa_fp8_recipe_dpa
fp8_mha = _dpa_fp8_recipe_mha
float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling"
# QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe
qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling
# DPA: always produce FP8 output when fp8=True to take advantage of the O amax
dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha)
# Proj Gemm: match DPA output except for Float8CurrentScaling
proj_fp8_grad = dpa_fp8_output and not float8_current_scaling
layernorm_output = None layernorm_output = None
if self.attention_type == "self": if self.attention_type == "self":
...@@ -728,7 +763,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -728,7 +763,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_qkv_outputs = self.layernorm_qkv( layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
if self.return_layernorm_output: if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs mixed_x_layer, layernorm_output = layernorm_qkv_outputs
...@@ -738,7 +773,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -738,7 +773,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_x_layer = self.qkv( mixed_x_layer = self.qkv(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
num_queries_per_key_value = ( num_queries_per_key_value = (
...@@ -792,7 +827,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -792,7 +827,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer = self.key_value( mixed_kv_layer = self.key_value(
encoder_output, encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
if self.qkv_weight_interleaved: if self.qkv_weight_interleaved:
...@@ -847,7 +882,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -847,7 +882,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_query_outputs = self.layernorm_query( layernorm_query_outputs = self.layernorm_query(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
if self.return_layernorm_output: if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs query_layer, layernorm_output = layernorm_query_outputs
...@@ -857,7 +892,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -857,7 +892,7 @@ class MultiheadAttention(torch.nn.Module):
query_layer = self.query_layer( query_layer = self.query_layer(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
# [sq, b, hp] --> [sq, b, np, hn] # [sq, b, hp] --> [sq, b, np, hn]
...@@ -958,6 +993,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -958,6 +993,7 @@ class MultiheadAttention(torch.nn.Module):
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
inference_params=inference_params, inference_params=inference_params,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
fp8_output=dpa_fp8_output,
) )
# =================== # ===================
...@@ -966,7 +1002,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -966,7 +1002,7 @@ class MultiheadAttention(torch.nn.Module):
projection_output = self.proj( projection_output = self.proj(
context_layer, context_layer,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_grad=isinstance(context_layer, QuantizedTensor), fp8_grad=proj_fp8_grad,
) )
if self.return_bias: if self.return_bias:
......
...@@ -91,3 +91,5 @@ GemmParallelModes = ("row", "column", None) ...@@ -91,3 +91,5 @@ GemmParallelModes = ("row", "column", None)
dist_group_type = torch.distributed.ProcessGroup dist_group_type = torch.distributed.ProcessGroup
MXFP8_BLOCK_SCALING_SIZE = 32 MXFP8_BLOCK_SCALING_SIZE = 32
NVFP4_BLOCK_SCALING_SIZE = 16
...@@ -12,6 +12,7 @@ from transformer_engine_torch import ( ...@@ -12,6 +12,7 @@ from transformer_engine_torch import (
NVTE_QKV_Format, NVTE_QKV_Format,
NVTE_Bias_Type, NVTE_Bias_Type,
NVTE_Mask_Type, NVTE_Mask_Type,
NVTE_Softmax_Type,
NVTE_Fused_Attn_Backend, NVTE_Fused_Attn_Backend,
) )
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
...@@ -86,6 +87,12 @@ AttnMaskType = { ...@@ -86,6 +87,12 @@ AttnMaskType = {
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, "padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
} }
SoftmaxType = {
"vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX,
"off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX,
"learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX,
}
FusedAttnBackend = { FusedAttnBackend = {
"F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
...@@ -102,9 +109,6 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT ...@@ -102,9 +109,6 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3 META_DP = tex.FP8BwdTensors.GRAD_INPUT3
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT
META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1
def fused_attn_fwd( def fused_attn_fwd(
...@@ -131,8 +135,10 @@ def fused_attn_fwd( ...@@ -131,8 +135,10 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input. """Fused Attention FWD for separate QKV input.
...@@ -197,6 +203,8 @@ def fused_attn_fwd( ...@@ -197,6 +203,8 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
...@@ -205,6 +213,9 @@ def fused_attn_fwd( ...@@ -205,6 +213,9 @@ def fused_attn_fwd(
rng_gen: torch.Generator, default = None rng_gen: torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
Returns Returns
---------- ----------
...@@ -286,6 +297,7 @@ def fused_attn_fwd( ...@@ -286,6 +297,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size, window_size,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
...@@ -300,6 +312,7 @@ def fused_attn_fwd( ...@@ -300,6 +312,7 @@ def fused_attn_fwd(
s_quantizer, s_quantizer,
o_quantizer, o_quantizer,
attn_bias, attn_bias,
softmax_offset,
rng_gen, rng_gen,
rng_elts_per_thread, rng_elts_per_thread,
) )
...@@ -333,6 +346,7 @@ def fused_attn_bwd( ...@@ -333,6 +346,7 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False, deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -398,6 +412,8 @@ def fused_attn_bwd( ...@@ -398,6 +412,8 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
...@@ -417,6 +433,9 @@ def fused_attn_bwd( ...@@ -417,6 +433,9 @@ def fused_attn_bwd(
d_bias: torch.Tensor, optional d_bias: torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias" gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
""" """
if attn_scale is None: if attn_scale is None:
d = q.size(-1) d = q.size(-1)
...@@ -454,6 +473,7 @@ def fused_attn_bwd( ...@@ -454,6 +473,7 @@ def fused_attn_bwd(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size, window_size,
deterministic, deterministic,
cu_seqlens_q, cu_seqlens_q,
......
...@@ -20,6 +20,8 @@ from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_ ...@@ -20,6 +20,8 @@ from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor.utils import is_experimental
from ..experimental.gemm import experimental_gemm
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
...@@ -169,6 +171,24 @@ def general_gemm( ...@@ -169,6 +171,24 @@ def general_gemm(
if not out.is_contiguous(): if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.") raise ValueError("Output tensor is not contiguous.")
# If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation
if is_experimental(A) or is_experimental(B):
return experimental_gemm(
A,
B,
workspace,
out_dtype,
quantization_params,
gelu,
gelu_in,
accumulate,
layout,
out,
bias,
use_split_accumulator,
grad,
)
debug_quantizer = None debug_quantizer = None
if isinstance(quantization_params, DebugQuantizer): if isinstance(quantization_params, DebugQuantizer):
debug_quantizer = quantization_params debug_quantizer = quantization_params
......
...@@ -12,6 +12,20 @@ ...@@ -12,6 +12,20 @@
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
/*! convert fp4 data shape back to original shape */
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose) {
std::vector<size_t> ret;
size_t start_idx = (transpose) ? 1 : 0;
for (size_t i = start_idx; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
}
ret.push_back(shape.back() * 2);
if (transpose) {
ret.push_back(shape.front());
}
return ret;
}
std::vector<size_t> getTensorShape(const at::Tensor& t) { std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape; std::vector<size_t> shape;
for (auto s : t.sizes()) { for (auto s : t.sizes()) {
...@@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) { ...@@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) {
return ((value + multiple - 1) / multiple) * multiple; return ((value + multiple - 1) / multiple) * multiple;
} }
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
at::cuda::getCurrentCUDAStream());
});
}
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread) {
at::PhiloxCudaState philox_args;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <transformer_engine/fused_rope.h> #include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_router.h> #include <transformer_engine/fused_router.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/hadamard_transform.h>
#include <transformer_engine/multi_stream.h> #include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h> #include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
...@@ -212,20 +213,25 @@ class Float8CurrentScalingQuantizer : public Quantizer { ...@@ -212,20 +213,25 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override; DType dtype) const override;
/*! @brief Construct a high precision tensor giving it this quantizer's amax /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer.
*
Note: this member function also zeros out the amax, as it is meant to be used in conjunction with * The amax is zeroed out. Most TE kernels that output amax expect
a kernel computing the amax, which might expect the amax to be initialized to zero * amax to be initialized to zero.
*/ */
std::pair<TensorWrapper, py::object> create_hp_tensor_with_amax(const std::vector<size_t>& shape, std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
DType dtype); const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data = std::nullopt);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override; std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out, void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override; const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Convert to a quantized data format avoiding amax computation */ /*! @brief Quantize to FP8, skipping local amax computation
*
* The quantizer's amax pointer is assumed to already hold the local
* amax. The amax may still be reduced across the amax reduction
* group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt); const std::optional<TensorWrapper>& noop_flag = std::nullopt);
...@@ -295,6 +301,60 @@ class MXFP8Quantizer : public Quantizer { ...@@ -295,6 +301,60 @@ class MXFP8Quantizer : public Quantizer {
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const; std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
}; };
class NVFP4Quantizer : public Quantizer {
public:
// fp4 dtype
DType dtype;
// amax reduction for low precision FP4 AG
bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
// random hadamard transform
bool with_rht;
bool with_post_rht_amax;
// 2D block scaling
bool with_2d_quantization;
bool stochastic_rounding;
int rht_matrix_random_sign_mask_t;
at::Tensor rht_matrix;
explicit NVFP4Quantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; }
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer
*
* The amax is zeroed out. Most TE kernels that output amax expect
* amax to be initialized to zero.
*/
std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
TensorWrapper& quantized_tensor, DType dtype);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Quantize to NVFP4, skipping local amax computation
*
* The input tensor's amax pointer is assumed to already hold the
* local amax. The amax may still be reduced across the amax
* reduction group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out);
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
private:
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
};
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer); std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
std::vector<size_t> getTensorShape(const at::Tensor& t); std::vector<size_t> getTensorShape(const at::Tensor& t);
...@@ -445,6 +505,15 @@ std::vector<size_t> convertShape(const NVTEShape& shape); ...@@ -445,6 +505,15 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
size_t roundup(const size_t value, const size_t multiple); size_t roundup(const size_t value, const size_t multiple);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose);
// unpack the PhiloxCudaState into CUDA tensor
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr);
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread);
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
namespace std { namespace std {
......
...@@ -73,28 +73,36 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T ...@@ -73,28 +73,36 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Fused_Attn_Backend get_fused_attn_backend(
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
bool create_hp_tensor_for_cs,
std::optional<at::Tensor> data);
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const std::optional<at::Tensor> cu_seqlens_q_padded, const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread); const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q, NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
......
...@@ -8,179 +8,269 @@ ...@@ -8,179 +8,269 @@
#include "common.h" #include "common.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine {
namespace pytorch {
template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)> namespace {
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) {
py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t),
const at::Tensor& input, py::handle quantizer,
int shape_divisor = 1) {
init_extension(); init_extension();
// Input tensor // Input tensor
auto input_tensor = input.contiguous(); auto input_tensor = input.contiguous();
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor);
// Construct output tensor // Construct output tensor
auto quantizer_cpp = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape = input_cpp.shape(); const auto input_shape = input_nvte.shape();
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim); std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
output_shape.back() /= shape_divisor; output_shape.back() /= shape_divisor;
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
// Compute activation // Choose implementation
enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 };
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) { detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation directly impl = Impl::FULLY_FUSED;
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation in high-precision fused together with amax, then quantize. impl = Impl::FUSED_ACTIVATION_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get()); auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
NVTE_SCOPED_GIL_RELEASE( if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); // Post-RHT amax is handled within NVFP4 quantizer
quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); impl = Impl::UNFUSED;
} else { } else {
// Compute activation in high-precision, then quantize impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
}
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); }
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); // Perform compute
quantizer_cpp->quantize(temp_cpp, out_cpp); auto stream = at::cuda::getCurrentCUDAStream();
switch (impl) {
case Impl::UNFUSED:
// Compute activation in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
quantizer_cpp->quantize(temp_nvte, out_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation directly
{
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), out_nvte.data(), stream); });
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
// Compute activation and amax in high precision, then quantize to FP8
{
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_NVFP4:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
default:
NVTE_ERROR("Invalid activation implementation (", static_cast<int>(impl), ")");
} }
return out_py; return out_py;
} }
template <void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)> py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor,
py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, cudaStream_t),
py::handle quantizer) { const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer) {
init_extension(); init_extension();
// Grad output and input tensors // Grad output and input tensors
auto grad_output_tensor = grad_output.contiguous(); auto grad_output_tensor = grad_output.contiguous();
auto input_tensor = input.contiguous(); auto input_tensor = input.contiguous();
const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); const TensorWrapper& grad_output_nvte = makeTransformerEngineTensor(grad_output_tensor);
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor);
// Construct grad input tensor // Construct grad input tensor
auto quantizer_cpp = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape_te = input_cpp.shape(); const auto input_shape_te = input_nvte.shape();
const std::vector<size_t> input_shape(input_shape_te.data, const std::vector<size_t> input_shape(input_shape_te.data,
input_shape_te.data + input_shape_te.ndim); input_shape_te.data + input_shape_te.ndim);
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
// Compute activation backward // Choose implementation
enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 };
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) { detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation backward directly impl = Impl::FULLY_FUSED;
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation backward in high-precision fused together with amax, then quantize. impl = Impl::FUSED_ACTIVATION_AMAX_FP8;
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get()); } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_SCOPED_GIL_RELEASE({ NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
at::cuda::getCurrentCUDAStream()); // Post-RHT amax is handled within NVFP4 quantizer
}); impl = Impl::UNFUSED;
quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); } else {
} else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
// Compute activation backward in high-precision, then quantize }
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); }
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), // Perform compute
at::cuda::getCurrentCUDAStream()); auto stream = at::cuda::getCurrentCUDAStream();
}); switch (impl) {
quantizer_cpp->quantize(temp_cpp, grad_input_cpp); case Impl::UNFUSED:
// Compute activation backward in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
at::cuda::getCurrentCUDAStream());
});
quantizer_cpp->quantize(temp_nvte, grad_input_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation backward directly
{
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream);
});
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
// Compute activation and amax in high precision, then quantize to FP8
{
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); });
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_NVFP4:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); });
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
default:
NVTE_ERROR("Invalid activation implementation (", static_cast<int>(impl), ")");
} }
return grad_input_py; return grad_input_py;
} }
/* GELU and variants*/ } // namespace
/* GELU and variants */
py::object gelu(const at::Tensor& input, py::handle quantizer) { py::object gelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_gelu>(input, quantizer); return activation_forward(nvte_gelu, input, quantizer);
} }
py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgelu>(grad, input, quantizer); return activation_backward(nvte_dgelu, grad, input, quantizer);
} }
py::object geglu(const at::Tensor& input, py::handle quantizer) { py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu>(input, quantizer, 2); return activation_forward(nvte_geglu, input, quantizer, 2);
} }
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgeglu>(grad, input, quantizer); return activation_backward(nvte_dgeglu, grad, input, quantizer);
} }
py::object qgelu(const at::Tensor& input, py::handle quantizer) { py::object qgelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgelu>(input, quantizer); return activation_forward(nvte_qgelu, input, quantizer);
} }
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgelu>(grad, input, quantizer); return activation_backward(nvte_dqgelu, grad, input, quantizer);
} }
py::object qgeglu(const at::Tensor& input, py::handle quantizer) { py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgeglu>(input, quantizer, 2); return activation_forward(nvte_qgeglu, input, quantizer, 2);
} }
py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgeglu>(grad, input, quantizer); return activation_backward(nvte_dqgeglu, grad, input, quantizer);
} }
/* ReLU and variants*/ /* ReLU and variants */
py::object relu(const at::Tensor& input, py::handle quantizer) { py::object relu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_relu>(input, quantizer); return activation_forward(nvte_relu, input, quantizer);
} }
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_drelu>(grad, input, quantizer); return activation_backward(nvte_drelu, grad, input, quantizer);
} }
py::object reglu(const at::Tensor& input, py::handle quantizer) { py::object reglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_reglu>(input, quantizer, 2); return activation_forward(nvte_reglu, input, quantizer, 2);
} }
py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dreglu>(grad, input, quantizer); return activation_backward(nvte_dreglu, grad, input, quantizer);
} }
py::object srelu(const at::Tensor& input, py::handle quantizer) { py::object srelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_srelu>(input, quantizer); return activation_forward(nvte_srelu, input, quantizer);
} }
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsrelu>(grad, input, quantizer); return activation_backward(nvte_dsrelu, grad, input, quantizer);
} }
py::object sreglu(const at::Tensor& input, py::handle quantizer) { py::object sreglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_sreglu>(input, quantizer, 2); return activation_forward(nvte_sreglu, input, quantizer, 2);
} }
py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsreglu>(grad, input, quantizer); return activation_backward(nvte_dsreglu, grad, input, quantizer);
} }
/* Silu and variants*/ /* Silu and variants */
py::object silu(const at::Tensor& input, py::handle quantizer) { py::object silu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_silu>(input, quantizer); return activation_forward(nvte_silu, input, quantizer);
} }
py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsilu>(grad, input, quantizer); return activation_backward(nvte_dsilu, grad, input, quantizer);
} }
py::object swiglu(const at::Tensor& input, py::handle quantizer) { py::object swiglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_swiglu>(input, quantizer, 2); return activation_forward(nvte_swiglu, input, quantizer, 2);
} }
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer); return activation_backward(nvte_dswiglu, grad, input, quantizer);
} }
} // namespace transformer_engine::pytorch
} // namespace pytorch
} // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment