Commit 53fa872c authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_release_v2.8' into release_v2.8

parents 27ddce40 40c69e75
......@@ -87,5 +87,31 @@ constexpr struct Alignment {
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);
template <typename T, typename... Rest>
void hash_combine(int64_t &seed, const T &v, Rest... rest) {
seed ^= std::hash<T>{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
(hash_combine(seed, rest), ...);
}
enum class JAXX_Collective_Op : int64_t {
NONE = 0,
ALL_GATHER = 1,
REDUCE_SCATTER = 2,
};
static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) {
switch (op) {
case JAXX_Collective_Op::ALL_GATHER:
return CommOverlapType::AG;
break;
case JAXX_Collective_Op::REDUCE_SCATTER:
return CommOverlapType::RS;
break;
default:
NVTE_ERROR("Invalid Collective Op ", static_cast<int>(op));
break;
}
}
} // namespace jax
} // namespace transformer_engine
......@@ -180,6 +180,42 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<bool>("is_2x"),
FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type mu_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type,
bool zero_centered_gamma, double epsilon, int64_t sm_margin,
JAXX_Scaling_Mode scaling_mode, bool is_2x) {
return wrapInStreamCapture(
std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("norm_type")
.Attr<bool>("zero_centered_gamma")
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, NVTE_Norm_Type norm_type,
bool zero_centered_gamma, int sm_margin) {
......@@ -305,5 +341,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI,
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type NormBackwardInitializeFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf,
Buffer_Type gamma_buf, Result_Type xgrad_buf,
Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, int64_t norm_type,
bool zero_centered_gamma, int64_t sm_margin) {
return wrapInStreamCapture(std::function(NormBackwardFFI), stream, dz_buf, x_buf, mu_buf,
rsigma_buf, gamma_buf, xgrad_buf, wgrad_buf, dbeta_buf, wkspace_buf,
norm_type, zero_centered_gamma, sm_margin);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardInitializeHandler, NormBackwardInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("norm_type")
.Attr<bool>("zero_centered_gamma")
.Attr<int64_t>("sm_margin"));
} // namespace jax
} // namespace transformer_engine
......@@ -5,6 +5,8 @@
************************************************************************/
#include "../extensions.h"
#include "cgemm_helper.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine {
namespace jax {
......@@ -20,8 +22,12 @@ pybind11::dict Registrations() {
pybind11::dict dict;
// Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_dact_dbias_quantize_ffi"] = EncapsulateFFI(DActLuDBiasQuantizeHandler);
dict["te_act_lu_ffi"] =
pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(ActLuInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ActLuHandler));
dict["te_dact_dbias_quantize_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(DActLuDBiasQuantizeInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler));
// Quantization
dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler);
......@@ -42,9 +48,11 @@ pybind11::dict Registrations() {
// Normalization
dict["te_norm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(NormForwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler));
dict["te_norm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(NormBackwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler));
// Attention
......@@ -57,7 +65,7 @@ pybind11::dict Registrations() {
// GEMM
dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM
......@@ -84,6 +92,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator);
m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams);
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
......@@ -159,6 +169,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values();
pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
.value("NONE", JAXX_Collective_Op::NONE)
.value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER)
.value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER)
.export_values();
}
} // namespace jax
......
......@@ -11,10 +11,12 @@ customizable contracting dimensions for flexible tensor operations.
from typing import Tuple, Sequence
from functools import partial
import warnings
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .quantize import (
ScaledTensorFactory,
ScalingMode,
......@@ -61,8 +63,12 @@ def dense(
kernel: jnp.ndarray,
bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
batch_sequence_transpose: bool = False,
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
output_axes: Tuple[str, ...] = None,
using_global_amax_of_x: bool = False,
collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
......@@ -76,11 +82,20 @@ def dense(
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
if batch_sequence_transpose:
warnings.warn("batch_sequence_transpose is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
......@@ -90,29 +105,30 @@ def dense(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set,
)
return output
@partial(
jax.custom_vjp,
nondiff_argnums=(
3,
4,
5,
),
)
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9))
def _dense(
x,
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
input_axes,
kernel_axes,
quantizer_set,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set, # need to be a diff_arg for DelayedScaling state management
):
"""Internal implementation of dense layer transformation with custom VJP.
......@@ -124,8 +140,12 @@ def _dense(
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
......@@ -136,8 +156,12 @@ def _dense(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set,
)
return output
......@@ -148,8 +172,12 @@ def _dense_fwd_rule(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set,
):
"""Forward pass rule for dense layer transformation.
......@@ -175,6 +203,7 @@ def _dense_fwd_rule(
x,
flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL,
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
......@@ -182,6 +211,7 @@ def _dense_fwd_rule(
kernel,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
......@@ -191,9 +221,12 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set.forward,
)
output = with_sharding_constraint_by_logical_axes(output, output_axes)
if use_bias and tex.gemm_uses_jax_dot():
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
......@@ -212,8 +245,16 @@ def _dense_fwd_rule(
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument
contracting_dims,
batch_sequence_transpose,
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
ctx,
grad,
):
"""Backward pass rule for dense layer transformation.
Returns:
......@@ -228,6 +269,7 @@ def _dense_bwd_rule(
quantizer_set,
flatten_axis_k,
) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, output_axes)
fwd_x_contracting_dims, fwd_k_contracting_dims = map(
tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
......@@ -238,6 +280,7 @@ def _dense_bwd_rule(
is_dbias=use_bias,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP,
)
# GEMM NT
......@@ -254,8 +297,9 @@ def _dense_bwd_rule(
casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set.backward,
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN
# x_non_contracting_dims
......@@ -267,7 +311,10 @@ def _dense_bwd_rule(
casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set
......
......@@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[
return drop_path_shape
# TODO(Phuong): move this function to sharding.py
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
"""
Extend the given Flax logical axis rules with the predefined TransformerLayer's
......
......@@ -21,6 +21,7 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .layernorm import canonicalize_norm_type
from .quantize import (
with_sharding_constraint_by_logical_axes,
......@@ -40,6 +41,7 @@ def layernorm_mlp(
norm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
batch_sequence_transpose: bool = False,
norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
......@@ -48,6 +50,10 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
collective_op_sets: Tuple[tex.CollectiveOpSet] = (
tex.noop_collective_op_set,
tex.noop_collective_op_set,
),
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block.
......@@ -71,6 +77,7 @@ def layernorm_mlp(
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
......@@ -79,6 +86,7 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns:
......@@ -121,6 +129,7 @@ def layernorm_mlp(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -129,12 +138,13 @@ def layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
collective_op_sets,
quantizer_sets,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -146,6 +156,7 @@ def _layernorm_mlp(
norm_type: str,
zero_centered_gamma: bool,
epsilon: float,
batch_sequence_transpose: bool,
norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...],
......@@ -154,6 +165,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
collective_op_sets: Tuple[tex.CollectiveOpSet],
quantizer_sets,
):
"""Internal implementation of layernorm_mlp with custom VJP.
......@@ -173,12 +185,16 @@ def _layernorm_mlp(
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding
kernel_1_axes: Logical axes for first weight matrix sharding
kernel_2_axes: Logical axes for second weight matrix sharding
ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s)
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of quantizer sets
Returns:
......@@ -195,6 +211,7 @@ def _layernorm_mlp(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -203,6 +220,7 @@ def _layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
collective_op_sets,
quantizer_sets,
)
return output
......@@ -219,6 +237,7 @@ def _layernorm_mlp_fwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -227,6 +246,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
collective_op_sets,
quantizer_sets,
):
"""Forward pass rule for layernorm_mlp.
......@@ -246,6 +266,10 @@ def _layernorm_mlp_fwd_rule(
del kernel_1_axes, kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
collective_op_set_1, collective_op_set_2 = collective_op_sets
assert not collective_op_set_1.forward.is_reduce_scatter
assert not collective_op_set_2.forward.is_all_gather
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
......@@ -272,13 +296,12 @@ def _layernorm_mlp_fwd_rule(
epsilon,
norm_type,
quantizer=ffn1_quantizer_set.x,
amax_scope=AmaxScope.TPSP,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(
kernel_1,
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP
)
# NN GEMM
......@@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_1.forward,
)
if use_bias_1 and tex.gemm_uses_jax_dot():
......@@ -317,6 +342,7 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_2 = tex.quantize(
kernel_2,
quantizer=ffn2_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
)
# NN GEMM
......@@ -325,8 +351,10 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_2.forward,
)
if use_bias_2 and tex.gemm_uses_jax_dot():
......@@ -334,6 +362,8 @@ def _layernorm_mlp_fwd_rule(
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
# sharding of outputs should be the same as dot_1's input
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes)
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (
......@@ -363,6 +393,7 @@ def _layernorm_mlp_bwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -371,6 +402,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
collective_op_sets,
ctx,
grad,
):
......@@ -409,6 +441,10 @@ def _layernorm_mlp_bwd_rule(
) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
collective_op_set_1, collective_op_set_2 = collective_op_sets
assert not collective_op_set_1.backward.is_all_gather
assert not collective_op_set_2.backward.is_reduce_scatter
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
......@@ -417,6 +453,7 @@ def _layernorm_mlp_bwd_rule(
grad,
is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......@@ -434,6 +471,8 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set_2.backward,
)
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
......@@ -448,6 +487,7 @@ def _layernorm_mlp_bwd_rule(
casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
......@@ -474,6 +514,8 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set_1.backward,
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
......@@ -484,6 +526,7 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
)
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
......@@ -17,7 +17,7 @@ from functools import reduce, lru_cache
import operator
import numpy as np
from jax.experimental.custom_partitioning import BATCHING
from jax.experimental.custom_partitioning import BATCHING, CompoundFactor
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
......@@ -152,12 +152,15 @@ class ScalingModeMetadataImpl(ABC):
@abstractmethod
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
......@@ -232,12 +235,15 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_groups,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
......@@ -245,7 +251,7 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
......@@ -323,20 +329,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_groups,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
flatten_axis: Axis along which data can be flattened to 2D for quantization
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
......@@ -562,52 +571,55 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_block_x * n_block_y,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
self,
input_shape,
unique_var,
flatten_axis,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = [f"{unique_var}{i}" for i in range(input_rank)]
rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)]
colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)]
# NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors.
# Unfortunately, because Shardy rules are applied to the inner primitive, the
# only way to preserve the relationship is to lower unpadded scales to the
# underlying custom call and pad them in C++. Until that's implemented, the
# Shardy rules for block scales have to be completely disconnected from the
# Shardy rules for the tensor they belong to.
# # We have to use two different factors in the two CompoundFactors because of Shardy
# # verifier requirements, even though they are the same.
# rowwise_var = unique_var
# colwise_var = f"{unique_var}_"
# input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
# input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
# # The rowwise and colwise scale tensors should be sharded the same way as the input.
# # However, we need to adjust the dimensions where the block scaling factor applies.
# rowwise = input_spec.copy()
# rowwise[-1] = rowwise_var
# colwise = input_spec.copy()
# colwise[flatten_axis - 1] = colwise_var
# # This implementation needs to be updated for different block dims.
# assert self._block_dims == (1, 32)
input_rank = len(input_shape)
input_spec = [f"{unique_var}_{i}" for i in range(input_rank)]
flatten_axis = (flatten_axis + input_rank) % input_rank
# This implementation needs to be updated for different block dims.
assert self._block_dims == (1, 32)
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
blocksizes = {}
colwise_var = f"{unique_var}_None"
rowwise_var = f"{unique_var}_None"
if not input_shape[-1] == 32:
rowwise_var = input_spec[-1] + "_compound"
input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x")
blocksizes["blocksize_x"] = 32
if not input_shape[flatten_axis - 1] == 32:
colwise_var = input_spec[flatten_axis - 1] + "_compound"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y")
blocksizes["blocksize_y"] = 32
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var
colwise = input_spec.copy()
colwise[flatten_axis - 1] = colwise_var
return QuantizeShardyRules(
tuple(input_spec),
tuple(rowwise),
tuple(colwise),
{}, # {"block_size_rowwise": 32, "block_size_colwise": 32},
blocksizes,
)
......@@ -697,18 +709,22 @@ class ScalingMode(Enum):
return self._get_impl().get_quantize_layout(usage)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1
self,
input_shape,
unique_var,
flatten_axis=-1,
) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)
return self._get_impl().get_shardy_sharding_rules(input_shape, unique_var, flatten_axis)
def get_grouped_scale_shape_2x(
self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
......
......@@ -13,6 +13,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Optional
import warnings
import jax
import jax.numpy as jnp
from jax.interpreters import pxla
......@@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x
def tpsp_axis_size():
"""
Get the size of the tensor parallelism axis.
Return 1 if no TP axis is set.
"""
return get_mesh_axis_size(global_mesh_resource().tpsp_resource)
def dp_or_fsdp_axis_size():
"""
Get the size of the data parallelism or FSDP axis.
Return 1 if no DP/FSDP axis is set.
"""
dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource)
fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource)
return dp_size if dp_size > 1 else fsdp_size
......@@ -13,17 +13,20 @@ import logging
from packaging.version import Version as PkgVersion
import torch
import torch.nn.functional as F
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import (
SplitAlongDim,
get_device_compute_capability,
combine_tensors,
split_tensor_along_dim,
)
from transformer_engine.pytorch.utils import attention_mask_func
from transformer_engine.pytorch.utils import attention_mask_func, nvtx_range_push, nvtx_range_pop
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
prepare_for_saving,
restore_from_saved,
)
......@@ -40,7 +43,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_O,
META_QKV,
)
from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype
from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype, FP8GlobalStateManager
from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.jit import no_torch_dynamo
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
......@@ -53,6 +56,9 @@ from transformer_engine.pytorch.attention.inference import InferenceParams
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils as fa_utils,
combine_and_quantize,
combine_and_dequantize,
print_quantizers,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log,
......@@ -131,6 +137,58 @@ if not IS_HIP_EXTENSION:
fa_utils.set_flash_attention_3_params()
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1"
class FP8EmulationFunc(torch.autograd.Function):
"""
Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows:
- forward : QKV (quantize+dequantize), P (pass-through), S (quantize+dequantize), O (pass-through)
- backward: dO (quantize+dequantize), dS (pass-through), dP (quantize+dequantize), dQKV (pass-through)
"""
@staticmethod
def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout):
# pylint: disable=missing-function-docstring
if quantizer_name == "QKV_quantizer":
query_layer, key_layer, value_layer = [
x.contiguous() for x in [tensor1, tensor2, tensor3]
]
q_fp8, k_fp8, v_fp8 = combine_and_quantize(
qkv_layout, query_layer, key_layer, value_layer, quantizer
)
tensors = combine_and_dequantize(
qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype
)
elif quantizer_name in ["S_quantizer", "O_quantizer"]:
t_fp8 = quantizer(tensor1)
tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3)
else:
tensors = (tensor1, tensor2, tensor3)
ctx.quantizer = quantizer
ctx.quantizer_name = quantizer_name
ctx.qkv_layout = qkv_layout
return tensors[0], tensors[1], tensors[2]
@staticmethod
def backward(ctx, grad1, grad2, grad3):
# pylint: disable=missing-function-docstring
if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]:
dt_fp8 = ctx.quantizer(grad1)
tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3
elif ctx.quantizer_name == "dQKV_quantizer":
query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]]
dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize(
ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer
)
tensors = combine_and_dequantize(
ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype
)
else:
tensors = grad1, grad2, grad3
return tensors[0], tensors[1], tensors[2], None, None, None
class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
......@@ -144,6 +202,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -151,6 +210,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.softmax_type = softmax_type
def mask_func(x, y):
return (
......@@ -187,6 +247,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None,
fp8_output: bool = False,
) -> torch.Tensor:
"""Unfused attention fprop"""
assert (
......@@ -284,6 +349,35 @@ class UnfusedDotProductAttention(torch.nn.Module):
if apply_qk_layer_scaling:
scale /= self.layer_number
if fp8:
# get quantizers from DPA; all Nones if not fp8
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
dpa_utils.get_attention_quantizers(fp8, quantizers)
)
# S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
if fp8_recipe.float8_current_scaling():
S_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=S_quantizer.dtype, device="cuda"
)
dP_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=dP_quantizer.dtype, device="cuda"
)
if "2" in qkv_layout or "3" in qkv_layout:
qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout)
qkv_layout = "_".join([qkv_format] * 3)
# quantize and dequantize QKV to emulate FP8
query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout
)
# quantize and dequantize dQKV to emulate FP8
query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout
)
# Raw attention scores. [b * np, sq, sk]
if core_attention_bias_type == "no_bias":
matmul_result = torch.baddbmm(
......@@ -328,7 +422,27 @@ class UnfusedDotProductAttention(torch.nn.Module):
dtype=query_layer.dtype
)
# attention scores and attention mask [b, np, sq, sk]
if fp8:
# quantize and dequantize dP to emulate FP8
matmul_result, *_ = FP8EmulationFunc.apply(
matmul_result, None, None, dP_quantizer, "dP_quantizer", None
)
# add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla":
matmul_result = torch.cat(
[
matmul_result,
softmax_offset.to(dtype=matmul_result.dtype).expand(
matmul_result.size(0), -1, matmul_result.size(2), -1
),
],
dim=-1,
)
attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False)
attn_mask_type = "arbitrary"
# attention scores and attention mask
softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(
matmul_result, attention_mask, attn_mask_type, softmax_scale
......@@ -339,6 +453,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0)
# remove attention sink: [b, np, sq, sk]
if self.softmax_type != "vanilla":
attention_probs = attention_probs[..., :-1]
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx():
......@@ -359,6 +477,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
if fp8:
# quantize and dequantize S to emulate FP8
attention_probs, *_ = FP8EmulationFunc.apply(
attention_probs, None, None, S_quantizer, "S_quantizer", None
)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
......@@ -393,6 +517,20 @@ class UnfusedDotProductAttention(torch.nn.Module):
# [tq, np, hn] --> [tq, hp]
context_layer = context_layer.view(total_tokens, -1)
if fp8:
# quantize and dequantize O to emulate FP8
context_layer, *_ = FP8EmulationFunc.apply(
context_layer, None, None, O_quantizer, "O_quantizer", None
)
# quantize and dequantize dO to emulate FP8
context_layer, *_ = FP8EmulationFunc.apply(
context_layer, None, None, dO_quantizer, "dO_quantizer", None
)
# quantize O
if fp8_output:
context_layer = O_quantizer(context_layer)
return context_layer
......@@ -491,6 +629,7 @@ class FlashAttention(torch.nn.Module):
quantizers=None,
inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
fp8_output: bool = False,
) -> torch.Tensor:
"""flash-attn fprop"""
......@@ -696,6 +835,7 @@ class FlashAttention(torch.nn.Module):
quantizers=quantizers,
pad_between_seqs=False,
use_flash_attn_3=use_flash_attn_3,
fp8_output=fp8_output,
)
else:
from transformer_engine.pytorch.cpu_offload import (
......@@ -795,8 +935,6 @@ class FlashAttention(torch.nn.Module):
)
return out
# "fp8_mha" decides outputs in fp8, while inputs are inferred from
# the real dtype
assert isinstance(key_layer, query_layer.__class__) and isinstance(
value_layer, query_layer.__class__
), "q, k, and v must have the same type."
......@@ -843,7 +981,7 @@ class FlashAttention(torch.nn.Module):
if fp8:
output = output.to(dtype=torch_orig_dtype)
if fp8 and fp8_meta["recipe"].fp8_mha:
if fp8 and fp8_output:
O_quantizer = quantizers["scaling_fwd"][META_O]
output = O_quantizer(output)
......@@ -871,7 +1009,7 @@ class FlashAttention(torch.nn.Module):
if q_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha:
if fp8 and fp8_output:
output_data = (
output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
......@@ -895,7 +1033,7 @@ class FlashAttention(torch.nn.Module):
class FusedAttnFunc(torch.autograd.Function):
"""Function for FusedAttention with separate Q, K, V tensors"""
"""FusedAttention forward and backward implementation"""
@staticmethod
def forward(
......@@ -919,6 +1057,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
fused_attention_backend,
......@@ -927,55 +1066,72 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta,
quantizers,
deterministic,
softmax_offset,
fp8_output,
layer_number,
):
# pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn
fake_dtype = q.dtype
# add NVTX range
nvtx_label = "transformer_engine.FusedAttnFunc.forward"
nvtx_range_push(f"{nvtx_label}")
# recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
# input types are inferred from the real data while output types are controlled by fp8_output
# fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha)
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor."
is_input_fp8 = isinstance(q, Float8Tensor)
is_output_fp8 = fp8_output
# whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa)
# whether bwd kernel in FP8:
is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# get quantizers from DPA; all Nones if not fp8
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
dpa_utils.get_attention_quantizers(fp8, quantizers)
)
# get nominal data type for out
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
out_nominal_dtype = q.dtype
if fp8:
fused_attention_backend = FusedAttnBackend["FP8"]
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor)
q_fp8, k_fp8, v_fp8 = None, None, None
# q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16
# q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
if is_input_fp8:
q_fp8, k_fp8, v_fp8 = q, k, v
else:
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_fp8 = QKV_quantizer(qkv)
q_fp8, k_fp8, v_fp8 = SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True)
case 2:
q_fp8 = QKV_quantizer(q)
dim = qkv_layout.split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_fp8 = QKV_quantizer(kv_c)
k_fp8, v_fp8 = SplitAlongDim.apply(kv_fp8, dim, [1, 1], True)
case 3:
q_fp8 = QKV_quantizer(q)
k_fp8 = QKV_quantizer(k)
v_fp8 = QKV_quantizer(v)
case _:
raise "Invalid qkv_layout " + qkv_layout
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
out_fp8, aux_ctx_tensors = fused_attn_fwd(
q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer)
# print quantizers
print_quantizers(
"FusedAttnFunc.forward >> before: ",
layer_number,
QKV_quantizer,
O_quantizer,
S_quantizer,
dQKV_quantizer,
dO_quantizer,
dP_quantizer,
)
# out_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
......@@ -984,7 +1140,7 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8,
k_fp8,
v_fp8,
fake_dtype,
out_nominal_dtype,
fused_attention_backend,
attn_bias,
cu_seqlens_q_padded,
......@@ -999,45 +1155,59 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
softmax_offset,
)
if is_output_fp8:
out_ret = out_fp8
# out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# out: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_fp8 = out_
out = out_
if isinstance(out_, Float8Tensor):
if not is_output_fp8 or not is_bwd_fp8:
out = out_.dequantize().view(out_.shape)
else:
out_ret = out_fp8.dequantize().view(out_fp8.shape)
# is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16
# is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn
out_save = out_ret
if is_output_fp8 or (
is_bwd_fp8
and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)
):
out_fp8 = O_quantizer(out_)
if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
# 1: qkv packed, 2: kv packed, 3: qkv separate
if is_input_fp8:
qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
if qkv_group == 1:
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape)
q, k, v = SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True)
if qkv_group == 2:
q = q.dequantize()
dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_no_fp8 = kv.dequantize()
k, v = SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True)
if qkv_group == 3:
q = q.dequantize()
k = k.dequantize()
v = v.dequantize()
if is_output_fp8:
out_save = out_fp8.dequantize()
# print quantizers
print_quantizers(
"FusedAttnFunc.forward >> after: ",
layer_number,
QKV_quantizer,
O_quantizer,
S_quantizer,
dQKV_quantizer,
dO_quantizer,
dP_quantizer,
)
# return appropriate tensors
out_ret = out_fp8 if is_output_fp8 else out
# save appropriate tensors
fp8_tensors = (None, None, None, None)
qkvo_tensors = (None, None, None, None)
if is_bwd_fp8:
if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16:
fp8_tensors = (q_fp8, k_fp8, v_fp8, None)
qkvo_tensors = (None, None, None, out)
else:
fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
else:
# q, k, v, out_ret: torch.float16 or torch.bfloat16
out_ret, aux_ctx_tensors = fused_attn_fwd(
if is_input_fp8:
q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8)
qkvo_tensors = (q, k, v, out)
else:
# q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
......@@ -1046,7 +1216,7 @@ class FusedAttnFunc(torch.autograd.Function):
q,
k,
v,
fake_dtype,
out_nominal_dtype,
fused_attention_backend,
attn_bias,
cu_seqlens_q_padded,
......@@ -1061,13 +1231,23 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
softmax_offset,
)
out_save = out_ret
out = out_
out_ret = out_
fp8_tensors = (None, None, None, None)
qkvo_tensors = (q, k, v, out)
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
nvtx_range_pop(f"{nvtx_label}")
ctx.fp8_recipe = fp8_recipe
ctx.fp8 = is_bwd_fp8
# assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16
# used when some tensors are base tensors and loose the "dtype" attribute
ctx.nominal_dtype = out_nominal_dtype
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled,
......@@ -1078,7 +1258,7 @@ class FusedAttnFunc(torch.autograd.Function):
if ctx.fp8:
tensor_list = fp8_tensors
else:
tensor_list = [q, k, v, out_save]
tensor_list = [q, k, v, out]
qkv_layout = "sbhd_sbhd_sbhd"
mark_activation_offload(*tensor_list)
......@@ -1086,7 +1266,6 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
tensors_to_save, tensor_objects = prepare_for_saving(
*fp8_tensors,
*qkvo_tensors,
......@@ -1100,11 +1279,14 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.fp8_meta = fp8_meta
ctx.layer_number = layer_number
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.S_quantizer = S_quantizer
if ctx.fp8:
if ctx.fp8 and isinstance(ctx.S_quantizer, Float8Quantizer):
ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone()
......@@ -1116,6 +1298,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type
ctx.window_size = window_size
ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
......@@ -1128,16 +1311,14 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
# pylint: disable=missing-function-docstring
if ctx.is_output_fp8:
assert isinstance(
d_out, Float8Tensor
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2
fake_dtype = d_out.dtype
# d_out is expected to be in FP8 if is_output_fp8=True,
# but in the case it's not, convert it to FP8 before any operation
if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorBase):
d_out = ctx.dO_quantizer(d_out)
if not ctx.use_FAv2_bwd:
d_out._data = d_out._data.contiguous()
elif not ctx.use_FAv2_bwd:
d_out = d_out.contiguous()
(
q_fp8,
......@@ -1192,16 +1373,55 @@ class FusedAttnFunc(torch.autograd.Function):
dk = dk[..., : d_out.shape[-1]]
dv = dv[..., : d_out.shape[-1]]
else:
with torch.cuda.nvtx.range("_FusedAttn"):
with torch.cuda.nvtx.range("FusedAttnFunc.backward"):
# get nominal data type of dq, dk, dv
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
dqkv_nominal_dtype = ctx.nominal_dtype
if ctx.fp8:
# d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16
# d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
if ctx.is_output_fp8:
d_out_fp8 = d_out
else:
d_out_fp8 = ctx.dO_quantizer(d_out)
dqkv_dtype = TE_DType[d_out_fp8._data.dtype]
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
# d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2
dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
# print quantizers
print_quantizers(
"FusedAttnFunc.backward >> before: ",
ctx.layer_number,
ctx.QKV_quantizer,
ctx.O_quantizer,
ctx.S_quantizer,
ctx.dQKV_quantizer,
ctx.dO_quantizer,
ctx.dP_quantizer,
)
# get tex.DType for dq, dk, dv data
dqkv_te_dtype = d_out_fp8._fp8_dtype
# q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16,
# fp8_dtype = tex.DType.kFloat8E4M3
# d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
# out_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
#
# dq_, dk_, dv_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_ = (
out
if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16
else out_fp8
)
dq_, dk_, dv_, *rest = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
cu_seqlens_q,
......@@ -1209,10 +1429,10 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8,
k_fp8,
v_fp8,
out_fp8,
out_,
d_out_fp8,
fake_dtype,
dqkv_dtype,
dqkv_nominal_dtype,
dqkv_te_dtype,
aux_ctx_tensors,
ctx.fused_attention_backend,
cu_seqlens_q_padded,
......@@ -1226,44 +1446,45 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.deterministic,
)
# is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16
# is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2
if not ctx.is_input_fp8:
qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_"))
if qkv_group == 1:
dim = ctx.qkv_layout.find("3")
dqkv_fp8_data = combine_tensors(
[dq_fp8._data, dk_fp8._data, dv_fp8._data], dim
)
dqkv_fp8 = dq_fp8.make_like(
tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape
)
dqkv = dqkv_fp8.dequantize()
dq, dk, dv = SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True)
if qkv_group == 2:
dq = dq_fp8.dequantize()
dim = ctx.qkv_layout.split("_")[1].find("2")
dkv_fp8 = combine_tensors([dk_fp8, dv_fp8], dim)
dkv_c_fp8 = dkv_fp8.view(
-1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
)
dkv = dkv_c_fp8.dequantize()
dk, dv = SplitAlongDim.apply(dkv, dim, [1, 1], True)
if qkv_group == 3:
dq = dq_fp8.dequantize()
dk = dk_fp8.dequantize()
dv = dv_fp8.dequantize()
else:
dq, dk, dv = dq_fp8, dk_fp8, dv_fp8
# dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16
dq, dk, dv = dq_, dk_, dv_
is_float8tensor = isinstance(dq_, Float8Tensor)
if is_float8tensor and not ctx.is_input_fp8:
# return in F16
dq, dk, dv = combine_and_dequantize(
ctx.qkv_layout,
dq_,
dk_,
dv_,
src_nominal_dtype=dq_.dtype,
)
if not is_float8tensor and ctx.is_input_fp8:
# return in FP8
dq, dk, dv = combine_and_quantize(
ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer
)
# print quantizers
print_quantizers(
"FusedAttnFunc.backward >> after: ",
ctx.layer_number,
ctx.QKV_quantizer,
ctx.O_quantizer,
ctx.S_quantizer,
ctx.dQKV_quantizer,
ctx.dO_quantizer,
ctx.dP_quantizer,
)
else:
if isinstance(d_out, QuantizedTensor):
d_out = d_out.dequantize()
dqkv_dtype = TE_DType[d_out.dtype]
# q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16
if isinstance(d_out, QuantizedTensorBase):
d_out = d_out.dequantize(dtype=ctx.nominal_dtype)
dqkv_te_dtype = TE_DType[d_out.dtype]
# q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
......@@ -1274,8 +1495,8 @@ class FusedAttnFunc(torch.autograd.Function):
v,
out,
d_out,
fake_dtype,
dqkv_dtype,
dqkv_nominal_dtype,
dqkv_te_dtype,
aux_ctx_tensors,
ctx.fused_attention_backend,
cu_seqlens_q_padded,
......@@ -1289,12 +1510,17 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.deterministic,
)
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
d_bias = None
if ctx.attn_bias_type not in ["no_bias", "alibi"]:
d_bias = rest[0]
d_softmax_offset = None
if ctx.softmax_type != "vanilla":
d_softmax_offset = rest[1]
return (
None,
None,
......@@ -1308,6 +1534,7 @@ class FusedAttnFunc(torch.autograd.Function):
dq,
dk,
dv,
d_bias,
None,
None,
None,
......@@ -1323,34 +1550,7 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
)
# else, return (dqkv, dbias)
return (
None,
None,
None,
None,
None,
None,
None,
None,
None,
dq,
dk,
dv,
rest[0],
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
d_softmax_offset,
None,
None,
)
......@@ -1392,6 +1592,7 @@ class FusedAttention(torch.nn.Module):
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -1404,6 +1605,7 @@ class FusedAttention(torch.nn.Module):
) == "1" and get_device_compute_capability() == (9, 0)
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.softmax_type = softmax_type
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
......@@ -1455,6 +1657,8 @@ class FusedAttention(torch.nn.Module):
quantizers=None,
pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
fp8_output: bool = False,
) -> torch.Tensor:
"""fused attention fprop"""
assert (
......@@ -1555,14 +1759,26 @@ class FusedAttention(torch.nn.Module):
)
if fp8:
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
" is required for FP8 attention!"
)
assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
"Amax reduction across TP+CP group is necessary when using context parallelism with"
" FP8!"
if fp8_recipe.delayed():
assert not context_parallel or fp8_recipe.reduce_amax, (
"Amax reduction across TP+CP group is necessary when using context parallelism"
" with FP8!"
)
if fp8_recipe.float8_current_scaling() and context_parallel:
all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers)
for q in all_quantizers:
if isinstance(q, Float8CurrentScalingQuantizer):
q.with_amax_reduction = True
q.amax_reduction_group = (
cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group
)
if context_parallel:
......@@ -1605,6 +1821,10 @@ class FusedAttention(torch.nn.Module):
fp8_meta=fp8_meta,
quantizers=quantizers,
pad_between_seqs=pad_between_seqs,
softmax_type=self.softmax_type,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
layer_number=self.layer_number,
)
else:
with self.attention_dropout_ctx():
......@@ -1628,6 +1848,7 @@ class FusedAttention(torch.nn.Module):
qkv_layout,
core_attention_bias_type,
attn_mask_type,
self.softmax_type,
window_size,
None, # rng_gen
fused_attention_backend,
......@@ -1636,6 +1857,9 @@ class FusedAttention(torch.nn.Module):
fp8_meta,
quantizers,
self.deterministic,
softmax_offset,
fp8_output,
self.layer_number,
)
# ...hd -> ...(hd)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -11,10 +11,25 @@ import warnings
import logging
import torch
from torch.nn.parameter import Parameter
import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
Format,
Recipe,
DelayedScaling,
Float8CurrentScaling,
)
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.fp8 import (
get_fp8_te_dtype,
FP8GlobalStateManager,
RecipeState,
DelayedScalingRecipeState,
MXFP8BlockScalingRecipeState,
Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState,
)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.export import is_in_onnx_export_mode
......@@ -72,6 +87,67 @@ _alibi_cache = {
"_alibi_bias_require_update": False,
}
"""
This feature is **experimental** and subject to change.
Some models may use different FP8 recipes for their linear layers and attention layers. To support this,
users can either use multiple, nested fp8_autocast() contexts to assign a distinct recipe for each layer,
or use a single fp8_autocast() for the non-attention layers and configure the recipe for the attention
layers as follows.
+-------------------+-----------+-----------------------------------------------------------------------------------+
| Linear | Attention | Configuration |
+===================+===========+===================================================================================+
| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to fp8_autocast(); |
| | | export NVTE_DPA_FP8_RECIPE="F16" |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8DS | Pass FP8DS to fp8_autocast(); |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8DS | Pass FP8CS to fp8_autocast(); |
| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8DS | Pass NVFP4 to fp8_autocast(); |
| | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8CS | Pass FP8DS to fp8_autocast(); |
| | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,|
| | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8CS | Pass FP8CS to fp8_autocast(); |
| | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe |
| | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8CS | Pass NVFP4 to fp8_autocast(); |
| | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe |
| | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
"""
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2}
_dpa_fp8_format = formats[os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")]
_dpa_fp8ds_amax_algo = os.getenv("NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent")
_dpa_fp8ds_amax_histlen = int(os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1"))
_dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1"
__all__ = ["DotProductAttention"]
......@@ -168,6 +244,17 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
......@@ -223,6 +310,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -307,6 +395,20 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type
self.attention_dropout = attention_dropout
self.softmax_type = softmax_type
if self.softmax_type == "vanilla":
self.softmax_offset = None
if self.softmax_type == "off-by-one":
self.softmax_offset = torch.zeros(
self.num_attention_heads // self.tp_size, device="cuda"
)
if self.softmax_type == "learnable":
self.register_parameter(
"softmax_offset",
Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")),
get_rng_state_tracker=get_rng_state_tracker,
)
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
......@@ -328,6 +430,7 @@ class DotProductAttention(TransformerEngineBaseModule):
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs,
softmax_type=self.softmax_type,
)
self.unfused_attention = UnfusedDotProductAttention(
......@@ -335,6 +438,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type=attention_type,
**attn_kwargs,
layer_number=layer_number,
softmax_type=self.softmax_type,
)
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
......@@ -433,6 +537,231 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_stream = cp_stream
self.cp_comm_type = cp_comm_type
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""
Override TransformerEngineBaseModule.init_fp8_metadata to allow for more flexible recipe support.
Initialize fp8 related metadata and tensors during fprop.
"""
_original_recipe = self.fp8_meta.get("recipe", None)
# global recipe set in fp8_autocast()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
# switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to
# a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well.
#
# fp8_recipe | NVTE_DPA_FP8_RECIPE | self.fp8_meta["recipe"] | self.quantizers
# --------------------------------------------------------------------------------------------
# DelayedScaling (DS) | unset | DS | all DS
# Float8CurrentScaling (CS) | unset | DS | CS for QKV, O, dO, dQKV; DS for S, dP
# x={DS, CS} | y | refer to row x=y | refer to row x=y
fp8_recipe_dpa = fp8_recipe
fp8_recipes = fp8_recipe
if _dpa_fp8_recipe == "F16":
# ignore the recipe from fp8_autocast, set fp8_dpa = False, fp8_mha = False
fp8_recipe.fp8_dpa = False
fp8_recipe.fp8_mha = False
elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling":
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe
fake_recipe = DelayedScaling(
fp8_format=fp8_recipe.fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = fp8_recipe_dpa
elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "DelayedScaling":
# reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a DS recipe
fake_recipe = DelayedScaling(
fp8_format=_dpa_fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = fp8_recipe_dpa
elif fp8_recipe.delayed() and _dpa_fp8_recipe == "Float8CurrentScaling":
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe
fake_recipes = [
Float8CurrentScaling(
fp8_format=fp8_recipe.fp8_format,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
),
fp8_recipe,
]
fp8_recipe_dpa = fake_recipes[1]
fp8_recipes = fake_recipes
elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe in (
"",
"Float8CurrentScaling",
):
# use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe
fake_recipe = DelayedScaling(
fp8_format=fp8_recipe.fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = [fp8_recipe, fp8_recipe_dpa]
elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling":
# reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format
# construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP
fake_recipes = [
Float8CurrentScaling(
fp8_format=_dpa_fp8_format,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
),
DelayedScaling(
fp8_format=_dpa_fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
),
]
fp8_recipe_dpa = fake_recipes[1]
fp8_recipes = fake_recipes
# DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False
if not fp8_recipe_dpa.float8_per_tensor_scaling():
assert not (
fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha
), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe"
# reduce over TP+CP groups; expect fp8_group to be set up so
# assume attention uses the same fp8_group as GEMMs
fp8_group = FP8GlobalStateManager.get_fp8_group()
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled:
self.fp8_meta["global_recipe"] = fp8_recipe
self.fp8_meta["local_recipes"] = (
fp8_recipes if isinstance(fp8_recipes, List) else [fp8_recipes]
)
if self.fp8_parameters or fp8_enabled:
if self.fp8_initialized and fp8_recipe_dpa == self.fp8_meta["recipe"]:
# FP8 init has already been run and recipe is the same, don't do anything.
return
self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
# fp8_recipe has changed, rehash the key.
autocast_key = FP8GlobalStateManager.get_unique_autocast_key(
fp8_recipe_dpa, fp8_group
)
FP8GlobalStateManager.autocast_arguments[autocast_key] = (
fp8_recipe_dpa,
fp8_group,
)
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
if self.fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(fp8_recipes)
if fp8_enabled:
# Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = fp8_group
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors(fp8_recipes)
self.fp8_initialized = True
self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
# fp8_recipe has changed, rehash the key.
autocast_key = FP8GlobalStateManager.get_unique_autocast_key(
fp8_recipe_dpa, fp8_group
)
FP8GlobalStateManager.autocast_arguments[autocast_key] = (
fp8_recipe_dpa,
fp8_group,
)
_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
):
warnings.warn(
f"Recipe type changed from {_original_recipe.__class__.__name__} "
f"to {_current_recipe.__class__.__name__}. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None:
"""Override to allow multiple recipes. Init scales and amaxes for fwd | bwd."""
if isinstance(recipe, Recipe):
recipe = [recipe]
fp8_recipe_dpa = recipe[-1]
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
# Return early if recipe state matches recipe
if self.fp8_meta_tensors_initialized:
recipe_state = self.fp8_meta[fp8_meta_tensor_key]
if fp8_recipe_dpa.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
self.adjust_amax_history_length(fp8_recipe_dpa.amax_history_len, fwd=fwd)
return
if fp8_recipe_dpa.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
return
if fp8_recipe_dpa.float8_current_scaling() and isinstance(
recipe_state, Float8CurrentScalingRecipeState
):
return
if fp8_recipe_dpa.float8_block_scaling() and isinstance(
recipe_state, Float8BlockScalingRecipeState
):
return
# When fp8_recipe=Float8CurrentScaling, recipe=[CS, DS], and QKV/dQKV, O/dO use CS quantizers, S/dP use DS quantizers.
# See table above in init_fp8_metadata for more detail.
num_gemms = [2, 1] if len(recipe) == 2 else [3]
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = [x * 3 if fwd else x * 2 for x in num_gemms]
# Initialize recipe state and quantizers
recipe_states = [
RecipeState.create(
recipe[i],
mode=("forward" if fwd else "backward"),
num_quantizers=num_fp8_tensors[i],
)
for i in range(len(recipe))
]
self.fp8_meta[fp8_meta_tensor_key] = (
recipe_states[-1] if len(recipe) == 2 else recipe_states[0]
)
self.quantizers[fp8_meta_tensor_key] = []
for recipe_state in recipe_states:
self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers())
@no_torch_dynamo(recursive=False)
def forward(
self,
......@@ -456,6 +785,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
......@@ -628,12 +958,15 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False`
Whether to enforce output to be in FP8 or not.
"""
with torch.cuda.device(query_layer.device), self.prepare_forward(
query_layer,
num_gemms=3,
allow_non_contiguous=True,
allow_different_data_and_param_types=self.softmax_type != "vanilla",
) as query_layer:
# checks for RNG
if self.rng_states_tracker is not None and is_graph_capturing():
......@@ -663,6 +996,8 @@ class DotProductAttention(TransformerEngineBaseModule):
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
else:
fp8_output = False
# checks for q/k/v shapes
assert (
......@@ -922,6 +1257,7 @@ class DotProductAttention(TransformerEngineBaseModule):
False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
# check if there is padding between sequences when qkv_format='thd'
if pad_between_seqs is None:
if qkv_format == "thd":
pad_between_seqs = (
......@@ -957,11 +1293,13 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout,
context_parallel=context_parallel,
cp_comm_type=self.cp_comm_type,
deterministic=self.deterministic,
is_training=self.training,
fp8=self.fp8,
fp8_meta=self.fp8_meta,
inference_params=inference_params,
softmax_type=self.softmax_type,
)
global _attention_backends
if is_in_onnx_export_mode():
......@@ -1022,6 +1360,12 @@ class DotProductAttention(TransformerEngineBaseModule):
)
# run attention
softmax_offset = (
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
if self.softmax_offset is not None
else None
)
if use_flash_attention:
if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi(
......@@ -1053,6 +1397,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers,
inference_params=inference_params,
flash_attention_backend=flash_attention_backend,
fp8_output=fp8_output,
)
if use_fused_attention:
......@@ -1071,7 +1416,6 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
)
# checkpoint_core_attention=False
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.fused_attention,
......@@ -1101,6 +1445,8 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
)
return self.fused_attention(
query_layer,
......@@ -1129,6 +1475,8 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
)
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
......@@ -1140,6 +1488,7 @@ class DotProductAttention(TransformerEngineBaseModule):
)
if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
......@@ -1157,6 +1506,11 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
inference_params=inference_params,
softmax_offset=softmax_offset,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
fp8_output=fp8_output,
)
return self.unfused_attention(
_alibi_cache,
......@@ -1173,5 +1527,10 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
inference_params=inference_params,
softmax_offset=softmax_offset,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
fp8_output=fp8_output,
)
return None
......@@ -17,6 +17,7 @@ import numpy as np
from packaging.version import Version as PkgVersion
import torch
import torch.distributed as dist
import torch.nn.functional as F
import transformer_engine_torch as tex
import transformer_engine as te
......@@ -24,6 +25,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout,
AttnBiasType,
AttnMaskType,
SoftmaxType,
FusedAttnBackend,
META_QKV,
META_DQKV,
......@@ -31,11 +33,13 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_DO,
META_S,
META_DP,
META_O_CP,
META_DQKV_CP,
)
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType
from torch.utils.cpp_extension import IS_HIP_EXTENSION
......@@ -43,6 +47,8 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
SplitAlongDim,
combine_tensors,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
......@@ -53,6 +59,9 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
# print quantizer info for a particular layer on a particular rank
_print_layer = int(os.getenv("NVTE_PRINT_LAYER_NUMBER", "1"))
_print_rank = int(os.getenv("NVTE_PRINT_RANK", "0"))
_cu_seqlens_cache = {}
......@@ -206,6 +215,8 @@ class AttentionParams:
Attention dropout.
context_parallel: bool, default = `False`
Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
......@@ -216,6 +227,8 @@ class AttentionParams:
The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
......@@ -237,11 +250,13 @@ class AttentionParams:
pad_between_seqs: bool = False
attention_dropout: float = 0.0
context_parallel: bool = False
cp_comm_type: str = "p2p"
deterministic: bool = False
is_training: bool = True
fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla"
def __eq__(self, other):
"""
......@@ -308,11 +323,13 @@ def get_attention_backend(
pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel
cp_comm_type = attention_params.cp_comm_type
deterministic = attention_params.deterministic
is_training = attention_params.is_training
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type
# Run config
logger = logging.getLogger("DotProductAttention")
......@@ -341,8 +358,31 @@ def get_attention_backend(
field.name: getattr(attention_params, field.name) for field in fields(attention_params)
}
run_config.update(attention_params_dict)
# Add FP8 environment variables to config
if fp8:
# all FP8 recipes: 1: (FP8 fwd, FP8 bwd), 0: (FP8 fwd, F16 bwd)
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# Float8CurrentScaling: 1: use F16 O in bwd, 0: use FP8 O in bwd
run_config["NVTE_DPA_FP8CS_O_in_F16"] = int(os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1"))
# switch recipe to "F16", "DelayedScaling", or "Float8CurrentScaling"
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
run_config["NVTE_DPA_FP8_RECIPE"] = _dpa_fp8_recipe
if _dpa_fp8_recipe != "":
# config new recipe if switched
run_config["NVTE_DPA_FP8_FORMAT"] = os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")
run_config["NVTE_DPA_FP8DS_AMAX_ALGO"] = os.getenv(
"NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent"
)
run_config["NVTE_DPA_FP8DS_AMAX_HISTLEN"] = int(
os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1")
)
run_config["NVTE_DPA_FP8DS_REDUCE_AMAX"] = int(
os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1")
)
# UnfusedDotProductAttention: 1: allow FP8 emulation, 0: do not allow
run_config["NVTE_UnfusedDPA_Emulate_FP8"] = int(
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0")
)
logger.debug("Running with config=%s", run_config)
# The following sections check if `FlashAttention` supports the provided attention params,
......@@ -422,8 +462,20 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for FP8 training")
use_flash_attention_3 = False
if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if not allow_emulation:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
use_unfused_attention = False
fp8_recipe = fp8_meta["recipe"]
if fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
if (
use_fused_attention
and fp8_recipe.float8_current_scaling()
and device_compute_capability < (10, 0)
):
logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100")
use_fused_attention = False
# TODO: rocm fused attention backends does not support fp8 yet
if IS_HIP_EXTENSION and use_fused_attention:
logger.debug("Disabling ROCm FusedAttention as it does not support FP8")
......@@ -581,6 +633,51 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for dropout")
use_flash_attention_3 = False
# Filter: Softmax type
# context_parallel | softmax_type | supported backends
# ----------------------------------------------------------------------------------------------------
# no | vanilla | All
# no | off-by-one | FusedAttention, UnfusedDotProductAttention
# no | learnable | FusedAttention, UnfusedDotProductAttention
# yes | vanilla | FusedAttention, FlashAttention
# yes | off-by-one | FusedAttention
# yes | learnable | FusedAttention
if softmax_type != "vanilla":
logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type)
use_flash_attention = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type
)
use_unfused_attention = False
if qkv_format == "thd":
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type,
)
use_unfused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
" cp_comm_type = %s",
softmax_type,
cp_comm_type,
)
use_fused_attention = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# ----------------------------------------------------------------------------------------------------
......@@ -822,6 +919,7 @@ def get_attention_backend(
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
attention_dropout,
num_heads,
num_gqa_groups,
......@@ -1836,11 +1934,10 @@ def check_set_window_size(
return window_size
def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
def get_attention_quantizers(fp8, quantizers):
"""Get the list of quantizers used in attention from the quantizers list."""
if not fp8:
num_of_nones = 8 if cp_specific_quantizers else 6
return [None] * num_of_nones
return [None] * 6
QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
QKV_quantizer.internal = True
QKV_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -1849,6 +1946,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
S_quantizer = quantizers["scaling_fwd"][META_S]
S_quantizer.internal = True
S_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV]
dQKV_quantizer.interal = True
dQKV_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -1858,22 +1956,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
dP_quantizer = quantizers["scaling_bwd"][META_DP]
dP_quantizer.set_usage(rowwise=True, columnwise=False)
dP_quantizer.interal = True
dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP]
dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_CP_quantizer.internal = True
O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP]
O_CP_quantizer.set_usage(rowwise=True, columnwise=False)
if cp_specific_quantizers:
return (
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
def print_quantizers(
label,
layer_number,
QKV_quantizer,
O_quantizer,
O_CP_quantizer,
S_quantizer,
dQKV_quantizer,
dQKV_CP_quantizer,
dO_quantizer,
dP_quantizer,
):
"""Print the type and scale/amax of attention quantizers"""
_to_print = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL == 2
if (
_to_print
and _print_layer == layer_number
and (
not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == _print_rank)
)
):
names = [
"QKV_quantizer",
"S_quantizer",
"O_quantizer",
"dO_quantizer",
"dP_quantizer",
"dQKV_quantizer",
]
quantizers = [
QKV_quantizer,
S_quantizer,
O_quantizer,
dO_quantizer,
dP_quantizer,
dQKV_quantizer,
]
if "forward" in label:
names = names[:3]
quantizers = quantizers[:3]
if "backward" in label:
names = names[3:]
quantizers = quantizers[3:]
for i, q in enumerate(quantizers):
type_str = ""
if q is None:
type_str = "None"
elif isinstance(q, Float8Quantizer):
type_str = "DS"
elif isinstance(q, Float8CurrentScalingQuantizer):
type_str = "CS"
print(
f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x"
f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}"
)
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer):
"""Combine q,k,v based on qkv_layout and quantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout = qkv_layout.replace("paged_kv_", "")
qkv_group = len(qkv_layout.split("_"))
src_nominal_dtype = q.dtype
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_fp8 = qkv_quantizer(qkv)
q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
tensors = [q, kv]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
q_data, kv_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
k_data, v_data = SplitAlongDim.apply(kv_data, dim, [1, 1], True)
case 3:
tensors = [q, k, v]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
q_data, k_data, v_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
case _:
raise RuntimeError("Invalid qkv_layout " + qkv_layout)
q_fp8, k_fp8, v_fp8 = [
Float8Tensor.make_like(qkv_fp8, data=x, dtype=src_nominal_dtype)
for x in [q_data, k_data, v_data]
]
return q_fp8, k_fp8, v_fp8
def combine_and_dequantize(
qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None
):
"""Combine q,k,v based on qkv_layout and dequantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout = qkv_layout.replace("paged_kv_", "")
qkv_group = len(qkv_layout.split("_"))
if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]):
src_nominal_dtype = q_fp8.dtype
else:
assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!"
if des_nominal_dtype is None:
des_nominal_dtype = src_nominal_dtype
q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]]
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv_data = combine_tensors([q_data, k_data, v_data], dim)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, k, v = SplitAlongDim.apply(qkv, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
kv_data = combine_tensors([k_data, v_data], dim)
tensors = [q_data, kv_data]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv_data = torch.cat([x.reshape(-1) for x in tensors], dim=0)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, kv = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)]
k, v = SplitAlongDim.apply(kv, dim, [1, 1], True)
case 3:
tensors = [q_data, k_data, v_data]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv_data = torch.cat([x.contiguous().reshape(-1) for x in tensors], dim=0)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, k, v = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)]
case _:
raise RuntimeError("Invalid qkv_layout " + qkv_layout)
return q, k, v
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Multi-head Attention."""
import os
import collections
from typing import Callable, List, Optional, Tuple, Union
import torch
......@@ -31,7 +32,13 @@ from transformer_engine.pytorch.distributed import (
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
# Force DotProductAttention to use a different recipe than the fp8_recipe set in fp8_autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
_dpa_fp8_recipe_dpa = os.getenv("NVTE_DPA_FP8_RECIPE_DPA", "0") == "1"
_dpa_fp8_recipe_mha = os.getenv("NVTE_DPA_FP8_RECIPE_MHA", "0") == "1"
class MultiheadAttention(torch.nn.Module):
......@@ -135,6 +142,17 @@ class MultiheadAttention(torch.nn.Module):
For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
......@@ -245,6 +263,7 @@ class MultiheadAttention(torch.nn.Module):
qk_norm_before_rope: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -262,6 +281,7 @@ class MultiheadAttention(torch.nn.Module):
self.return_bias = return_bias
self.cp_size = 1
self.cp_rank = 0
self.softmax_type = softmax_type
kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
......@@ -416,6 +436,7 @@ class MultiheadAttention(torch.nn.Module):
tp_group=tp_group,
layer_number=self.layer_number,
attention_type=self.attention_type,
softmax_type=self.softmax_type,
)
# Linear
......@@ -556,10 +577,12 @@ class MultiheadAttention(torch.nn.Module):
self.cp_size = get_distributed_world_size(cp_group)
self.cp_rank = get_distributed_rank(cp_group)
elif isinstance(cp_group, list):
assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
assert (
cp_comm_type == "a2a+p2p"
), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
assert (
len(cp_group) == 2
), "cp_comm_type = a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!"
cp_size_a2a = get_distributed_world_size(cp_group[0])
cp_rank_a2a = get_distributed_rank(cp_group[0])
cp_size_p2p = get_distributed_world_size(cp_group[1])
......@@ -716,10 +739,22 @@ class MultiheadAttention(torch.nn.Module):
# Query, Key, and Value
# ======================
fp8_mha = (
FP8GlobalStateManager.is_fp8_enabled()
and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
)
fp8 = FP8GlobalStateManager.is_fp8_enabled()
if _dpa_fp8_recipe == "":
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
fp8_dpa = fp8_recipe.fp8_dpa
fp8_mha = fp8_recipe.fp8_mha
float8_current_scaling = fp8_recipe.float8_current_scaling()
else:
fp8_dpa = _dpa_fp8_recipe_dpa
fp8_mha = _dpa_fp8_recipe_mha
float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling"
# QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe
qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling
# DPA: always produce FP8 output when fp8=True to take advantage of the O amax
dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha)
# Proj Gemm: match DPA output except for Float8CurrentScaling
proj_fp8_grad = dpa_fp8_output and not float8_current_scaling
layernorm_output = None
if self.attention_type == "self":
......@@ -728,7 +763,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
fp8_output=qkv_fp8_output,
)
if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs
......@@ -738,7 +773,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_x_layer = self.qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
fp8_output=qkv_fp8_output,
)
num_queries_per_key_value = (
......@@ -792,7 +827,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer = self.key_value(
encoder_output,
is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
fp8_output=qkv_fp8_output,
)
if self.qkv_weight_interleaved:
......@@ -847,7 +882,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_query_outputs = self.layernorm_query(
hidden_states,
is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
fp8_output=qkv_fp8_output,
)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
......@@ -857,7 +892,7 @@ class MultiheadAttention(torch.nn.Module):
query_layer = self.query_layer(
hidden_states,
is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
fp8_output=qkv_fp8_output,
)
# [sq, b, hp] --> [sq, b, np, hn]
......@@ -958,6 +993,7 @@ class MultiheadAttention(torch.nn.Module):
fast_zero_fill=fast_zero_fill,
inference_params=inference_params,
pad_between_seqs=pad_between_seqs,
fp8_output=dpa_fp8_output,
)
# ===================
......@@ -966,7 +1002,7 @@ class MultiheadAttention(torch.nn.Module):
projection_output = self.proj(
context_layer,
is_first_microbatch=is_first_microbatch,
fp8_grad=isinstance(context_layer, QuantizedTensor),
fp8_grad=proj_fp8_grad,
)
if self.return_bias:
......
......@@ -91,3 +91,5 @@ GemmParallelModes = ("row", "column", None)
dist_group_type = torch.distributed.ProcessGroup
MXFP8_BLOCK_SCALING_SIZE = 32
NVFP4_BLOCK_SCALING_SIZE = 16
......@@ -12,6 +12,7 @@ from transformer_engine_torch import (
NVTE_QKV_Format,
NVTE_Bias_Type,
NVTE_Mask_Type,
NVTE_Softmax_Type,
NVTE_Fused_Attn_Backend,
)
from ..tensor.quantized_tensor import Quantizer
......@@ -86,6 +87,12 @@ AttnMaskType = {
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
}
SoftmaxType = {
"vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX,
"off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX,
"learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX,
}
FusedAttnBackend = {
"F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
......@@ -102,9 +109,6 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT
META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1
def fused_attn_fwd(
......@@ -131,8 +135,10 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input.
......@@ -197,6 +203,8 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
......@@ -205,6 +213,9 @@ def fused_attn_fwd(
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
Returns
----------
......@@ -286,6 +297,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
cu_seqlens_q,
cu_seqlens_kv,
......@@ -300,6 +312,7 @@ def fused_attn_fwd(
s_quantizer,
o_quantizer,
attn_bias,
softmax_offset,
rng_gen,
rng_elts_per_thread,
)
......@@ -333,6 +346,7 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -398,6 +412,8 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
......@@ -417,6 +433,9 @@ def fused_attn_bwd(
d_bias: torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
"""
if attn_scale is None:
d = q.size(-1)
......@@ -454,6 +473,7 @@ def fused_attn_bwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
deterministic,
cu_seqlens_q,
......
......@@ -20,6 +20,8 @@ from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor.utils import is_experimental
from ..experimental.gemm import experimental_gemm
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
......@@ -169,6 +171,24 @@ def general_gemm(
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")
# If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation
if is_experimental(A) or is_experimental(B):
return experimental_gemm(
A,
B,
workspace,
out_dtype,
quantization_params,
gelu,
gelu_in,
accumulate,
layout,
out,
bias,
use_split_accumulator,
grad,
)
debug_quantizer = None
if isinstance(quantization_params, DebugQuantizer):
debug_quantizer = quantization_params
......
......@@ -12,6 +12,20 @@
namespace transformer_engine::pytorch {
/*! convert fp4 data shape back to original shape */
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose) {
std::vector<size_t> ret;
size_t start_idx = (transpose) ? 1 : 0;
for (size_t i = start_idx; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
}
ret.push_back(shape.back() * 2);
if (transpose) {
ret.push_back(shape.front());
}
return ret;
}
std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
......@@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) {
return ((value + multiple - 1) / multiple) * multiple;
}
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
at::cuda::getCurrentCUDAStream());
});
}
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread) {
at::PhiloxCudaState philox_args;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
}
} // namespace transformer_engine::pytorch
......@@ -35,6 +35,7 @@
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_router.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/hadamard_transform.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h>
......@@ -212,20 +213,25 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct a high precision tensor giving it this quantizer's amax
Note: this member function also zeros out the amax, as it is meant to be used in conjunction with
a kernel computing the amax, which might expect the amax to be initialized to zero
/*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer.
*
* The amax is zeroed out. Most TE kernels that output amax expect
* amax to be initialized to zero.
*/
std::pair<TensorWrapper, py::object> create_hp_tensor_with_amax(const std::vector<size_t>& shape,
DType dtype);
std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data = std::nullopt);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Convert to a quantized data format avoiding amax computation */
/*! @brief Quantize to FP8, skipping local amax computation
*
* The quantizer's amax pointer is assumed to already hold the local
* amax. The amax may still be reduced across the amax reduction
* group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt);
......@@ -295,6 +301,60 @@ class MXFP8Quantizer : public Quantizer {
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
};
class NVFP4Quantizer : public Quantizer {
public:
// fp4 dtype
DType dtype;
// amax reduction for low precision FP4 AG
bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
// random hadamard transform
bool with_rht;
bool with_post_rht_amax;
// 2D block scaling
bool with_2d_quantization;
bool stochastic_rounding;
int rht_matrix_random_sign_mask_t;
at::Tensor rht_matrix;
explicit NVFP4Quantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; }
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer
*
* The amax is zeroed out. Most TE kernels that output amax expect
* amax to be initialized to zero.
*/
std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
TensorWrapper& quantized_tensor, DType dtype);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Quantize to NVFP4, skipping local amax computation
*
* The input tensor's amax pointer is assumed to already hold the
* local amax. The amax may still be reduced across the amax
* reduction group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out);
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
private:
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
};
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
std::vector<size_t> getTensorShape(const at::Tensor& t);
......@@ -445,6 +505,15 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
size_t roundup(const size_t value, const size_t multiple);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose);
// unpack the PhiloxCudaState into CUDA tensor
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr);
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread);
} // namespace transformer_engine::pytorch
namespace std {
......
......@@ -73,28 +73,36 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
NVTE_Fused_Attn_Backend get_fused_attn_backend(
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
bool create_hp_tensor_for_cs,
std::optional<at::Tensor> data);
std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const std::optional<at::Tensor> cu_seqlens_q_padded,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
......
......@@ -8,179 +8,269 @@
#include "common.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
namespace transformer_engine {
namespace pytorch {
template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)>
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) {
namespace {
py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t),
const at::Tensor& input, py::handle quantizer,
int shape_divisor = 1) {
init_extension();
// Input tensor
auto input_tensor = input.contiguous();
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor);
// Construct output tensor
auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape = input_cpp.shape();
const auto input_shape = input_nvte.shape();
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
output_shape.back() /= shape_divisor;
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
// Compute activation
// Choose implementation
enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 };
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation directly
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
impl = Impl::FULLY_FUSED;
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation in high-precision fused together with amax, then quantize.
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp);
impl = Impl::FUSED_ACTIVATION_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else {
// Compute activation in high-precision, then quantize
impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
}
}
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
quantizer_cpp->quantize(temp_cpp, out_cpp);
// Perform compute
auto stream = at::cuda::getCurrentCUDAStream();
switch (impl) {
case Impl::UNFUSED:
// Compute activation in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
quantizer_cpp->quantize(temp_nvte, out_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation directly
{
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), out_nvte.data(), stream); });
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
// Compute activation and amax in high precision, then quantize to FP8
{
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_NVFP4:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
default:
NVTE_ERROR("Invalid activation implementation (", static_cast<int>(impl), ")");
}
return out_py;
}
template <void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input,
py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor,
cudaStream_t),
const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer) {
init_extension();
// Grad output and input tensors
auto grad_output_tensor = grad_output.contiguous();
auto input_tensor = input.contiguous();
const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor);
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
const TensorWrapper& grad_output_nvte = makeTransformerEngineTensor(grad_output_tensor);
const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor);
// Construct grad input tensor
auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape_te = input_cpp.shape();
const auto input_shape_te = input_nvte.shape();
const std::vector<size_t> input_shape(input_shape_te.data,
input_shape_te.data + input_shape_te.ndim);
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
// Compute activation backward
// Choose implementation
enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 };
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation backward directly
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
impl = Impl::FULLY_FUSED;
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation backward in high-precision fused together with amax, then quantize.
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype);
impl = Impl::FUSED_ACTIVATION_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else {
impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
}
}
// Perform compute
auto stream = at::cuda::getCurrentCUDAStream();
switch (impl) {
case Impl::UNFUSED:
// Compute activation backward in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
at::cuda::getCurrentCUDAStream());
});
quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp);
} else {
// Compute activation backward in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
quantizer_cpp->quantize(temp_nvte, grad_input_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation backward directly
{
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
at::cuda::getCurrentCUDAStream());
dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream);
});
quantizer_cpp->quantize(temp_cpp, grad_input_cpp);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
// Compute activation and amax in high precision, then quantize to FP8
{
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); });
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_NVFP4:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); });
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
default:
NVTE_ERROR("Invalid activation implementation (", static_cast<int>(impl), ")");
}
return grad_input_py;
}
/* GELU and variants*/
} // namespace
/* GELU and variants */
py::object gelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_gelu>(input, quantizer);
return activation_forward(nvte_gelu, input, quantizer);
}
py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgelu>(grad, input, quantizer);
return activation_backward(nvte_dgelu, grad, input, quantizer);
}
py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu>(input, quantizer, 2);
return activation_forward(nvte_geglu, input, quantizer, 2);
}
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgeglu>(grad, input, quantizer);
return activation_backward(nvte_dgeglu, grad, input, quantizer);
}
py::object qgelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgelu>(input, quantizer);
return activation_forward(nvte_qgelu, input, quantizer);
}
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgelu>(grad, input, quantizer);
return activation_backward(nvte_dqgelu, grad, input, quantizer);
}
py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgeglu>(input, quantizer, 2);
return activation_forward(nvte_qgeglu, input, quantizer, 2);
}
py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgeglu>(grad, input, quantizer);
return activation_backward(nvte_dqgeglu, grad, input, quantizer);
}
/* ReLU and variants*/
/* ReLU and variants */
py::object relu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_relu>(input, quantizer);
return activation_forward(nvte_relu, input, quantizer);
}
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_drelu>(grad, input, quantizer);
return activation_backward(nvte_drelu, grad, input, quantizer);
}
py::object reglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_reglu>(input, quantizer, 2);
return activation_forward(nvte_reglu, input, quantizer, 2);
}
py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dreglu>(grad, input, quantizer);
return activation_backward(nvte_dreglu, grad, input, quantizer);
}
py::object srelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_srelu>(input, quantizer);
return activation_forward(nvte_srelu, input, quantizer);
}
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsrelu>(grad, input, quantizer);
return activation_backward(nvte_dsrelu, grad, input, quantizer);
}
py::object sreglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_sreglu>(input, quantizer, 2);
return activation_forward(nvte_sreglu, input, quantizer, 2);
}
py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsreglu>(grad, input, quantizer);
return activation_backward(nvte_dsreglu, grad, input, quantizer);
}
/* Silu and variants*/
/* Silu and variants */
py::object silu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_silu>(input, quantizer);
return activation_forward(nvte_silu, input, quantizer);
}
py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsilu>(grad, input, quantizer);
return activation_backward(nvte_dsilu, grad, input, quantizer);
}
py::object swiglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_swiglu>(input, quantizer, 2);
return activation_forward(nvte_swiglu, input, quantizer, 2);
}
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer);
return activation_backward(nvte_dswiglu, grad, input, quantizer);
}
} // namespace transformer_engine::pytorch
} // namespace pytorch
} // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment