"vscode:/vscode.git/clone" did not exist on "b7c568326c969c59a5f90e4731dc5b91f260c6f0"
Unverified Commit 71e51eae authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Migrating from Xmap to Custom Partitioning for All Custom Calls (#472)



* Refactor sharding.py for the further custom_partitioning migration
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* WAR to LN/RMSN_fp8 before migrating to CP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix the wrong order of parameters of bwd of LN/RMSN_fp8.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Following review feedback to modify
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Force the hidden dim in Norm ops to no sharding and add warning msg.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Reuse fwd_rule in VJP functions
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* add gelu and dgelu.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Reuse fwd_rule in VJP functions for attentions
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Apply native FP8 Dtypes to fp8.py
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating cast_and_transpose from xmap to custom_partitioning
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating transpose from xmap to custom_partitioning
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Apply XLA pattern match to perform FP8 GEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* migrate layernorm_fp8 to custom_partitioning.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Unify code style
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Extend supported of Transpose with FP8
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Implementing layernorm_fp8_dot based on migrated custom calls.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Renaming variables and publish NVTE_FP8_COLLECTION_NAME
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Replace Q/DQ custom calls with native XLA implementations
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* migrate gelu_fp to custom_partitioning.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Miner fix
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Support custom calls with mutli-dims
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Support gerneral dot indices in _fp8_dot_impl
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Implementing layernrom_geglu_fp8_mlp
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Remove GEMM custom calls
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Remove xmap related code
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix typo and add query-function to FP8MetaPackage
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix some bugs of custom calls
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix CT's bugs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update UTs/eaxmaples to adapt to the API changes.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Unify kernel initilization in MLP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Modifing with code review's feedback
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update README and Add deprecating warning to *ShardingType
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Canonicalize the dtype
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding assertion for non-supported batch dims.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding doc/examples to _multidim_transpose
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Apply dtype-based rtol/atol to UTs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Deprecate QKV_INTERLEAVED enum
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Skip test_distributed_custom_ops.py
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong sharding of bias in SelfAttn
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding distributed ops unit-tests
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license to test_distributed_*
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Follow review feedback to modify
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Use total bytes involved in collective ops as criteria.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarDonglin Yang <dongliny@nvidia.com>
parent 7976bd00
...@@ -4,227 +4,167 @@ ...@@ -4,227 +4,167 @@
"""JAX te modules""" """JAX te modules"""
from typing import Tuple, Sequence from typing import Tuple, Sequence
from functools import partial, reduce from functools import partial
import operator
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import DType as TEDType from .cpp_extensions import cast_transpose
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype from .fp8 import FP8Helper, FP8MetaPackage
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True) def type_safe_dot_general(
x,
kernel,
fp8_meta_pkg: FP8MetaPackage = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,))
) -> jnp.ndarray:
"""
Type safe dot_general, including FP8.
"""
if fp8_meta_pkg is None:
kernel = jnp.asarray(kernel, x.dtype)
return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))
fp8_max = fp8_meta_pkg.fp8_max
amax = fp8_meta_pkg.amax
scale = fp8_meta_pkg.scale
scale_inv = fp8_meta_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
return _fp8_dot(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
contracting_dims)
def quantize(x, q_dtype, scale):
"""
Quantize with scale.
"""
dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype)
scale = scale.astype(x.dtype)
clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max)
return clipped_scaled_x.astype(q_dtype)
def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, def dequantize(x, dq_dtype, scale_inv):
fwd_dtype: TEDType, """
bwd_dtype: TEDType, Dequantize with scale_inv.
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), """
sharding_type: ShardingType = ShardingType.SINGLE, return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
dp_dim_index: int = 0) -> jnp.ndarray:
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(4, 5))
def fp8_dot_impl(
q_lhs: jnp.ndarray,
q_rhs: jnp.ndarray,
lhs_scale_inv: jnp.ndarray,
rhs_scale_inv: jnp.ndarray,
ctype: jnp.dtype, # computing type
contracting_dims: Tuple[Sequence[int], Sequence[int]]):
""" """
FP8 dot wrapper FP8 GEMM for XLA pattern match
""" """
assert fp8_gemm_pkg.num_of_gemm == 1 dim_nums = (contracting_dims, ((), ()))
inputs = fp8_gemm_pkg.inputs
kernel = fp8_gemm_pkg.kernels[0] lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
fp8_max = fp8_gemm_pkg.fp8_max rhs = dequantize(q_rhs, ctype, rhs_scale_inv)
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale return jax.lax.dot_general(lhs, rhs, dim_nums)
scale_inv = fp8_gemm_pkg.scale_inv
if sharding_type is ShardingType.SINGLE: @partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
res = _fp8_dot(inputs, def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
kernel, scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
fp8_max, contracting_dims: Tuple[Sequence[int], Sequence[int]]):
amax, output, _ = _fp8_dot_fwd_rule(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
scale, contracting_dims)
scale_inv, return output
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims, def _fp8_dot_fwd_rule(
sharding_type=sharding_type, x,
dp_axis_name="",
tp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
kernel_tp_index = None
# TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
kernel_tp_index = len(kernel.shape) - 1
elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
kernel_tp_index = 0
input_tp_index = len(inputs.shape) - 1
sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
axis_resources = merge_axis_resources(
[sharding_meta.axis_resources, fp8_sharding_meta.axis_resources])
partial_fp8_dot = partial(_fp8_dot,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes),
sharding_meta.out_axes, axis_resources,
(inputs_, kernel_, fp8_max, amax, scale, scale_inv))
res = jnp.reshape(res, sharding_meta.output_shapes[0])
return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs,
kernel,
fp8_maxs,
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
return res
def _fp8_dot_fwd(
inputs,
kernel, kernel,
fp8_maxs, fp8_max,
amax, amax,
scale, scale,
scale_inv, scale_inv,
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
contracting_dims, contracting_dims):
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):] x_shape_suf = x.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1] kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:] assert x_shape_suf == kernel_shape_pre
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx, 0:1] x_amax = amax[gemm_x_idx, 0:1]
input_scale = scale[gemm_input_idx] x_scale = scale[gemm_x_idx]
input_scale_inv = scale_inv[gemm_input_idx] x_scale_inv = scale_inv[gemm_x_idx]
input_cast, input_cast_trans, input_amax = cast_transpose(inputs_, input_amax, input_scale,
input_scale_inv, fwd_dtype) casted_x, casted_xt, updated_x_amax = \
cast_transpose(x, x_amax, x_scale, x_scale_inv, fwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))
kernel_amax = amax[gemm_kernel_idx, 0:1] kernel_amax = amax[gemm_kernel_idx, 0:1]
kernel_scale = scale[gemm_kernel_idx] kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx] kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
kernel_scale_inv, fwd_dtype)
res = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, input_cast, input_scale_inv,
fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW): casted_kerenl, casted_kerenl_t, updated_kernel_amax = \
res = jax.lax.psum(res, tp_axis_name) cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv,
fwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=(max(rhs_contracting_dims) + 1))
# (input_shape_pre, input_shape_suf) rhs_t_contracting_dims = tuple(range(kernel.ndim - len(rhs_contracting_dims), kernel.ndim))
# x (kernel_shape_pre, kernel_shape_suf) output = fp8_dot_impl(casted_x, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype,
# = (input_shape_pre, kernel_shape_suf) (lhs_contracting_dims, rhs_t_contracting_dims))
output_shape = input_shape_pre + kernel_shape_suf
res = jnp.reshape(res, output_shape)
ctx = (input_cast_trans, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax, ctx = (casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
inputs.shape, kernel.shape) updated_kernel_amax, x.shape, kernel.shape)
return res, ctx return output, ctx
def _fp8_dot_bwd( def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # pylint: disable=unused-argument
fwd_dtype, lhs_contracting_dims, rhs_contracting_dims = contracting_dims
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, \
sharding_type, updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx
dp_axis_name,
tp_axis_name, gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
fsdp_axis_name,
ctx,
g):
input_cast_trans, kernel_cast, \
fp8_maxs, amax, scale, scale_inv, \
input_amax, kernel_amax, \
inputs_shape, kernel_shape = ctx
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx, 0:1] grad_amax = amax[gemm_grad_idx, 0:1]
grad_scale = scale[gemm_grad_idx] grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx] grad_scale_inv = scale_inv[gemm_grad_idx]
g = jnp.reshape(g, (input_cast_trans.shape[1], -1))
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
input_scale_inv = scale_inv[gemm_input_idx]
wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype,
True, input_cast_trans, input_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
kernel_scale_inv = scale_inv[gemm_kernel_idx] casted_grad, casted_grad_t, updated_grad_amax = \
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv, cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) bwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))
amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]): xt_constracting_dim = tuple(range(len(lhs_contracting_dims), len(x_shape)))
wgrad = jax.lax.psum(wgrad, dp_axis_name) gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
amax = jax.lax.pmax(amax, dp_axis_name) x_scale_inv = scale_inv[gemm_x_idx]
wgrad = fp8_dot_impl(casted_xt, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(xt_constracting_dim, gt_constracting_dim))
if len(fsdp_axis_name) > 0: g_constracting_dim = tuple(
wgrad = jax.lax.psum(wgrad, fsdp_axis_name) range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
amax = jax.lax.pmax(amax, fsdp_axis_name) k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_constracting_dim, k_constracting_dim))
if is_tp_enabled(sharding_type.value[0]): amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
amax = jax.lax.pmax(amax, tp_axis_name) amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL): scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
dgrad = jax.lax.psum(dgrad, tp_axis_name)
dgrad = jnp.reshape(dgrad, inputs_shape) return dgrad, wgrad, fp8_max, amax, scale, scale_inv
wgrad = jnp.reshape(wgrad, kernel_shape)
return dgrad, wgrad, fp8_maxs, amax, scale, scale_inv
_fp8_dot.defvjp(_fp8_dot_fwd, _fp8_dot_bwd) _fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule)
This diff is collapsed.
...@@ -27,9 +27,8 @@ from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout ...@@ -27,9 +27,8 @@ from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import infer_major_sharding_type, infer_sharding_type from ..sharding import global_mesh_resource, num_of_devices
from ..sharding import global_shard_resource, with_sharding_constraint from ..sharding import with_sharding_constraint
from ..sharding import ShardingType
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -102,7 +101,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -102,7 +101,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
else: else:
rules_map[key] = [val] rules_map[key] = [val]
gsr = global_shard_resource() gsr = global_mesh_resource()
batch_dim_rule = [] batch_dim_rule = []
if gsr.dp_resource is not None: if gsr.dp_resource is not None:
...@@ -186,7 +185,6 @@ def core_attention(query: Array, ...@@ -186,7 +185,6 @@ def core_attention(query: Array,
scale_factor: float, scale_factor: float,
transpose_batch_sequence: bool, transpose_batch_sequence: bool,
softmax_type: SoftmaxType = SoftmaxType.SCALED, softmax_type: SoftmaxType = SoftmaxType.SCALED,
softmax_sharding_type: ShardingType = ShardingType.SINGLE,
mask: Optional[Array] = None, mask: Optional[Array] = None,
bias: Optional[Array] = None, bias: Optional[Array] = None,
dropout_rng: Optional[PRNGKey] = None, dropout_rng: Optional[PRNGKey] = None,
...@@ -226,9 +224,7 @@ def core_attention(query: Array, ...@@ -226,9 +224,7 @@ def core_attention(query: Array,
fused_scale_factor = scale_factor fused_scale_factor = scale_factor
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor, scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
sharding_type=softmax_sharding_type)(attn_weights, mask,
bias).astype(dtype)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
...@@ -482,8 +478,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -482,8 +478,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"Fused attention is not enabled. Because " \ f"Fused attention is not enabled. Because " \
f"{reason}fall back to unfused attention.") f"{reason}fall back to unfused attention.")
first_sharding_type, second_sharding_type = infer_sharding_type()
residual = inputs_q residual = inputs_q
if self.fuse_qkv: if self.fuse_qkv:
if is_self_attn: if is_self_attn:
...@@ -494,7 +488,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -494,7 +488,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=(3, self.num_heads * self.head_dim), features=(3, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
...@@ -516,7 +509,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -516,7 +509,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
...@@ -530,7 +522,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -530,7 +522,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
name='query')(inputs_q) name='query')(inputs_q)
kv_proj = DenseGeneral(axis=-1, kv_proj = DenseGeneral(axis=-1,
features=(2, self.num_heads * self.head_dim), features=(2, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init, kernel_init=kv_init,
...@@ -546,7 +537,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -546,7 +537,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
DenseGeneral, DenseGeneral,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -560,7 +550,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -560,7 +550,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True, return_layernorm_output=True,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
...@@ -648,7 +637,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -648,7 +637,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
seed = None seed = None
if dropout_rng is not None: if dropout_rng is not None:
seed = jax.random.split(dropout_rng, len(jax.devices())) seed = jax.random.split(dropout_rng, num_of_devices())
# ensure the old key never used # ensure the old key never used
del dropout_rng del dropout_rng
...@@ -665,8 +654,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -665,8 +654,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.dropout_rate, dropout_probability=self.dropout_rate,
is_training=not deterministic, is_training=not deterministic)
sharding_type=first_sharding_type)
else: else:
assert bias is None assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
...@@ -685,8 +673,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -685,8 +673,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.dropout_rate, dropout_probability=self.dropout_rate,
is_training=not deterministic, is_training=not deterministic)
sharding_type=first_sharding_type)
else: else:
def convert_to_softmax_type(attn_mask_type, mask): def convert_to_softmax_type(attn_mask_type, mask):
...@@ -710,7 +697,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -710,7 +697,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
softmax_type=softmax_type, softmax_type=softmax_type,
softmax_sharding_type=first_sharding_type,
mask=mask, mask=mask,
bias=bias, bias=bias,
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
...@@ -728,7 +714,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -728,7 +714,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
x = _with_sharding_constraint(x, attn_context_sharding_constraint) x = _with_sharding_constraint(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1], out = DenseGeneral(features=inputs_q.shape[-1],
sharding_type=second_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1, axis=-1,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
...@@ -1175,7 +1160,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1175,7 +1160,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
major_sharding_type=infer_major_sharding_type(),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size, intermediate_dim=self.mlp_hidden_size,
...@@ -1208,7 +1192,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1208,7 +1192,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
z = z + residual z = z + residual
if self.output_layernorm: if self.output_layernorm:
ln_sharding_type, _ = infer_sharding_type()
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
...@@ -1216,7 +1199,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1216,7 +1199,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,), bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype, dtype=self.dtype,
sharding_type=ln_sharding_type,
name="output_layer_norm")(z) name="output_layer_norm")(z)
return z return z
...@@ -6,7 +6,7 @@ Helper module for fp8 meta management ...@@ -6,7 +6,7 @@ Helper module for fp8 meta management
""" """
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -17,7 +17,7 @@ from transformer_engine_jax import get_cublasLt_version ...@@ -17,7 +17,7 @@ from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability from transformer_engine_jax import get_cuda_version, get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import ShardingResource from transformer_engine.jax.sharding import MeshResource
_is_fp8_available = None _is_fp8_available = None
_reason_for_no_fp8 = "" _reason_for_no_fp8 = ""
...@@ -59,37 +59,29 @@ def is_fp8_available(gpu_id=None) -> Tuple[bool, str]: ...@@ -59,37 +59,29 @@ def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
def _format2dtypes(format_: Format): def _format2dtypes(format_: Format):
if format_ == Format.E4M3: if format_ == Format.E4M3:
return DType.kFloat8E4M3, DType.kFloat8E4M3 return jnp.float8_e4m3fn, jnp.float8_e4m3fn
if format_ == Format.E5M2: if format_ == Format.E5M2:
return DType.kFloat8E5M2, DType.kFloat8E5M2 return jnp.float8_e5m2, jnp.float8_e5m2
if format_ == Format.HYBRID: if format_ == Format.HYBRID:
return DType.kFloat8E4M3, DType.kFloat8E5M2 return jnp.float8_e4m3fn, jnp.float8_e5m2
return DType.kBFloat16, DType.kBFloat16 return jnp.bfloat16, jnp.bfloat16
class FP8GemmPackage: class FP8MetaPackage:
""" """
A container that contains all required data for A container that contains all required meta data for FP8
FP8 GEMM
""" """
def __init__( def __init__(
self, self,
num_of_gemm: int, num_of_gemm: int,
inputs: jnp.ndarray,
kernels: List[jnp.ndarray],
fp8_max: jnp.ndarray, fp8_max: jnp.ndarray,
amax: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, scale_inv: jnp.ndarray,
) -> None: ) -> None:
total_num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
self._num_of_gemm = num_of_gemm self._num_of_gemm = num_of_gemm
self._inputs = inputs
assert len(kernels) == self._num_of_gemm
self._kernels = kernels
total_num_of_meta = self._num_of_gemm * FP8Helper.NUM_META_PER_GEMM
assert fp8_max.shape[0] == total_num_of_meta assert fp8_max.shape[0] == total_num_of_meta
self._fp8_max = fp8_max self._fp8_max = fp8_max
assert amax.shape[0] == total_num_of_meta assert amax.shape[0] == total_num_of_meta
...@@ -106,20 +98,6 @@ class FP8GemmPackage: ...@@ -106,20 +98,6 @@ class FP8GemmPackage:
""" """
return self._num_of_gemm return self._num_of_gemm
@property
def inputs(self) -> jnp.ndarray:
"""
inputs of this package
"""
return self._inputs
@property
def kernels(self) -> List[jnp.ndarray]:
"""
kernels of this package
"""
return self._kernels
@property @property
def fp8_max(self) -> jnp.ndarray: def fp8_max(self) -> jnp.ndarray:
""" """
...@@ -148,6 +126,19 @@ class FP8GemmPackage: ...@@ -148,6 +126,19 @@ class FP8GemmPackage:
""" """
return self._scale_inv return self._scale_inv
def get_package_by_gemm_idx(self, gemm_idx):
"""
Get a sub package by gemm_idx
"""
assert self.num_of_gemm > gemm_idx
meta_start_idx = gemm_idx * FP8Helper.NUM_META_PER_GEMM
meta_end_idx = (gemm_idx + 1) * FP8Helper.NUM_META_PER_GEMM
return FP8MetaPackage(1, self.fp8_max[meta_start_idx:meta_end_idx],
self.amax[meta_start_idx:meta_end_idx],
self.scale[meta_start_idx:meta_end_idx],
self.scale_inv[meta_start_idx:meta_end_idx])
class AmaxComputeAlgo(Enum): class AmaxComputeAlgo(Enum):
"""AmaxComputeAlgo.""" """AmaxComputeAlgo."""
...@@ -155,6 +146,9 @@ class AmaxComputeAlgo(Enum): ...@@ -155,6 +146,9 @@ class AmaxComputeAlgo(Enum):
MOST_RECENT = "most_recent" MOST_RECENT = "most_recent"
NVTE_FP8_COLLECTION_NAME = "fp8_meta_collection"
class FP8Helper: class FP8Helper:
""" """
FP8 helper to manage the FP8 meta FP8 helper to manage the FP8 meta
...@@ -162,8 +156,8 @@ class FP8Helper: ...@@ -162,8 +156,8 @@ class FP8Helper:
INITIALIZED = False INITIALIZED = False
MARGIN: float = 0.0 MARGIN: float = 0.0
FP8_FORMAT: Format = Format.HYBRID FP8_FORMAT: Format = Format.HYBRID
FWD_DTYPE: DType = DType.kFloat8E4M3 FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
BWD_DTYPE: DType = DType.kFloat8E5M2 BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
UPDATE_FP8META_INTERVAL: int = 1 UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_LEN: int = 1024 AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
...@@ -171,7 +165,7 @@ class FP8Helper: ...@@ -171,7 +165,7 @@ class FP8Helper:
INPUT_META_IDX_PER_GEMM: int = 0 INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1 KERNEL_META_IDX_PER_GEMM: int = 1
GRAD_META_IDX_PER_GEMM: int = 2 GRAD_META_IDX_PER_GEMM: int = 2
FP8_COLLECTION_NAME: str = "fp8_meta_collection" FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_AMAX_NAME: str = "fp8_meta_amax" FP8_AMAX_NAME: str = "fp8_meta_amax"
FP8_SCALE_NAME: str = "fp8_meta_scale" FP8_SCALE_NAME: str = "fp8_meta_scale"
FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv" FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
...@@ -216,21 +210,12 @@ class FP8Helper: ...@@ -216,21 +210,12 @@ class FP8Helper:
FP8Helper.INITIALIZED = False FP8Helper.INITIALIZED = False
FP8Helper.MARGIN = 0.0 FP8Helper.MARGIN = 0.0
FP8Helper.FP8_FORMAT = Format.HYBRID FP8Helper.FP8_FORMAT = Format.HYBRID
FP8Helper.FWD_DTYPE = DType.kFloat8E4M3 FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
FP8Helper.BWD_DTYPE = DType.kFloat8E5M2 _format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = 1 FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_LEN = 1024 FP8Helper.AMAX_HISTORY_LEN = 1024
FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
@staticmethod
def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax_buffers = jnp.roll(amax_buffers, -1, 1)
updated_amax_buffers = updated_amax_buffers.at[:, 0].set(0)
return updated_amax_buffers
@staticmethod @staticmethod
def update_collections(new: Collection, original: Collection) -> Collection: def update_collections(new: Collection, original: Collection) -> Collection:
""" """
...@@ -270,8 +255,8 @@ class FP8Helper: ...@@ -270,8 +255,8 @@ class FP8Helper:
Generate the FP8 max array Generate the FP8 max array
""" """
num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
fp8_max_fwd = FP8Helper.FP8_FORMAT.value.max_fwd fp8_max_fwd = jnp.finfo(FP8Helper.FWD_DTYPE).max
fp8_max_bwd = FP8Helper.FP8_FORMAT.value.max_bwd fp8_max_bwd = jnp.finfo(FP8Helper.BWD_DTYPE).max
fp8_max_per_gemm = [] fp8_max_per_gemm = []
for i in range(FP8Helper.NUM_META_PER_GEMM): for i in range(FP8Helper.NUM_META_PER_GEMM):
val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \ val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
...@@ -318,11 +303,40 @@ class FP8Helper: ...@@ -318,11 +303,40 @@ class FP8Helper:
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays) return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
@staticmethod
def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax = jnp.roll(amax, -1, -1)
updated_amax = updated_amax.at[..., 0].set(0)
return updated_amax
@staticmethod
@jax.jit
def update_fp8_scale(fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray) -> jnp.ndarray:
"""
Calculate fp8 scale and scale_inv based on given amax.
"""
if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax, axis=-1, keepdims=True)
else:
amax = amax[..., 0:1]
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = sf
scale_inv = 1 / sf
return scale, scale_inv
@contextmanager @contextmanager
def fp8_autocast(enabled: bool = False, def fp8_autocast(enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[DelayedScaling] = None,
sharding_resource: Optional[ShardingResource] = None) -> None: mesh_resource: Optional[MeshResource] = None) -> None:
r""" r"""
Context manager for FP8 usage. Context manager for FP8 usage.
...@@ -334,9 +348,9 @@ def fp8_autocast(enabled: bool = False, ...@@ -334,9 +348,9 @@ def fp8_autocast(enabled: bool = False,
devices = np.asarray(jax.devices()).reshape(*mesh_shape) devices = np.asarray(jax.devices()).reshape(*mesh_shape)
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
sharding_resource=ShardingResource(dp_mesh_axis_name, tp_mesh_axis_name) mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, sharding_resource=sharding_resource): with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple()) rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer() transformer = TransformerLayer()
...@@ -356,7 +370,7 @@ def fp8_autocast(enabled: bool = False, ...@@ -356,7 +370,7 @@ def fp8_autocast(enabled: bool = False,
Whether or not to enable fp8 Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None fp8_recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training. Recipe used for FP8 training.
sharding_resource: ShardingResource, default = None mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along. Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used. If set to None, then no data or tensor parallelism will be used.
...@@ -373,11 +387,11 @@ def fp8_autocast(enabled: bool = False, ...@@ -373,11 +387,11 @@ def fp8_autocast(enabled: bool = False,
"DelayedScaling override_linear_precision isn't supported by TE/JAX.") "DelayedScaling override_linear_precision isn't supported by TE/JAX.")
assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.") assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")
if sharding_resource is None: if mesh_resource is None:
sharding_resource = ShardingResource() mesh_resource = MeshResource()
try: try:
with global_shard_guard(sharding_resource): with global_shard_guard(mesh_resource):
if enabled: if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8 assert fp8_available, reason_for_no_fp8
......
...@@ -15,12 +15,6 @@ from transformer_engine_jax import NVTE_QKV_Layout ...@@ -15,12 +15,6 @@ from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
class AttnBiasType(Enum): class AttnBiasType(Enum):
...@@ -54,62 +48,24 @@ def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, ...@@ -54,62 +48,24 @@ def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type,
head_dim).is_fused_attn_kernel_available() head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray, def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
bias: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
mask: jnp.ndarray, scaling_factor: float, dropout_probability: float, is_training: bool):
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
""" """
Self fused attention wrapper Self fused attention wrapper
""" """
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ assert attn_mask_type is not AttnMaskType.NO_MASK, \
"self_fused_attn does not support row-split tensor parallelism currently." "Currently not support AttnMaskType.NO_MASK."
if sharding_type is ShardingType.SINGLE: output = _self_fused_attn(qkv,
output = _self_fused_attn(qkv, bias,
bias, mask,
mask, seed,
seed, attn_bias_type=attn_bias_type,
attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type,
attn_mask_type=attn_mask_type, scaling_factor=scaling_factor,
scaling_factor=scaling_factor, dropout_probability=dropout_probability,
dropout_probability=dropout_probability, is_training=is_training)
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [qkv, bias, mask, seed]
batch, seqlen, _, num_head, head_dim = qkv.shape
output_shape = [batch, seqlen, num_head, head_dim]
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, None, 0, 0], [0]),
tp_dims=([3, 1, None, 0], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_self_fused_attn = partial(_self_fused_attn,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes)
return output return output
...@@ -118,119 +74,70 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -118,119 +74,70 @@ def self_fused_attn(qkv: jnp.ndarray,
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_fwd(qkv,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
return output
def _self_fused_attn_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen))
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
squeezed_mask = mask[:, :, :, 0]
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias, bias,
cu_seqlen, squeezed_mask,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (qkv, softmax_aux, rng_state, output, cu_seqlen) return output, (qkv, softmax_aux, rng_state, output, squeezed_mask)
def _self_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, grad): is_training, ctx, dz):
qkv, softmax_aux, rng_state, output, cu_seqlen = ctx qkv, softmax_aux, rng_state, output, squeezed_mask = ctx
doutput = grad
grad_qkv, grad_bias = self_fused_attn_bwd(qkv, grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
doutput, dz,
cu_seqlen, squeezed_mask,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None grad_bias = None
return grad_qkv, grad_bias, None, None return grad_qkv, grad_bias, None, None
_self_fused_attn.defvjp(_self_fused_attn_fwd, _self_fused_attn_bwd) _self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
def cross_fused_attn(q: jnp.ndarray, def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
kv: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
mask: jnp.ndarray, scaling_factor: float, dropout_probability: float, is_training: bool):
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
""" """
Cross multi-head attention wrapper Cross multi-head attention wrapper
""" """
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"cross_fused_attn does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE:
output = _cross_fused_attn(q,
kv,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [q, kv, mask, seed]
output_shape = q.shape
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, 0, 0, None], [0]),
tp_dims=([2, 3, None, None], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_cross_fused_attn = partial(_cross_fused_attn,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes, output = _cross_fused_attn(q,
sharding_meta.out_axes, sharding_meta.axis_resources, inputs_) kv,
mask,
output = jnp.reshape(output_, sharding_meta.output_shapes) seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output return output
...@@ -240,54 +147,40 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: ...@@ -240,54 +147,40 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed:
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_fwd(q, output, _ = _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type,
kv, scaling_factor, dropout_probability, is_training)
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output return output
def _cross_fused_attn_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, def _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
q_cu_seqlen = jnp.cumsum(q_seqlen)
q_cu_seqlen = jnp.hstack((0, q_cu_seqlen))
kv_seqlen = jnp.sum(mask[:, :, 0, :] == 0, axis=(-1, -2), dtype=jnp.int32) q_squeezed_mask = mask[:, :, :, 0]
kv_cu_seqlen = jnp.cumsum(kv_seqlen) kv_squeezed_mask = mask[:, :, 0, :]
kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))
output, softmax_aux = cross_fused_attn_fwd(q, output, softmax_aux = cross_fused_attn_fwd(q,
kv, kv,
q_cu_seqlen, q_squeezed_mask,
kv_cu_seqlen, kv_squeezed_mask,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen) return output, (softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask)
def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, grad):
softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx
doutput = grad def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask = ctx
grad_q, grad_kv = cross_fused_attn_bwd(q, grad_q, grad_kv = cross_fused_attn_bwd(q,
kv, kv,
softmax_aux, softmax_aux,
doutput, dz,
q_cu_seqlen, q_squeezed_mask,
kv_cu_seqlen, kv_squeezed_mask,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -297,4 +190,4 @@ def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropou ...@@ -297,4 +190,4 @@ def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropou
return grad_q, grad_kv, None, None return grad_q, grad_kv, None, None
_cross_fused_attn.defvjp(_cross_fused_attn_fwd, _cross_fused_attn_bwd) _cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment