Unverified Commit 71e51eae authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Migrating from Xmap to Custom Partitioning for All Custom Calls (#472)



* Refactor sharding.py for the further custom_partitioning migration
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* WAR to LN/RMSN_fp8 before migrating to CP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix the wrong order of parameters of bwd of LN/RMSN_fp8.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Following review feedback to modify
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Force the hidden dim in Norm ops to no sharding and add warning msg.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Reuse fwd_rule in VJP functions
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* add gelu and dgelu.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Reuse fwd_rule in VJP functions for attentions
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Apply native FP8 Dtypes to fp8.py
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating cast_and_transpose from xmap to custom_partitioning
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating transpose from xmap to custom_partitioning
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Apply XLA pattern match to perform FP8 GEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* migrate layernorm_fp8 to custom_partitioning.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Unify code style
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Extend supported of Transpose with FP8
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Implementing layernorm_fp8_dot based on migrated custom calls.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Renaming variables and publish NVTE_FP8_COLLECTION_NAME
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Replace Q/DQ custom calls with native XLA implementations
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* migrate gelu_fp to custom_partitioning.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Miner fix
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Support custom calls with mutli-dims
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Support gerneral dot indices in _fp8_dot_impl
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Implementing layernrom_geglu_fp8_mlp
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Remove GEMM custom calls
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Remove xmap related code
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix typo and add query-function to FP8MetaPackage
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix some bugs of custom calls
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix CT's bugs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update UTs/eaxmaples to adapt to the API changes.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Unify kernel initilization in MLP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Modifing with code review's feedback
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update README and Add deprecating warning to *ShardingType
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Canonicalize the dtype
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding assertion for non-supported batch dims.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding doc/examples to _multidim_transpose
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Apply dtype-based rtol/atol to UTs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Deprecate QKV_INTERLEAVED enum
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Skip test_distributed_custom_ops.py
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong sharding of bias in SelfAttn
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding distributed ops unit-tests
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license to test_distributed_*
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Follow review feedback to modify
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Use total bytes involved in collective ops as criteria.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarDonglin Yang <dongliny@nvidia.com>
parent 7976bd00
......@@ -4,227 +4,167 @@
"""JAX te modules"""
from typing import Tuple, Sequence
from functools import partial, reduce
import operator
from functools import partial
import jax
import jax.numpy as jnp
from transformer_engine_jax import DType as TEDType
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner, extend_fsdp_sharding_meta
from .cpp_extensions import cast_transpose
from .fp8 import FP8Helper, FP8MetaPackage
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def type_safe_dot_general(
x,
kernel,
fp8_meta_pkg: FP8MetaPackage = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,))
) -> jnp.ndarray:
"""
Type safe dot_general, including FP8.
"""
if fp8_meta_pkg is None:
kernel = jnp.asarray(kernel, x.dtype)
return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))
fp8_max = fp8_meta_pkg.fp8_max
amax = fp8_meta_pkg.amax
scale = fp8_meta_pkg.scale
scale_inv = fp8_meta_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
return _fp8_dot(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
contracting_dims)
def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0) -> jnp.ndarray:
def quantize(x, q_dtype, scale):
"""
FP8 dot wrapper
Quantize with scale.
"""
assert fp8_gemm_pkg.num_of_gemm == 1
inputs = fp8_gemm_pkg.inputs
kernel = fp8_gemm_pkg.kernels[0]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
if sharding_type is ShardingType.SINGLE:
res = _fp8_dot(inputs,
kernel,
fp8_max,
amax,
scale,
scale_inv,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name="",
tp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
kernel_tp_index = None
# TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
kernel_tp_index = len(kernel.shape) - 1
elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
kernel_tp_index = 0
input_tp_index = len(inputs.shape) - 1
sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
axis_resources = merge_axis_resources(
[sharding_meta.axis_resources, fp8_sharding_meta.axis_resources])
partial_fp8_dot = partial(_fp8_dot,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes),
sharding_meta.out_axes, axis_resources,
(inputs_, kernel_, fp8_max, amax, scale, scale_inv))
res = jnp.reshape(res, sharding_meta.output_shapes[0])
return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs,
kernel,
fp8_maxs,
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
return res
dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype)
scale = scale.astype(x.dtype)
clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max)
return clipped_scaled_x.astype(q_dtype)
def _fp8_dot_fwd(
inputs,
def dequantize(x, dq_dtype, scale_inv):
"""
Dequantize with scale_inv.
"""
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(4, 5))
def fp8_dot_impl(
q_lhs: jnp.ndarray,
q_rhs: jnp.ndarray,
lhs_scale_inv: jnp.ndarray,
rhs_scale_inv: jnp.ndarray,
ctype: jnp.dtype, # computing type
contracting_dims: Tuple[Sequence[int], Sequence[int]]):
"""
FP8 GEMM for XLA pattern match
"""
dim_nums = (contracting_dims, ((), ()))
lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
rhs = dequantize(q_rhs, ctype, rhs_scale_inv)
return jax.lax.dot_general(lhs, rhs, dim_nums)
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
contracting_dims: Tuple[Sequence[int], Sequence[int]]):
output, _ = _fp8_dot_fwd_rule(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
contracting_dims)
return output
def _fp8_dot_fwd_rule(
x,
kernel,
fp8_maxs,
fp8_max,
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
contracting_dims):
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
x_shape_suf = x.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:]
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
assert x_shape_suf == kernel_shape_pre
amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx, 0:1]
input_scale = scale[gemm_input_idx]
input_scale_inv = scale_inv[gemm_input_idx]
input_cast, input_cast_trans, input_amax = cast_transpose(inputs_, input_amax, input_scale,
input_scale_inv, fwd_dtype)
x_amax = amax[gemm_x_idx, 0:1]
x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx]
casted_x, casted_xt, updated_x_amax = \
cast_transpose(x, x_amax, x_scale, x_scale_inv, fwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))
kernel_amax = amax[gemm_kernel_idx, 0:1]
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
kernel_scale_inv, fwd_dtype)
res = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, input_cast, input_scale_inv,
fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
res = jax.lax.psum(res, tp_axis_name)
casted_kerenl, casted_kerenl_t, updated_kernel_amax = \
cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv,
fwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=(max(rhs_contracting_dims) + 1))
# (input_shape_pre, input_shape_suf)
# x (kernel_shape_pre, kernel_shape_suf)
# = (input_shape_pre, kernel_shape_suf)
output_shape = input_shape_pre + kernel_shape_suf
res = jnp.reshape(res, output_shape)
rhs_t_contracting_dims = tuple(range(kernel.ndim - len(rhs_contracting_dims), kernel.ndim))
output = fp8_dot_impl(casted_x, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype,
(lhs_contracting_dims, rhs_t_contracting_dims))
ctx = (input_cast_trans, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax,
inputs.shape, kernel.shape)
return res, ctx
ctx = (casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape)
return output, ctx
def _fp8_dot_bwd(
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
sharding_type,
dp_axis_name,
tp_axis_name,
fsdp_axis_name,
ctx,
g):
input_cast_trans, kernel_cast, \
fp8_maxs, amax, scale, scale_inv, \
input_amax, kernel_amax, \
inputs_shape, kernel_shape = ctx
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, \
updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx
gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx, 0:1]
grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx]
g = jnp.reshape(g, (input_cast_trans.shape[1], -1))
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
input_scale_inv = scale_inv[gemm_input_idx]
wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype,
True, input_cast_trans, input_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
xt_constracting_dim = tuple(range(len(lhs_contracting_dims), len(x_shape)))
gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
x_scale_inv = scale_inv[gemm_x_idx]
wgrad = fp8_dot_impl(casted_xt, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(xt_constracting_dim, gt_constracting_dim))
if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if len(fsdp_axis_name) > 0:
wgrad = jax.lax.psum(wgrad, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_constracting_dim, k_constracting_dim))
if is_tp_enabled(sharding_type.value[0]):
amax = jax.lax.pmax(amax, tp_axis_name)
amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
dgrad = jax.lax.psum(dgrad, tp_axis_name)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
dgrad = jnp.reshape(dgrad, inputs_shape)
wgrad = jnp.reshape(wgrad, kernel_shape)
return dgrad, wgrad, fp8_maxs, amax, scale, scale_inv
return dgrad, wgrad, fp8_max, amax, scale, scale_inv
_fp8_dot.defvjp(_fp8_dot_fwd, _fp8_dot_bwd)
_fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule)
This diff is collapsed.
......@@ -27,9 +27,8 @@ from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType
from ..sharding import infer_major_sharding_type, infer_sharding_type
from ..sharding import global_shard_resource, with_sharding_constraint
from ..sharding import ShardingType
from ..sharding import global_mesh_resource, num_of_devices
from ..sharding import with_sharding_constraint
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -102,7 +101,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
else:
rules_map[key] = [val]
gsr = global_shard_resource()
gsr = global_mesh_resource()
batch_dim_rule = []
if gsr.dp_resource is not None:
......@@ -186,7 +185,6 @@ def core_attention(query: Array,
scale_factor: float,
transpose_batch_sequence: bool,
softmax_type: SoftmaxType = SoftmaxType.SCALED,
softmax_sharding_type: ShardingType = ShardingType.SINGLE,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
dropout_rng: Optional[PRNGKey] = None,
......@@ -226,9 +224,7 @@ def core_attention(query: Array,
fused_scale_factor = scale_factor
attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor,
sharding_type=softmax_sharding_type)(attn_weights, mask,
bias).astype(dtype)
scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
......@@ -482,8 +478,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"Fused attention is not enabled. Because " \
f"{reason}fall back to unfused attention.")
first_sharding_type, second_sharding_type = infer_sharding_type()
residual = inputs_q
if self.fuse_qkv:
if is_self_attn:
......@@ -494,7 +488,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=(3, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=(W_NO_SHARD_AXES,),
......@@ -516,7 +509,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=(W_NO_SHARD_AXES,),
......@@ -530,7 +522,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
name='query')(inputs_q)
kv_proj = DenseGeneral(axis=-1,
features=(2, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init,
......@@ -546,7 +537,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias,
......@@ -560,7 +550,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True,
scale_axes=(W_NO_SHARD_AXES,),
......@@ -648,7 +637,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
seed = None
if dropout_rng is not None:
seed = jax.random.split(dropout_rng, len(jax.devices()))
seed = jax.random.split(dropout_rng, num_of_devices())
# ensure the old key never used
del dropout_rng
......@@ -665,8 +654,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic,
sharding_type=first_sharding_type)
is_training=not deterministic)
else:
assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
......@@ -685,8 +673,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic,
sharding_type=first_sharding_type)
is_training=not deterministic)
else:
def convert_to_softmax_type(attn_mask_type, mask):
......@@ -710,7 +697,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
softmax_type=softmax_type,
softmax_sharding_type=first_sharding_type,
mask=mask,
bias=bias,
dropout_rng=dropout_rng,
......@@ -728,7 +714,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
x = _with_sharding_constraint(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1],
sharding_type=second_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1,
kernel_init=self.kernel_init,
......@@ -1175,7 +1160,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
major_sharding_type=infer_major_sharding_type(),
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size,
......@@ -1208,7 +1192,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
z = z + residual
if self.output_layernorm:
ln_sharding_type, _ = infer_sharding_type()
z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
......@@ -1216,7 +1199,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
sharding_type=ln_sharding_type,
name="output_layer_norm")(z)
return z
......@@ -6,7 +6,7 @@ Helper module for fp8 meta management
"""
from contextlib import contextmanager
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import jax
import jax.numpy as jnp
......@@ -17,7 +17,7 @@ from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import ShardingResource
from transformer_engine.jax.sharding import MeshResource
_is_fp8_available = None
_reason_for_no_fp8 = ""
......@@ -59,37 +59,29 @@ def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
def _format2dtypes(format_: Format):
if format_ == Format.E4M3:
return DType.kFloat8E4M3, DType.kFloat8E4M3
return jnp.float8_e4m3fn, jnp.float8_e4m3fn
if format_ == Format.E5M2:
return DType.kFloat8E5M2, DType.kFloat8E5M2
return jnp.float8_e5m2, jnp.float8_e5m2
if format_ == Format.HYBRID:
return DType.kFloat8E4M3, DType.kFloat8E5M2
return DType.kBFloat16, DType.kBFloat16
return jnp.float8_e4m3fn, jnp.float8_e5m2
return jnp.bfloat16, jnp.bfloat16
class FP8GemmPackage:
class FP8MetaPackage:
"""
A container that contains all required data for
FP8 GEMM
A container that contains all required meta data for FP8
"""
def __init__(
self,
num_of_gemm: int,
inputs: jnp.ndarray,
kernels: List[jnp.ndarray],
fp8_max: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
) -> None:
total_num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
self._num_of_gemm = num_of_gemm
self._inputs = inputs
assert len(kernels) == self._num_of_gemm
self._kernels = kernels
total_num_of_meta = self._num_of_gemm * FP8Helper.NUM_META_PER_GEMM
assert fp8_max.shape[0] == total_num_of_meta
self._fp8_max = fp8_max
assert amax.shape[0] == total_num_of_meta
......@@ -106,20 +98,6 @@ class FP8GemmPackage:
"""
return self._num_of_gemm
@property
def inputs(self) -> jnp.ndarray:
"""
inputs of this package
"""
return self._inputs
@property
def kernels(self) -> List[jnp.ndarray]:
"""
kernels of this package
"""
return self._kernels
@property
def fp8_max(self) -> jnp.ndarray:
"""
......@@ -148,6 +126,19 @@ class FP8GemmPackage:
"""
return self._scale_inv
def get_package_by_gemm_idx(self, gemm_idx):
"""
Get a sub package by gemm_idx
"""
assert self.num_of_gemm > gemm_idx
meta_start_idx = gemm_idx * FP8Helper.NUM_META_PER_GEMM
meta_end_idx = (gemm_idx + 1) * FP8Helper.NUM_META_PER_GEMM
return FP8MetaPackage(1, self.fp8_max[meta_start_idx:meta_end_idx],
self.amax[meta_start_idx:meta_end_idx],
self.scale[meta_start_idx:meta_end_idx],
self.scale_inv[meta_start_idx:meta_end_idx])
class AmaxComputeAlgo(Enum):
"""AmaxComputeAlgo."""
......@@ -155,6 +146,9 @@ class AmaxComputeAlgo(Enum):
MOST_RECENT = "most_recent"
NVTE_FP8_COLLECTION_NAME = "fp8_meta_collection"
class FP8Helper:
"""
FP8 helper to manage the FP8 meta
......@@ -162,8 +156,8 @@ class FP8Helper:
INITIALIZED = False
MARGIN: float = 0.0
FP8_FORMAT: Format = Format.HYBRID
FWD_DTYPE: DType = DType.kFloat8E4M3
BWD_DTYPE: DType = DType.kFloat8E5M2
FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
......@@ -171,7 +165,7 @@ class FP8Helper:
INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1
GRAD_META_IDX_PER_GEMM: int = 2
FP8_COLLECTION_NAME: str = "fp8_meta_collection"
FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_AMAX_NAME: str = "fp8_meta_amax"
FP8_SCALE_NAME: str = "fp8_meta_scale"
FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
......@@ -216,21 +210,12 @@ class FP8Helper:
FP8Helper.INITIALIZED = False
FP8Helper.MARGIN = 0.0
FP8Helper.FP8_FORMAT = Format.HYBRID
FP8Helper.FWD_DTYPE = DType.kFloat8E4M3
FP8Helper.BWD_DTYPE = DType.kFloat8E5M2
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_LEN = 1024
FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
@staticmethod
def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax_buffers = jnp.roll(amax_buffers, -1, 1)
updated_amax_buffers = updated_amax_buffers.at[:, 0].set(0)
return updated_amax_buffers
@staticmethod
def update_collections(new: Collection, original: Collection) -> Collection:
"""
......@@ -270,8 +255,8 @@ class FP8Helper:
Generate the FP8 max array
"""
num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
fp8_max_fwd = FP8Helper.FP8_FORMAT.value.max_fwd
fp8_max_bwd = FP8Helper.FP8_FORMAT.value.max_bwd
fp8_max_fwd = jnp.finfo(FP8Helper.FWD_DTYPE).max
fp8_max_bwd = jnp.finfo(FP8Helper.BWD_DTYPE).max
fp8_max_per_gemm = []
for i in range(FP8Helper.NUM_META_PER_GEMM):
val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
......@@ -318,11 +303,40 @@ class FP8Helper:
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
@staticmethod
def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax = jnp.roll(amax, -1, -1)
updated_amax = updated_amax.at[..., 0].set(0)
return updated_amax
@staticmethod
@jax.jit
def update_fp8_scale(fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray) -> jnp.ndarray:
"""
Calculate fp8 scale and scale_inv based on given amax.
"""
if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax, axis=-1, keepdims=True)
else:
amax = amax[..., 0:1]
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = sf
scale_inv = 1 / sf
return scale, scale_inv
@contextmanager
def fp8_autocast(enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
sharding_resource: Optional[ShardingResource] = None) -> None:
mesh_resource: Optional[MeshResource] = None) -> None:
r"""
Context manager for FP8 usage.
......@@ -334,9 +348,9 @@ def fp8_autocast(enabled: bool = False,
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
sharding_resource=ShardingResource(dp_mesh_axis_name, tp_mesh_axis_name)
mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, sharding_resource=sharding_resource):
with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer()
......@@ -356,7 +370,7 @@ def fp8_autocast(enabled: bool = False,
Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training.
sharding_resource: ShardingResource, default = None
mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used.
......@@ -373,11 +387,11 @@ def fp8_autocast(enabled: bool = False,
"DelayedScaling override_linear_precision isn't supported by TE/JAX.")
assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")
if sharding_resource is None:
sharding_resource = ShardingResource()
if mesh_resource is None:
mesh_resource = MeshResource()
try:
with global_shard_guard(sharding_resource):
with global_shard_guard(mesh_resource):
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
......
......@@ -15,12 +15,6 @@ from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
class AttnBiasType(Enum):
......@@ -54,23 +48,15 @@ def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type,
head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Self fused attention wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"self_fused_attn does not support row-split tensor parallelism currently."
assert attn_mask_type is not AttnMaskType.NO_MASK, \
"Currently not support AttnMaskType.NO_MASK."
if sharding_type is ShardingType.SINGLE:
output = _self_fused_attn(qkv,
bias,
mask,
......@@ -80,36 +66,6 @@ def self_fused_attn(qkv: jnp.ndarray,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [qkv, bias, mask, seed]
batch, seqlen, _, num_head, head_dim = qkv.shape
output_shape = [batch, seqlen, num_head, head_dim]
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, None, 0, 0], [0]),
tp_dims=([3, 1, None, 0], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_self_fused_attn = partial(_self_fused_attn,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes)
return output
......@@ -118,81 +74,61 @@ def self_fused_attn(qkv: jnp.ndarray,
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_fwd(qkv,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
return output
def _self_fused_attn_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen))
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
squeezed_mask = mask[:, :, :, 0]
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias,
cu_seqlen,
squeezed_mask,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (qkv, softmax_aux, rng_state, output, cu_seqlen)
return output, (qkv, softmax_aux, rng_state, output, squeezed_mask)
def _self_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, grad):
qkv, softmax_aux, rng_state, output, cu_seqlen = ctx
doutput = grad
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
qkv, softmax_aux, rng_state, output, squeezed_mask = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
softmax_aux,
rng_state,
output,
doutput,
cu_seqlen,
dz,
squeezed_mask,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_qkv, grad_bias, None, None
_self_fused_attn.defvjp(_self_fused_attn_fwd, _self_fused_attn_bwd)
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
def cross_fused_attn(q: jnp.ndarray,
kv: jnp.ndarray,
mask: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Cross multi-head attention wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"cross_fused_attn does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE:
output = _cross_fused_attn(q,
kv,
mask,
......@@ -202,35 +138,6 @@ def cross_fused_attn(q: jnp.ndarray,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [q, kv, mask, seed]
output_shape = q.shape
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, 0, 0, None], [0]),
tp_dims=([2, 3, None, None], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_cross_fused_attn = partial(_cross_fused_attn,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes)
return output
......@@ -240,54 +147,40 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed:
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_fwd(q,
kv,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output, _ = _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
return output
def _cross_fused_attn_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
def _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
q_cu_seqlen = jnp.cumsum(q_seqlen)
q_cu_seqlen = jnp.hstack((0, q_cu_seqlen))
kv_seqlen = jnp.sum(mask[:, :, 0, :] == 0, axis=(-1, -2), dtype=jnp.int32)
kv_cu_seqlen = jnp.cumsum(kv_seqlen)
kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))
q_squeezed_mask = mask[:, :, :, 0]
kv_squeezed_mask = mask[:, :, 0, :]
output, softmax_aux = cross_fused_attn_fwd(q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
q_squeezed_mask,
kv_squeezed_mask,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen)
return output, (softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask)
def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, grad):
softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx
doutput = grad
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask = ctx
grad_q, grad_kv = cross_fused_attn_bwd(q,
kv,
softmax_aux,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
dz,
q_squeezed_mask,
kv_squeezed_mask,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
......@@ -297,4 +190,4 @@ def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropou
return grad_q, grad_kv, None, None
_cross_fused_attn.defvjp(_cross_fused_attn_fwd, _cross_fused_attn_bwd)
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
This diff is collapsed.
This diff is collapsed.
......@@ -49,7 +49,7 @@ class TransformerEngineBaseLayer(BaseLayer):
fp8_collection_map = {
FP8Helper.FP8_COLLECTION_NAME: [
WeightHParamsCollection.SKIP_LP_REGULARIZATION,
WeightHParamsCollection.NON_TRAINABLE,
WeightHParamsCollection.OVERWRITE_WITH_GRADIENT,
WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
]
}
......@@ -92,8 +92,7 @@ class LayerNorm(TransformerEngineBaseLayer):
"ln_bias", self.bias_init),
bias_axes=self.bias_axes,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
sharding_type=self.sharding_type)
transpose_batch_sequence=self.transpose_batch_sequence)
self.create_layer("layer_norm", ln_cls)
......@@ -115,8 +114,7 @@ class FusedSoftmax(TransformerEngineBaseLayer):
fused_softmax_cls = partial(Softmax,
scale_factor=self.scale_factor,
softmax_type=self.softmax_type,
sharding_type=self.sharding_type)
softmax_type=self.softmax_type)
self.create_layer("fused_softmax", fused_softmax_cls)
......@@ -151,8 +149,7 @@ class Linear(TransformerEngineBaseLayer):
bias_axes=self.bias_axes,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
sharding_type=self.sharding_type)
transpose_batch_sequence=self.transpose_batch_sequence)
self.create_layer("linear", dense_general_cls)
......@@ -208,8 +205,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
depth_scaling=self.depth_scaling,
sharding_type=self.sharding_type)
depth_scaling=self.depth_scaling)
self.create_layer("ln_linear", ln_dense_general_cls)
......@@ -273,8 +269,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
major_sharding_type=self.major_sharding_type)
transpose_batch_sequence=self.transpose_batch_sequence)
self.create_layer("ln_mlp", ln_mlp_cls)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment