Unverified Commit 71e51eae authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Migrating from Xmap to Custom Partitioning for All Custom Calls (#472)



* Refactor sharding.py for the further custom_partitioning migration
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* WAR to LN/RMSN_fp8 before migrating to CP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix the wrong order of parameters of bwd of LN/RMSN_fp8.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Following review feedback to modify
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Force the hidden dim in Norm ops to no sharding and add warning msg.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Reuse fwd_rule in VJP functions
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* add gelu and dgelu.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Reuse fwd_rule in VJP functions for attentions
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Apply native FP8 Dtypes to fp8.py
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating cast_and_transpose from xmap to custom_partitioning
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating transpose from xmap to custom_partitioning
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Apply XLA pattern match to perform FP8 GEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* migrate layernorm_fp8 to custom_partitioning.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Unify code style
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Extend supported of Transpose with FP8
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Implementing layernorm_fp8_dot based on migrated custom calls.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Renaming variables and publish NVTE_FP8_COLLECTION_NAME
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Replace Q/DQ custom calls with native XLA implementations
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* migrate gelu_fp to custom_partitioning.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Miner fix
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Support custom calls with mutli-dims
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Support gerneral dot indices in _fp8_dot_impl
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Implementing layernrom_geglu_fp8_mlp
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Remove GEMM custom calls
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Remove xmap related code
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix typo and add query-function to FP8MetaPackage
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix some bugs of custom calls
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix CT's bugs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update UTs/eaxmaples to adapt to the API changes.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Unify kernel initilization in MLP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Modifing with code review's feedback
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update README and Add deprecating warning to *ShardingType
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Canonicalize the dtype
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding assertion for non-supported batch dims.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding doc/examples to _multidim_transpose
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Apply dtype-based rtol/atol to UTs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Deprecate QKV_INTERLEAVED enum
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Skip test_distributed_custom_ops.py
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong sharding of bias in SelfAttn
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding distributed ops unit-tests
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license to test_distributed_*
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Follow review feedback to modify
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Use total bytes involved in collective ops as criteria.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarDonglin Yang <dongliny@nvidia.com>
parent 7976bd00
...@@ -4,227 +4,167 @@ ...@@ -4,227 +4,167 @@
"""JAX te modules""" """JAX te modules"""
from typing import Tuple, Sequence from typing import Tuple, Sequence
from functools import partial, reduce from functools import partial
import operator
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import DType as TEDType from .cpp_extensions import cast_transpose
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype from .fp8 import FP8Helper, FP8MetaPackage
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True) def type_safe_dot_general(
x,
kernel,
fp8_meta_pkg: FP8MetaPackage = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,))
) -> jnp.ndarray:
"""
Type safe dot_general, including FP8.
"""
if fp8_meta_pkg is None:
kernel = jnp.asarray(kernel, x.dtype)
return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))
fp8_max = fp8_meta_pkg.fp8_max
amax = fp8_meta_pkg.amax
scale = fp8_meta_pkg.scale
scale_inv = fp8_meta_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
return _fp8_dot(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
contracting_dims)
def quantize(x, q_dtype, scale):
"""
Quantize with scale.
"""
dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype)
scale = scale.astype(x.dtype)
clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max)
return clipped_scaled_x.astype(q_dtype)
def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, def dequantize(x, dq_dtype, scale_inv):
fwd_dtype: TEDType, """
bwd_dtype: TEDType, Dequantize with scale_inv.
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), """
sharding_type: ShardingType = ShardingType.SINGLE, return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
dp_dim_index: int = 0) -> jnp.ndarray:
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(4, 5))
def fp8_dot_impl(
q_lhs: jnp.ndarray,
q_rhs: jnp.ndarray,
lhs_scale_inv: jnp.ndarray,
rhs_scale_inv: jnp.ndarray,
ctype: jnp.dtype, # computing type
contracting_dims: Tuple[Sequence[int], Sequence[int]]):
""" """
FP8 dot wrapper FP8 GEMM for XLA pattern match
""" """
assert fp8_gemm_pkg.num_of_gemm == 1 dim_nums = (contracting_dims, ((), ()))
inputs = fp8_gemm_pkg.inputs
kernel = fp8_gemm_pkg.kernels[0] lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
fp8_max = fp8_gemm_pkg.fp8_max rhs = dequantize(q_rhs, ctype, rhs_scale_inv)
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale return jax.lax.dot_general(lhs, rhs, dim_nums)
scale_inv = fp8_gemm_pkg.scale_inv
if sharding_type is ShardingType.SINGLE: @partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
res = _fp8_dot(inputs, def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
kernel, scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
fp8_max, contracting_dims: Tuple[Sequence[int], Sequence[int]]):
amax, output, _ = _fp8_dot_fwd_rule(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
scale, contracting_dims)
scale_inv, return output
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims, def _fp8_dot_fwd_rule(
sharding_type=sharding_type, x,
dp_axis_name="",
tp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
kernel_tp_index = None
# TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
kernel_tp_index = len(kernel.shape) - 1
elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
kernel_tp_index = 0
input_tp_index = len(inputs.shape) - 1
sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
axis_resources = merge_axis_resources(
[sharding_meta.axis_resources, fp8_sharding_meta.axis_resources])
partial_fp8_dot = partial(_fp8_dot,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes),
sharding_meta.out_axes, axis_resources,
(inputs_, kernel_, fp8_max, amax, scale, scale_inv))
res = jnp.reshape(res, sharding_meta.output_shapes[0])
return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs,
kernel,
fp8_maxs,
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
return res
def _fp8_dot_fwd(
inputs,
kernel, kernel,
fp8_maxs, fp8_max,
amax, amax,
scale, scale,
scale_inv, scale_inv,
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
contracting_dims, contracting_dims):
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):] x_shape_suf = x.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1] kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:] assert x_shape_suf == kernel_shape_pre
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx, 0:1] x_amax = amax[gemm_x_idx, 0:1]
input_scale = scale[gemm_input_idx] x_scale = scale[gemm_x_idx]
input_scale_inv = scale_inv[gemm_input_idx] x_scale_inv = scale_inv[gemm_x_idx]
input_cast, input_cast_trans, input_amax = cast_transpose(inputs_, input_amax, input_scale,
input_scale_inv, fwd_dtype) casted_x, casted_xt, updated_x_amax = \
cast_transpose(x, x_amax, x_scale, x_scale_inv, fwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))
kernel_amax = amax[gemm_kernel_idx, 0:1] kernel_amax = amax[gemm_kernel_idx, 0:1]
kernel_scale = scale[gemm_kernel_idx] kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx] kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
kernel_scale_inv, fwd_dtype)
res = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, input_cast, input_scale_inv,
fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW): casted_kerenl, casted_kerenl_t, updated_kernel_amax = \
res = jax.lax.psum(res, tp_axis_name) cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv,
fwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=(max(rhs_contracting_dims) + 1))
# (input_shape_pre, input_shape_suf) rhs_t_contracting_dims = tuple(range(kernel.ndim - len(rhs_contracting_dims), kernel.ndim))
# x (kernel_shape_pre, kernel_shape_suf) output = fp8_dot_impl(casted_x, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype,
# = (input_shape_pre, kernel_shape_suf) (lhs_contracting_dims, rhs_t_contracting_dims))
output_shape = input_shape_pre + kernel_shape_suf
res = jnp.reshape(res, output_shape)
ctx = (input_cast_trans, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax, ctx = (casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
inputs.shape, kernel.shape) updated_kernel_amax, x.shape, kernel.shape)
return res, ctx return output, ctx
def _fp8_dot_bwd( def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # pylint: disable=unused-argument
fwd_dtype, lhs_contracting_dims, rhs_contracting_dims = contracting_dims
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, \
sharding_type, updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx
dp_axis_name,
tp_axis_name, gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
fsdp_axis_name,
ctx,
g):
input_cast_trans, kernel_cast, \
fp8_maxs, amax, scale, scale_inv, \
input_amax, kernel_amax, \
inputs_shape, kernel_shape = ctx
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx, 0:1] grad_amax = amax[gemm_grad_idx, 0:1]
grad_scale = scale[gemm_grad_idx] grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx] grad_scale_inv = scale_inv[gemm_grad_idx]
g = jnp.reshape(g, (input_cast_trans.shape[1], -1))
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
input_scale_inv = scale_inv[gemm_input_idx]
wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype,
True, input_cast_trans, input_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
kernel_scale_inv = scale_inv[gemm_kernel_idx] casted_grad, casted_grad_t, updated_grad_amax = \
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv, cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) bwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))
amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]): xt_constracting_dim = tuple(range(len(lhs_contracting_dims), len(x_shape)))
wgrad = jax.lax.psum(wgrad, dp_axis_name) gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
amax = jax.lax.pmax(amax, dp_axis_name) x_scale_inv = scale_inv[gemm_x_idx]
wgrad = fp8_dot_impl(casted_xt, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(xt_constracting_dim, gt_constracting_dim))
if len(fsdp_axis_name) > 0: g_constracting_dim = tuple(
wgrad = jax.lax.psum(wgrad, fsdp_axis_name) range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
amax = jax.lax.pmax(amax, fsdp_axis_name) k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_constracting_dim, k_constracting_dim))
if is_tp_enabled(sharding_type.value[0]): amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
amax = jax.lax.pmax(amax, tp_axis_name) amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL): scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
dgrad = jax.lax.psum(dgrad, tp_axis_name)
dgrad = jnp.reshape(dgrad, inputs_shape) return dgrad, wgrad, fp8_max, amax, scale, scale_inv
wgrad = jnp.reshape(wgrad, kernel_shape)
return dgrad, wgrad, fp8_maxs, amax, scale, scale_inv
_fp8_dot.defvjp(_fp8_dot_fwd, _fp8_dot_bwd) _fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule)
...@@ -6,6 +6,7 @@ Wrapper module for Transformer related layers with FP8 support. ...@@ -6,6 +6,7 @@ Wrapper module for Transformer related layers with FP8 support.
""" """
import functools import functools
import operator import operator
import warnings
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
import jax.numpy as jnp import jax.numpy as jnp
...@@ -16,14 +17,12 @@ from jax import lax ...@@ -16,14 +17,12 @@ from jax import lax
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from ..dot import fp8_dot from ..dot import type_safe_dot_general
from ..fp8 import FP8GemmPackage, FP8Helper from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import fp8_ln_mlp, geglu from ..mlp import layernrom_geglu_fp8_mlp, geglu
from ..sharding import infer_sharding_type
from ..softmax import is_softmax_kernel_available from ..softmax import is_softmax_kernel_available
from ..sharding import MajorShardingType, ShardingType
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
PRNGKey = Any PRNGKey = Any
...@@ -119,16 +118,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -119,16 +118,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
Scalar for the input to softmax. Scalar for the input to softmax.
softmax_type : SoftmaxType, default = SoftmaxType.SCALED softmax_type : SoftmaxType, default = SoftmaxType.SCALED
Indicate the type of softmax. Indicate the type of softmax.
Optimization parameters
-----------------------
sharding_type : ShardingType, default = ShardingType.SINGLE
Indicate the sharding pattern.
""" """
scale_factor: float = 1.0 scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED softmax_type: SoftmaxType = SoftmaxType.SCALED
sharding_type: ShardingType = ShardingType.SINGLE
@nn.compact @nn.compact
def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray: def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
...@@ -149,8 +142,7 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -149,8 +142,7 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
if self.softmax_type is not SoftmaxType.SCALED_MASKED: if self.softmax_type is not SoftmaxType.SCALED_MASKED:
mask_ = None mask_ = None
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type, outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
self.sharding_type)
else: else:
attention_bias = None attention_bias = None
if mask is not None: if mask is not None:
...@@ -168,8 +160,7 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -168,8 +160,7 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
# and kernel is unavailable, then try on pure scaled softmax custom calls. # and kernel is unavailable, then try on pure scaled softmax custom calls.
if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen,
dtype): dtype):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED, outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
self.sharding_type)
else: else:
outputs = jax_nn.softmax(logits * self.scale_factor) outputs = jax_nn.softmax(logits * self.scale_factor)
...@@ -242,8 +233,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -242,8 +233,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
sharding_type : ShardingType, default = ShardingType.SINGLE
Indicate the sharding pattern.
""" """
epsilon: float = 1e-6 epsilon: float = 1e-6
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
...@@ -254,7 +243,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -254,7 +243,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_axes: Tuple[str, ...] = ('embed',) bias_axes: Tuple[str, ...] = ('embed',)
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE sharding_type = None
def __post_init__(self): def __post_init__(self):
self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init = _obtain_default_layernorm_scale_init_if_need(
...@@ -276,6 +265,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -276,6 +265,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
outputs : jax.numpy.ndarray outputs : jax.numpy.ndarray
Output tensors. Output tensors.
""" """
warnings.warn("sharding_type of LayerNorm would be removed in the near feature",
DeprecationWarning)
features = x.shape[-1] features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,), scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
...@@ -286,9 +277,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -286,9 +277,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
ln_bias, ln_bias,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon)
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
class TransformerEngineBase(nn.Module): class TransformerEngineBase(nn.Module):
...@@ -329,17 +318,15 @@ class TransformerEngineBase(nn.Module): ...@@ -329,17 +318,15 @@ class TransformerEngineBase(nn.Module):
return fp8_max.value, fp8_metas_amax.value, fp8_metas_scale.value, fp8_metas_scale_inv.value return fp8_max.value, fp8_metas_amax.value, fp8_metas_scale.value, fp8_metas_scale_inv.value
@staticmethod @staticmethod
def get_fp8_gemm_package(num_of_gemm: int, inputs: jnp.ndarray, def get_fp8_meta_package(num_of_gemm: int) -> FP8MetaPackage:
kernels: List[jnp.ndarray]) -> FP8GemmPackage:
""" """
Get the FP8 metas Get the FP8 metas
""" """
assert num_of_gemm == len(kernels)
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \ fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
TransformerEngineBase.get_fp8_metas(num_of_gemm) TransformerEngineBase.get_fp8_metas(num_of_gemm)
return FP8GemmPackage(num_of_gemm, inputs, kernels, fp8_max, fp8_metas_amax, return FP8MetaPackage(num_of_gemm, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale, fp8_metas_scale_inv) fp8_metas_scale_inv)
class DenseGeneral(TransformerEngineBase): class DenseGeneral(TransformerEngineBase):
...@@ -376,8 +363,6 @@ class DenseGeneral(TransformerEngineBase): ...@@ -376,8 +363,6 @@ class DenseGeneral(TransformerEngineBase):
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
sharding_type : ShardingType, default = ShardingType.SINGLE
Indicate the sharding pattern.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -389,7 +374,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -389,7 +374,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE sharding_type = None
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -411,6 +396,9 @@ class DenseGeneral(TransformerEngineBase): ...@@ -411,6 +396,9 @@ class DenseGeneral(TransformerEngineBase):
outputs : jax.numpy.ndarray outputs : jax.numpy.ndarray
Output tensors. Output tensors.
""" """
warnings.warn("sharding_type of DenseGeneral would be removed in the near feature",
DeprecationWarning)
features = _canonicalize_tuple(self.features) features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis) axis = _canonicalize_tuple(self.axis)
...@@ -438,18 +426,15 @@ class DenseGeneral(TransformerEngineBase): ...@@ -438,18 +426,15 @@ class DenseGeneral(TransformerEngineBase):
bias = None bias = None
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
fp8_gemm_pkg = None
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
fp8_gemm_package = \ fp8_gemm_pkg = \
TransformerEngineBase.get_fp8_gemm_package(1, inputs, [kernel]) TransformerEngineBase.get_fp8_meta_package(1)
y = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE, y = type_safe_dot_general(inputs,
FP8Helper.BWD_DTYPE, (axis, contract_ind), kernel,
sharding_type=self.sharding_type, fp8_meta_pkg=fp8_gemm_pkg,
dp_dim_index=1 if self.transpose_batch_sequence else 0) contracting_dims=(axis, contract_ind))
else:
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
if bias is not None: if bias is not None:
bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
...@@ -528,8 +513,6 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -528,8 +513,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float, default = None depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied. value or None. When None is set, then no scaling is applied.
sharding_type : ShardingType, default = ShardingType.SINGLE
Indicate the sharding pattern.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -551,7 +534,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -551,7 +534,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
depth_scaling: float = None depth_scaling: float = None
sharding_type: ShardingType = ShardingType.SINGLE sharding_type = None
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -578,12 +561,16 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -578,12 +561,16 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
warnings.warn("sharding_type of LayerNormDenseGeneral would be removed in the near feature",
DeprecationWarning)
ln_output = None ln_output = None
fuse_layernorm = FP8Helper.is_fp8_enabled( fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm ) and not self.return_layernorm_output and self.enable_layernorm
if self.enable_layernorm: if self.enable_layernorm:
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1] features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,), scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
...@@ -597,9 +584,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -597,9 +584,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
ln_bias, ln_bias,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon)
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else: else:
assert not self.return_layernorm_output assert not self.return_layernorm_output
y = inputs y = inputs
...@@ -627,30 +612,25 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -627,30 +612,25 @@ class LayerNormDenseGeneral(TransformerEngineBase):
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
fp8_meta_package = None
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
fp8_gemm_package = \ fp8_meta_package = \
TransformerEngineBase.get_fp8_gemm_package(1, y, [kernel]) TransformerEngineBase.get_fp8_meta_package(1)
if not fuse_layernorm: if fuse_layernorm:
z = fp8_dot(fp8_gemm_package, z = layernorm_fp8_dot(y,
FP8Helper.FWD_DTYPE, kernel,
FP8Helper.BWD_DTYPE, (axis, contract_ind), scale,
sharding_type=self.sharding_type, ln_bias,
dp_dim_index=1 if self.transpose_batch_sequence else 0) fp8_meta_package,
else: self.layernorm_type,
z = layernorm_fp8_dot(fp8_gemm_package, zero_centered_gamma=self.zero_centered_gamma,
scale, epsilon=self.epsilon)
ln_bias,
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else: else:
kernel = jnp.asarray(kernel, self.dtype) z = type_safe_dot_general(y,
z = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ()))) kernel,
fp8_meta_pkg=fp8_meta_package,
contracting_dims=(axis, contract_ind))
bias = None bias = None
if self.use_bias: if self.use_bias:
...@@ -758,8 +738,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -758,8 +738,6 @@ class LayerNormMLP(TransformerEngineBase):
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
major_sharding_type : MajorShardingType, default = MajorShardingType.SINGLE
Indicate the sharding pattern.
""" """
intermediate_dim: int = 2048 intermediate_dim: int = 2048
...@@ -776,10 +754,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -776,10 +754,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed') kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = ( bias_axes_1: Tuple[str, ...] = ('act', 'mlp')
'act',
'mlp',
)
bias_axes_2: Tuple[str, ...] = ('embed',) bias_axes_2: Tuple[str, ...] = ('embed',)
return_layernorm_output: bool = True return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',) activations: Sequence[Union[str, Callable]] = ('relu',)
...@@ -789,7 +764,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -789,7 +764,7 @@ class LayerNormMLP(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE major_sharding_type = None
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -818,19 +793,32 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -818,19 +793,32 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
warnings.warn("major_sharding_type of LayerNormMLP would be removed in the near feature",
DeprecationWarning)
ln_output = None ln_output = None
fuse_layernorm = FP8Helper.is_fp8_enabled( fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm ) and not self.return_layernorm_output and self.enable_layernorm
def is_geglu(acts):
geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')]
normalize_acts = []
for act in acts:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return normalize_acts in geglu_act_pool
use_fused_ln_mlp = fuse_layernorm \ use_fused_ln_mlp = fuse_layernorm \
and (not self.use_bias) and self.activations == ('gelu', 'linear') \ and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3) and (self.intermediate_dropout_rate < 1e-3)
first_sharding_type, second_sharding_type = infer_sharding_type(self.major_sharding_type)
# LayerNorm # LayerNorm
if self.enable_layernorm: if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment
features = inputs.shape[-1] features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,), scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
...@@ -844,9 +832,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -844,9 +832,7 @@ class LayerNormMLP(TransformerEngineBase):
ln_bias, ln_bias,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon)
sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else: else:
assert not self.return_layernorm_output assert not self.return_layernorm_output
y = inputs y = inputs
...@@ -864,107 +850,67 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -864,107 +850,67 @@ class LayerNormMLP(TransformerEngineBase):
return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32) return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)
num_of_gemm = 2 num_of_gemm = 2
if use_fused_ln_mlp: fp8_meta_package = None
num_activations = len(self.activations) if FP8Helper.is_fp8_enabled():
axis = _canonicalize_tuple(self.axis) fp8_meta_package = \
axis = _normalize_axes(axis, inputs.ndim) TransformerEngineBase.get_fp8_meta_package(num_of_gemm)
intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
kernel_1_shape = tuple(inputs.shape[ax] for ax in axis) + intermediate_dim
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes('wi_kernel',
kernel_1_init,
num_activations,
-2,
kernel_1_each_shape,
jnp.float32,
axes=self.kernel_axes_1)
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
kernel_2 = nn_partitioning.param_with_axes('wo_kernel',
self.kernel_init,
kernel_2_param_shape,
jnp.float32,
axes=self.kernel_axes_2)
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
contract_ind = tuple(range(0, len(axis)))
fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(num_of_gemm, y, [kernel_1, kernel_2])
out = fp8_ln_mlp(fp8_gemm_package,
scale,
ln_bias,
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
contracting_dims=(axis, contract_ind),
major_sharding_type=self.major_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
activations=self.activations)
else: # not use_fused_ln_mlp
def fp8_meta_generator(): num_activations = len(self.activations)
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = (None, None, None, axis = _canonicalize_tuple(self.axis)
None) axis = _normalize_axes(axis, y.ndim)
if FP8Helper.is_fp8_enabled():
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
TransformerEngineBase.get_fp8_metas(num_of_gemm)
return fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \ intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
fp8_meta_generator() kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes('wi_kernel',
kernel_1_init,
num_activations,
-2,
kernel_1_each_shape,
jnp.float32,
axes=self.kernel_axes_1)
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
kernel_2 = nn_partitioning.param_with_axes('wo_kernel',
self.kernel_init,
kernel_2_param_shape,
jnp.float32,
axes=self.kernel_axes_2)
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
contract_ind = tuple(range(0, len(axis)))
# DenseGeneral 1 if use_fused_ln_mlp:
activations = [] assert self.axis == -1 # Only support axis = =-1 at this moment
num_activations = len(self.activations)
axis = _canonicalize_tuple(self.axis) out = layernrom_geglu_fp8_mlp(y,
axis = _normalize_axes(axis, y.ndim)
intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
kernel_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel = nn_partitioning.param_with_axes('wi_kernel',
kernel_1_init,
num_activations,
-2,
kernel_1_each_shape,
jnp.float32,
axes=self.kernel_axes_1)
kernel = jnp.reshape(kernel, kernel_shape)
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.is_fp8_enabled():
fp8_gemm_package = FP8GemmPackage(
1, y, [kernel], fp8_max[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_amax[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_scale[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_scale_inv[:FP8Helper.NUM_META_PER_GEMM, :])
if not fuse_layernorm:
x = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
x = layernorm_fp8_dot(fp8_gemm_package,
scale, scale,
ln_bias, ln_bias, [kernel_1, kernel_2],
fp8_meta_package,
self.layernorm_type, self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon)
sharding_type=first_sharding_type, else: # not use_fused_ln_mlp
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else: # not enable fp8 # DenseGeneral 1
kernel = jnp.asarray(kernel, self.dtype) gemm1_fp8_meta_package = None if fp8_meta_package is None \
x = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ()))) else fp8_meta_package.get_package_by_gemm_idx(0)
if fuse_layernorm:
x = layernorm_fp8_dot(y,
kernel_1,
scale,
ln_bias,
gemm1_fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon)
else:
x = type_safe_dot_general(y,
kernel_1,
fp8_meta_pkg=gemm1_fp8_meta_package,
contracting_dims=(axis, contract_ind))
bias = None bias = None
if self.use_bias: if self.use_bias:
...@@ -977,11 +923,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -977,11 +923,9 @@ class LayerNormMLP(TransformerEngineBase):
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape) x += jnp.reshape(bias, bias_shape)
if self.activations == ('gelu', 'linear'): activations = []
z = geglu(x, if is_geglu(self.activations):
contracting_dims=(-2, -1), z = geglu(x)
sharding_type=second_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else: else:
x = jnp.split(x, num_activations, axis=-2) x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations): for idx, act_fn in enumerate(self.activations):
...@@ -996,37 +940,13 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -996,37 +940,13 @@ class LayerNormMLP(TransformerEngineBase):
z, deterministic=deterministic) z, deterministic=deterministic)
# DenseGeneral 2 # DenseGeneral 2
hidden_size = inputs.shape[-1] gemm2_fp8_meta_package = None if fp8_meta_package is None \
hidden_size_tuple = _canonicalize_tuple(hidden_size) else fp8_meta_package.get_package_by_gemm_idx(1)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, z.ndim) out = type_safe_dot_general(z,
kernel_2,
kernel_shape = tuple(z.shape[ax] for ax in axis) + hidden_size_tuple fp8_meta_pkg=gemm2_fp8_meta_package,
kernel_param_shape = (np.prod([z.shape[ax] for ax in axis]), np.prod(hidden_size_tuple)) contracting_dims=(axis, contract_ind))
kernel = nn_partitioning.param_with_axes('wo_kernel',
self.kernel_init,
kernel_param_shape,
jnp.float32,
axes=self.kernel_axes_2)
kernel = jnp.reshape(kernel, kernel_shape)
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.is_fp8_enabled():
fp8_gemm_package = FP8GemmPackage(
1, z, [kernel], fp8_max[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_amax[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_scale[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_scale_inv[FP8Helper.NUM_META_PER_GEMM:, :])
out = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=second_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
kernel = jnp.asarray(kernel, self.dtype)
out = lax.dot_general(z, kernel, ((axis, contract_ind), ((), ())))
bias = None bias = None
if self.use_bias: if self.use_bias:
......
...@@ -27,9 +27,8 @@ from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout ...@@ -27,9 +27,8 @@ from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import infer_major_sharding_type, infer_sharding_type from ..sharding import global_mesh_resource, num_of_devices
from ..sharding import global_shard_resource, with_sharding_constraint from ..sharding import with_sharding_constraint
from ..sharding import ShardingType
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -102,7 +101,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -102,7 +101,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
else: else:
rules_map[key] = [val] rules_map[key] = [val]
gsr = global_shard_resource() gsr = global_mesh_resource()
batch_dim_rule = [] batch_dim_rule = []
if gsr.dp_resource is not None: if gsr.dp_resource is not None:
...@@ -186,7 +185,6 @@ def core_attention(query: Array, ...@@ -186,7 +185,6 @@ def core_attention(query: Array,
scale_factor: float, scale_factor: float,
transpose_batch_sequence: bool, transpose_batch_sequence: bool,
softmax_type: SoftmaxType = SoftmaxType.SCALED, softmax_type: SoftmaxType = SoftmaxType.SCALED,
softmax_sharding_type: ShardingType = ShardingType.SINGLE,
mask: Optional[Array] = None, mask: Optional[Array] = None,
bias: Optional[Array] = None, bias: Optional[Array] = None,
dropout_rng: Optional[PRNGKey] = None, dropout_rng: Optional[PRNGKey] = None,
...@@ -226,9 +224,7 @@ def core_attention(query: Array, ...@@ -226,9 +224,7 @@ def core_attention(query: Array,
fused_scale_factor = scale_factor fused_scale_factor = scale_factor
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor, scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
sharding_type=softmax_sharding_type)(attn_weights, mask,
bias).astype(dtype)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
...@@ -482,8 +478,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -482,8 +478,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"Fused attention is not enabled. Because " \ f"Fused attention is not enabled. Because " \
f"{reason}fall back to unfused attention.") f"{reason}fall back to unfused attention.")
first_sharding_type, second_sharding_type = infer_sharding_type()
residual = inputs_q residual = inputs_q
if self.fuse_qkv: if self.fuse_qkv:
if is_self_attn: if is_self_attn:
...@@ -494,7 +488,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -494,7 +488,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=(3, self.num_heads * self.head_dim), features=(3, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
...@@ -516,7 +509,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -516,7 +509,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
...@@ -530,7 +522,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -530,7 +522,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
name='query')(inputs_q) name='query')(inputs_q)
kv_proj = DenseGeneral(axis=-1, kv_proj = DenseGeneral(axis=-1,
features=(2, self.num_heads * self.head_dim), features=(2, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init, kernel_init=kv_init,
...@@ -546,7 +537,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -546,7 +537,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
DenseGeneral, DenseGeneral,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -560,7 +550,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -560,7 +550,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True, return_layernorm_output=True,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
...@@ -648,7 +637,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -648,7 +637,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
seed = None seed = None
if dropout_rng is not None: if dropout_rng is not None:
seed = jax.random.split(dropout_rng, len(jax.devices())) seed = jax.random.split(dropout_rng, num_of_devices())
# ensure the old key never used # ensure the old key never used
del dropout_rng del dropout_rng
...@@ -665,8 +654,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -665,8 +654,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.dropout_rate, dropout_probability=self.dropout_rate,
is_training=not deterministic, is_training=not deterministic)
sharding_type=first_sharding_type)
else: else:
assert bias is None assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
...@@ -685,8 +673,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -685,8 +673,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.dropout_rate, dropout_probability=self.dropout_rate,
is_training=not deterministic, is_training=not deterministic)
sharding_type=first_sharding_type)
else: else:
def convert_to_softmax_type(attn_mask_type, mask): def convert_to_softmax_type(attn_mask_type, mask):
...@@ -710,7 +697,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -710,7 +697,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
softmax_type=softmax_type, softmax_type=softmax_type,
softmax_sharding_type=first_sharding_type,
mask=mask, mask=mask,
bias=bias, bias=bias,
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
...@@ -728,7 +714,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -728,7 +714,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
x = _with_sharding_constraint(x, attn_context_sharding_constraint) x = _with_sharding_constraint(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1], out = DenseGeneral(features=inputs_q.shape[-1],
sharding_type=second_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1, axis=-1,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
...@@ -1175,7 +1160,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1175,7 +1160,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
major_sharding_type=infer_major_sharding_type(),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size, intermediate_dim=self.mlp_hidden_size,
...@@ -1208,7 +1192,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1208,7 +1192,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
z = z + residual z = z + residual
if self.output_layernorm: if self.output_layernorm:
ln_sharding_type, _ = infer_sharding_type()
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
...@@ -1216,7 +1199,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1216,7 +1199,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,), bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype, dtype=self.dtype,
sharding_type=ln_sharding_type,
name="output_layer_norm")(z) name="output_layer_norm")(z)
return z return z
...@@ -6,7 +6,7 @@ Helper module for fp8 meta management ...@@ -6,7 +6,7 @@ Helper module for fp8 meta management
""" """
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -17,7 +17,7 @@ from transformer_engine_jax import get_cublasLt_version ...@@ -17,7 +17,7 @@ from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability from transformer_engine_jax import get_cuda_version, get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import ShardingResource from transformer_engine.jax.sharding import MeshResource
_is_fp8_available = None _is_fp8_available = None
_reason_for_no_fp8 = "" _reason_for_no_fp8 = ""
...@@ -59,37 +59,29 @@ def is_fp8_available(gpu_id=None) -> Tuple[bool, str]: ...@@ -59,37 +59,29 @@ def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
def _format2dtypes(format_: Format): def _format2dtypes(format_: Format):
if format_ == Format.E4M3: if format_ == Format.E4M3:
return DType.kFloat8E4M3, DType.kFloat8E4M3 return jnp.float8_e4m3fn, jnp.float8_e4m3fn
if format_ == Format.E5M2: if format_ == Format.E5M2:
return DType.kFloat8E5M2, DType.kFloat8E5M2 return jnp.float8_e5m2, jnp.float8_e5m2
if format_ == Format.HYBRID: if format_ == Format.HYBRID:
return DType.kFloat8E4M3, DType.kFloat8E5M2 return jnp.float8_e4m3fn, jnp.float8_e5m2
return DType.kBFloat16, DType.kBFloat16 return jnp.bfloat16, jnp.bfloat16
class FP8GemmPackage: class FP8MetaPackage:
""" """
A container that contains all required data for A container that contains all required meta data for FP8
FP8 GEMM
""" """
def __init__( def __init__(
self, self,
num_of_gemm: int, num_of_gemm: int,
inputs: jnp.ndarray,
kernels: List[jnp.ndarray],
fp8_max: jnp.ndarray, fp8_max: jnp.ndarray,
amax: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, scale_inv: jnp.ndarray,
) -> None: ) -> None:
total_num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
self._num_of_gemm = num_of_gemm self._num_of_gemm = num_of_gemm
self._inputs = inputs
assert len(kernels) == self._num_of_gemm
self._kernels = kernels
total_num_of_meta = self._num_of_gemm * FP8Helper.NUM_META_PER_GEMM
assert fp8_max.shape[0] == total_num_of_meta assert fp8_max.shape[0] == total_num_of_meta
self._fp8_max = fp8_max self._fp8_max = fp8_max
assert amax.shape[0] == total_num_of_meta assert amax.shape[0] == total_num_of_meta
...@@ -106,20 +98,6 @@ class FP8GemmPackage: ...@@ -106,20 +98,6 @@ class FP8GemmPackage:
""" """
return self._num_of_gemm return self._num_of_gemm
@property
def inputs(self) -> jnp.ndarray:
"""
inputs of this package
"""
return self._inputs
@property
def kernels(self) -> List[jnp.ndarray]:
"""
kernels of this package
"""
return self._kernels
@property @property
def fp8_max(self) -> jnp.ndarray: def fp8_max(self) -> jnp.ndarray:
""" """
...@@ -148,6 +126,19 @@ class FP8GemmPackage: ...@@ -148,6 +126,19 @@ class FP8GemmPackage:
""" """
return self._scale_inv return self._scale_inv
def get_package_by_gemm_idx(self, gemm_idx):
"""
Get a sub package by gemm_idx
"""
assert self.num_of_gemm > gemm_idx
meta_start_idx = gemm_idx * FP8Helper.NUM_META_PER_GEMM
meta_end_idx = (gemm_idx + 1) * FP8Helper.NUM_META_PER_GEMM
return FP8MetaPackage(1, self.fp8_max[meta_start_idx:meta_end_idx],
self.amax[meta_start_idx:meta_end_idx],
self.scale[meta_start_idx:meta_end_idx],
self.scale_inv[meta_start_idx:meta_end_idx])
class AmaxComputeAlgo(Enum): class AmaxComputeAlgo(Enum):
"""AmaxComputeAlgo.""" """AmaxComputeAlgo."""
...@@ -155,6 +146,9 @@ class AmaxComputeAlgo(Enum): ...@@ -155,6 +146,9 @@ class AmaxComputeAlgo(Enum):
MOST_RECENT = "most_recent" MOST_RECENT = "most_recent"
NVTE_FP8_COLLECTION_NAME = "fp8_meta_collection"
class FP8Helper: class FP8Helper:
""" """
FP8 helper to manage the FP8 meta FP8 helper to manage the FP8 meta
...@@ -162,8 +156,8 @@ class FP8Helper: ...@@ -162,8 +156,8 @@ class FP8Helper:
INITIALIZED = False INITIALIZED = False
MARGIN: float = 0.0 MARGIN: float = 0.0
FP8_FORMAT: Format = Format.HYBRID FP8_FORMAT: Format = Format.HYBRID
FWD_DTYPE: DType = DType.kFloat8E4M3 FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
BWD_DTYPE: DType = DType.kFloat8E5M2 BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
UPDATE_FP8META_INTERVAL: int = 1 UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_LEN: int = 1024 AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
...@@ -171,7 +165,7 @@ class FP8Helper: ...@@ -171,7 +165,7 @@ class FP8Helper:
INPUT_META_IDX_PER_GEMM: int = 0 INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1 KERNEL_META_IDX_PER_GEMM: int = 1
GRAD_META_IDX_PER_GEMM: int = 2 GRAD_META_IDX_PER_GEMM: int = 2
FP8_COLLECTION_NAME: str = "fp8_meta_collection" FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_AMAX_NAME: str = "fp8_meta_amax" FP8_AMAX_NAME: str = "fp8_meta_amax"
FP8_SCALE_NAME: str = "fp8_meta_scale" FP8_SCALE_NAME: str = "fp8_meta_scale"
FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv" FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
...@@ -216,21 +210,12 @@ class FP8Helper: ...@@ -216,21 +210,12 @@ class FP8Helper:
FP8Helper.INITIALIZED = False FP8Helper.INITIALIZED = False
FP8Helper.MARGIN = 0.0 FP8Helper.MARGIN = 0.0
FP8Helper.FP8_FORMAT = Format.HYBRID FP8Helper.FP8_FORMAT = Format.HYBRID
FP8Helper.FWD_DTYPE = DType.kFloat8E4M3 FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
FP8Helper.BWD_DTYPE = DType.kFloat8E5M2 _format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = 1 FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_LEN = 1024 FP8Helper.AMAX_HISTORY_LEN = 1024
FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
@staticmethod
def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax_buffers = jnp.roll(amax_buffers, -1, 1)
updated_amax_buffers = updated_amax_buffers.at[:, 0].set(0)
return updated_amax_buffers
@staticmethod @staticmethod
def update_collections(new: Collection, original: Collection) -> Collection: def update_collections(new: Collection, original: Collection) -> Collection:
""" """
...@@ -270,8 +255,8 @@ class FP8Helper: ...@@ -270,8 +255,8 @@ class FP8Helper:
Generate the FP8 max array Generate the FP8 max array
""" """
num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
fp8_max_fwd = FP8Helper.FP8_FORMAT.value.max_fwd fp8_max_fwd = jnp.finfo(FP8Helper.FWD_DTYPE).max
fp8_max_bwd = FP8Helper.FP8_FORMAT.value.max_bwd fp8_max_bwd = jnp.finfo(FP8Helper.BWD_DTYPE).max
fp8_max_per_gemm = [] fp8_max_per_gemm = []
for i in range(FP8Helper.NUM_META_PER_GEMM): for i in range(FP8Helper.NUM_META_PER_GEMM):
val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \ val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
...@@ -318,11 +303,40 @@ class FP8Helper: ...@@ -318,11 +303,40 @@ class FP8Helper:
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays) return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
@staticmethod
def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax = jnp.roll(amax, -1, -1)
updated_amax = updated_amax.at[..., 0].set(0)
return updated_amax
@staticmethod
@jax.jit
def update_fp8_scale(fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray) -> jnp.ndarray:
"""
Calculate fp8 scale and scale_inv based on given amax.
"""
if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax, axis=-1, keepdims=True)
else:
amax = amax[..., 0:1]
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = sf
scale_inv = 1 / sf
return scale, scale_inv
@contextmanager @contextmanager
def fp8_autocast(enabled: bool = False, def fp8_autocast(enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[DelayedScaling] = None,
sharding_resource: Optional[ShardingResource] = None) -> None: mesh_resource: Optional[MeshResource] = None) -> None:
r""" r"""
Context manager for FP8 usage. Context manager for FP8 usage.
...@@ -334,9 +348,9 @@ def fp8_autocast(enabled: bool = False, ...@@ -334,9 +348,9 @@ def fp8_autocast(enabled: bool = False,
devices = np.asarray(jax.devices()).reshape(*mesh_shape) devices = np.asarray(jax.devices()).reshape(*mesh_shape)
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
sharding_resource=ShardingResource(dp_mesh_axis_name, tp_mesh_axis_name) mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, sharding_resource=sharding_resource): with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple()) rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer() transformer = TransformerLayer()
...@@ -356,7 +370,7 @@ def fp8_autocast(enabled: bool = False, ...@@ -356,7 +370,7 @@ def fp8_autocast(enabled: bool = False,
Whether or not to enable fp8 Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None fp8_recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training. Recipe used for FP8 training.
sharding_resource: ShardingResource, default = None mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along. Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used. If set to None, then no data or tensor parallelism will be used.
...@@ -373,11 +387,11 @@ def fp8_autocast(enabled: bool = False, ...@@ -373,11 +387,11 @@ def fp8_autocast(enabled: bool = False,
"DelayedScaling override_linear_precision isn't supported by TE/JAX.") "DelayedScaling override_linear_precision isn't supported by TE/JAX.")
assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.") assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")
if sharding_resource is None: if mesh_resource is None:
sharding_resource = ShardingResource() mesh_resource = MeshResource()
try: try:
with global_shard_guard(sharding_resource): with global_shard_guard(mesh_resource):
if enabled: if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available() fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8 assert fp8_available, reason_for_no_fp8
......
...@@ -15,12 +15,6 @@ from transformer_engine_jax import NVTE_QKV_Layout ...@@ -15,12 +15,6 @@ from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
class AttnBiasType(Enum): class AttnBiasType(Enum):
...@@ -54,62 +48,24 @@ def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, ...@@ -54,62 +48,24 @@ def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type,
head_dim).is_fused_attn_kernel_available() head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray, def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
bias: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
mask: jnp.ndarray, scaling_factor: float, dropout_probability: float, is_training: bool):
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
""" """
Self fused attention wrapper Self fused attention wrapper
""" """
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ assert attn_mask_type is not AttnMaskType.NO_MASK, \
"self_fused_attn does not support row-split tensor parallelism currently." "Currently not support AttnMaskType.NO_MASK."
if sharding_type is ShardingType.SINGLE: output = _self_fused_attn(qkv,
output = _self_fused_attn(qkv, bias,
bias, mask,
mask, seed,
seed, attn_bias_type=attn_bias_type,
attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type,
attn_mask_type=attn_mask_type, scaling_factor=scaling_factor,
scaling_factor=scaling_factor, dropout_probability=dropout_probability,
dropout_probability=dropout_probability, is_training=is_training)
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [qkv, bias, mask, seed]
batch, seqlen, _, num_head, head_dim = qkv.shape
output_shape = [batch, seqlen, num_head, head_dim]
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, None, 0, 0], [0]),
tp_dims=([3, 1, None, 0], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_self_fused_attn = partial(_self_fused_attn,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes)
return output return output
...@@ -118,119 +74,70 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -118,119 +74,70 @@ def self_fused_attn(qkv: jnp.ndarray,
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_fwd(qkv,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
return output
def _self_fused_attn_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen))
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
squeezed_mask = mask[:, :, :, 0]
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias, bias,
cu_seqlen, squeezed_mask,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (qkv, softmax_aux, rng_state, output, cu_seqlen) return output, (qkv, softmax_aux, rng_state, output, squeezed_mask)
def _self_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, grad): is_training, ctx, dz):
qkv, softmax_aux, rng_state, output, cu_seqlen = ctx qkv, softmax_aux, rng_state, output, squeezed_mask = ctx
doutput = grad
grad_qkv, grad_bias = self_fused_attn_bwd(qkv, grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
doutput, dz,
cu_seqlen, squeezed_mask,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None grad_bias = None
return grad_qkv, grad_bias, None, None return grad_qkv, grad_bias, None, None
_self_fused_attn.defvjp(_self_fused_attn_fwd, _self_fused_attn_bwd) _self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
def cross_fused_attn(q: jnp.ndarray, def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
kv: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
mask: jnp.ndarray, scaling_factor: float, dropout_probability: float, is_training: bool):
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
""" """
Cross multi-head attention wrapper Cross multi-head attention wrapper
""" """
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"cross_fused_attn does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE:
output = _cross_fused_attn(q,
kv,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [q, kv, mask, seed]
output_shape = q.shape
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, 0, 0, None], [0]),
tp_dims=([2, 3, None, None], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_cross_fused_attn = partial(_cross_fused_attn,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes, output = _cross_fused_attn(q,
sharding_meta.out_axes, sharding_meta.axis_resources, inputs_) kv,
mask,
output = jnp.reshape(output_, sharding_meta.output_shapes) seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output return output
...@@ -240,54 +147,40 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: ...@@ -240,54 +147,40 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed:
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_fwd(q, output, _ = _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type,
kv, scaling_factor, dropout_probability, is_training)
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output return output
def _cross_fused_attn_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, def _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
q_cu_seqlen = jnp.cumsum(q_seqlen)
q_cu_seqlen = jnp.hstack((0, q_cu_seqlen))
kv_seqlen = jnp.sum(mask[:, :, 0, :] == 0, axis=(-1, -2), dtype=jnp.int32) q_squeezed_mask = mask[:, :, :, 0]
kv_cu_seqlen = jnp.cumsum(kv_seqlen) kv_squeezed_mask = mask[:, :, 0, :]
kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))
output, softmax_aux = cross_fused_attn_fwd(q, output, softmax_aux = cross_fused_attn_fwd(q,
kv, kv,
q_cu_seqlen, q_squeezed_mask,
kv_cu_seqlen, kv_squeezed_mask,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen) return output, (softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask)
def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, grad):
softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx
doutput = grad def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask = ctx
grad_q, grad_kv = cross_fused_attn_bwd(q, grad_q, grad_kv = cross_fused_attn_bwd(q,
kv, kv,
softmax_aux, softmax_aux,
doutput, dz,
q_cu_seqlen, q_squeezed_mask,
kv_cu_seqlen, kv_squeezed_mask,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -297,4 +190,4 @@ def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropou ...@@ -297,4 +190,4 @@ def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropou
return grad_q, grad_kv, None, None return grad_q, grad_kv, None, None
_cross_fused_attn.defvjp(_cross_fused_attn_fwd, _cross_fused_attn_bwd) _cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
...@@ -3,25 +3,15 @@ ...@@ -3,25 +3,15 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX layernorm modules""" """JAX layernorm modules"""
from typing import Tuple, Sequence from functools import partial
from functools import partial, reduce
import operator
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import DType as TEDType from .cpp_extensions import cast_transpose, transpose
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .cpp_extensions import transpose
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from .fp8 import FP8Helper, FP8GemmPackage from .dot import fp8_dot_impl
from .sharding import ShardingType, get_elementwise_sharding_meta from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def canonicalize_layernorm_type(x): def canonicalize_layernorm_type(x):
...@@ -38,421 +28,241 @@ def layernorm(inputs: jnp.ndarray, ...@@ -38,421 +28,241 @@ def layernorm(inputs: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
layernorm_type: str, layernorm_type: str,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6):
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0):
""" """
Layernorm wrapper LN/RMSNorm wrapper
Only support layernorm_type in ['layernorm', 'rmsnorm']
""" """
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ output = _layernorm(inputs,
"layernorm does not support row-split tensor parallelism currently." gamma,
beta,
layernorm_type = canonicalize_layernorm_type(layernorm_type) layernorm_type=layernorm_type,
if layernorm_type == 'rmsnorm': zero_centered_gamma=zero_centered_gamma,
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'" epsilon=epsilon)
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
if sharding_type is ShardingType.SINGLE:
output = _layernorm(inputs,
gamma,
beta,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
dp_dim_index, dp_axis_name, tp_axis_name)
sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
gamma_ = jnp.reshape(gamma, sharding_meta.input_shapes[1]) # 1 for gamma
beta_ = beta
beta_in_axis = {}
if beta_ is not None:
beta_ = jnp.reshape(beta_, sharding_meta.input_shapes[1]) # 1 for beta
beta_in_axis = sharding_meta.in_axes[1]
in_axes = (*sharding_meta.in_axes, beta_in_axis)
partial_ln = partial(_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
fsdp_axis_name=fsdp_axis_name)
output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, gamma_, beta_))
output = jnp.reshape(output, sharding_meta.output_shapes[0])
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _layernorm(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, sharding_type, def _layernorm(x,
dp_axis_name, fsdp_axis_name): gamma,
output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, beta,
sharding_type, dp_axis_name, fsdp_axis_name) layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6):
output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon)
return output return output
def _layernorm_fwd( def _layernorm_fwd_rule(x,
x, gamma,
gamma, beta,
beta, layernorm_type: str,
layernorm_type, zero_centered_gamma: bool = False,
zero_centered_gamma, epsilon: float = 1e-6):
epsilon, layernorm_type = canonicalize_layernorm_type(layernorm_type)
sharding_type, # pylint: disable=unused-argument
dp_axis_name, # pylint: disable=unused-argument
fsdp_axis_name # pylint: disable=unused-argument
):
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon) output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
else: elif layernorm_type == 'rmsnorm':
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'" "if layernorm_type is 'rmsnorm'"
output, rsigma = rmsnorm_fwd(x, gamma, epsilon) output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
mu = None mu = None
return output, (mu, rsigma, x, gamma) else:
raise ValueError(f"{layernorm_type=} is not supported.")
return output, (x, mu, rsigma, gamma)
def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name,
fsdp_axis_name, ctx, g):
mu, rsigma, x, gamma = ctx
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
x, mu, rsigma, gamma = ctx
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(g, dx, dgamma, dbeta = layernorm_bwd(dz,
mu, x,
rsigma, mu,
x, rsigma,
gamma, gamma,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon) epsilon=epsilon)
else: elif layernorm_type == 'rmsnorm':
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'" "if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(g, rsigma, x, gamma, epsilon=epsilon) dx, dgamma = rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
grad_beta = None dbeta = None
else:
if is_dp_enabled(sharding_type.value[0]): raise ValueError(f"{layernorm_type=} is not supported.")
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
if len(fsdp_axis_name) > 0:
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
return grad_input, grad_gamma, grad_beta return dx, dgamma, dbeta
_layernorm.defvjp(_layernorm_fwd, _layernorm_bwd) _layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, def layernorm_fp8_dot(x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
fp8_meta_pkg: FP8MetaPackage,
layernorm_type: str, layernorm_type: str,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6) -> jnp.ndarray:
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0) -> jnp.ndarray:
""" """
LN + fp8 dot fusion wrapper Layernorm + FP8 GEMM
""" """
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ fp8_max = fp8_meta_pkg.fp8_max
"layernorm_fp8_dot does not support row-split tensor parallelism currently." amax = fp8_meta_pkg.amax
scale = fp8_meta_pkg.scale
layernorm_type = canonicalize_layernorm_type(layernorm_type) scale_inv = fp8_meta_pkg.scale_inv
if layernorm_type == 'rmsnorm': fwd_dtype = FP8Helper.FWD_DTYPE
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'" bwd_dtype = FP8Helper.BWD_DTYPE
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
"if layernorm_type is 'rmsnorm'" layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon)
assert fp8_gemm_pkg.num_of_gemm == 1
inputs = fp8_gemm_pkg.inputs
kernel = fp8_gemm_pkg.kernels[0]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
if sharding_type is ShardingType.SINGLE:
output = _layernorm_fp8_dot(inputs,
kernel,
gamma,
beta,
fp8_max,
amax,
scale,
scale_inv,
layernorm_type,
fwd_dtype,
bwd_dtype,
contracting_dims,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name="",
tp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
ln_sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
dp_dim_index, dp_axis_name, tp_axis_name)
ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input
gamma_ = jnp.reshape(gamma, ln_sharding_meta.input_shapes[1]) # 1 for gamma
beta_ = beta
beta_in_axis = {}
if beta_ is not None:
beta_ = jnp.reshape(beta_, ln_sharding_meta.input_shapes[1]) # 1 for beta
beta_in_axis = ln_sharding_meta.in_axes[1]
kernel_tp_index = None
# TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
kernel_tp_index = len(kernel.shape) - 1
elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
kernel_tp_index = 0
input_tp_index = len(inputs.shape) - 1
dot_sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(dot_sharding_meta,
{0: dp_dim_index})
kernel_ = jnp.reshape(kernel, dot_sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
axis_resource = merge_axis_resources([
ln_sharding_meta.axis_resources, dot_sharding_meta.axis_resources,
fp8_sharding_meta.axis_resources
])
partial_ln_fp8_dot = partial(_layernorm_fp8_dot,
layernorm_type=layernorm_type,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
# input, kernel, gamma, beta, fp8_metas
in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1],
ln_sharding_meta.in_axes[1], beta_in_axis, *fp8_sharding_meta.in_axes)
output = xmap_runner(partial_ln_fp8_dot, in_axes, dot_sharding_meta.out_axes, axis_resource,
(inputs_, kernel_, gamma_, beta_, fp8_max, amax, scale, scale_inv))
output = jnp.reshape(output, dot_sharding_meta.output_shapes[0])
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) @partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12))
def _layernorm_fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
beta: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str, scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype,
fwd_dtype: TEDType, bwd_dtype: TEDType, bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float):
contracting_dims: Tuple[Sequence[int], Sequence[int]], output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
zero_centered_gamma: bool, epsilon: float, sharding_type: ShardingType, layernorm_type, fwd_dtype, bwd_dtype,
dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str) -> jnp.ndarray: zero_centered_gamma, epsilon)
output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
contracting_dims, zero_centered_gamma, epsilon,
sharding_type, dp_axis_name, tp_axis_name, fsdp_axis_name)
return output return output
def _layernorm_fp8_dot_fwd( def _layernorm_fp8_dot_fwd_rule(
inputs, x,
kernel, kernel,
gamma, gamma,
beta, beta,
fp8_maxs, fp8_max,
amax, amax,
scale, scale,
scale_inv, scale_inv,
layernorm_type, layernorm_type,
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon):
sharding_type,
dp_axis_name, # pylint: disable=unused-argument x_contracting_dims = (len(x.shape) - 1,)
tp_axis_name, k_contracting_dims = (0,)
fsdp_axis_name): # pylint: disable=unused-argument assert x.shape[-1] == kernel.shape[0]
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:]
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
x_amax = amax[gemm_x_idx, 0:1]
x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx]
input_amax = amax[gemm_input_idx, 0:1]
input_scale = scale[gemm_input_idx]
input_scale_inv = scale_inv[gemm_input_idx]
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
ln_out, mu, rsigma, input_amax = layernorm_fwd_fp8(inputs, ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
gamma, x,
beta, gamma,
input_amax, beta,
input_scale, x_amax,
input_scale_inv, x_scale,
zero_centered_gamma=zero_centered_gamma, x_scale_inv,
epsilon=epsilon) out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else: else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'" "if layernorm_type is 'rmsnorm'"
ln_out, rsigma, input_amax = rmsnorm_fwd_fp8(inputs, ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
gamma, gamma,
input_amax, x_amax,
input_scale, x_scale,
input_scale_inv, x_scale_inv,
epsilon=epsilon) out_dtype=fwd_dtype,
epsilon=epsilon)
mu = None mu = None
assert inputs.shape == ln_out.shape assert x.shape == ln_out.shape
ln_out_ = jnp.reshape(ln_out, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
kernel_amax = amax[gemm_kernel_idx, 0:1] kernel_amax = amax[gemm_kernel_idx, 0:1]
kernel_scale = scale[gemm_kernel_idx] kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx] kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
kernel_scale_inv, fwd_dtype)
output = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, ln_out_, input_scale_inv, # Kernel in (hidden_in, hidden_out...)
fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP) casted_kerenl, casted_kerenl_t, updated_kernel_amax = \
cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=1)
if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW): # (batch..., hidden_in) x (hidden_in, hidden_out...)
output = jax.lax.psum(output, tp_axis_name) kt_contracting_dims = (kernel.ndim - 1,)
output = fp8_dot_impl(ln_out, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype,
(x_contracting_dims, kt_contracting_dims))
# (input_shape_pre, input_shape_suf) ctx = (ln_out, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
# x (kernel_shape_pre, kernel_shape_suf) updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
# = (input_shape_pre, kernel_shape_suf) k_contracting_dims)
output_shape = input_shape_pre + kernel_shape_suf
output = jnp.reshape(output, output_shape)
ctx = (ln_out_, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax,
inputs.shape, kernel.shape, mu, rsigma, inputs, gamma)
return output, ctx return output, ctx
def _layernorm_fp8_dot_bwd( def _layernorm_fp8_dot_bwd_rule(
layernorm_type, layernorm_type,
fwd_dtype, fwd_dtype, # pylint: disable=unused-argument
bwd_dtype, bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sharding_type,
dp_axis_name,
tp_axis_name,
fsdp_axis_name,
ctx, ctx,
g): grad):
ln_out_, kernel_cast, \ ln_out_, casted_kerenl, fp8_max, amax, scale, scale_inv, \
fp8_maxs, amax, scale, scale_inv, \ updated_x_amax, updated_kernel_amax, \
input_amax, kernel_amax, \ x_shape, kernel_shape, mu, rsigma, x, gamma, \
inputs_shape, kernel_shape, \ x_contracting_dims, k_contracting_dims = ctx
mu, rsigma, inputs, gamma = ctx
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = \ ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
FP8Helper.get_fp8_meta_indices(0)
gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx, 0:1] grad_amax = amax[gemm_grad_idx, 0:1]
grad_scale = scale[gemm_grad_idx] grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx] grad_scale_inv = scale_inv[gemm_grad_idx]
ln_out_trans = transpose(ln_out_, fwd_dtype) casted_grad, casted_grad_t, updated_grad_amax = \
g = jnp.reshape(g, (ln_out_trans.shape[1], -1)) cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=min(x_contracting_dims))
# cast and transpose the grad_output
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
input_scale_inv = scale_inv[gemm_input_idx] xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape)))
wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, True, ln_out_trans, input_scale_inv, gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
fwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD) x_scale_inv = scale_inv[gemm_x_idx]
wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(xt_constracting_dim, gt_constracting_dim))
g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv[gemm_kernel_idx] kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv, dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) (g_constracting_dim, k_constracting_dim))
dgrad = jnp.reshape(dgrad, inputs_shape)
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
dgrad = jax.lax.psum(dgrad, tp_axis_name)
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad, dx, dgamma, dbeta = layernorm_bwd(dgrad,
mu, x,
rsigma, mu,
inputs, rsigma,
gamma, gamma,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon) epsilon=epsilon)
else: else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'" "if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon) dx, dgamma = rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
grad_beta = None dbeta = None
amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if len(fsdp_axis_name) > 0: amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
wgrad = jax.lax.psum(wgrad, fsdp_axis_name) amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name) amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
if is_tp_enabled(sharding_type.value[0]): scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = jax.lax.pmax(amax, tp_axis_name)
wgrad = jnp.reshape(wgrad, kernel_shape) return dx, wgrad, \
return grad_input, wgrad, \ dgamma, dbeta, \
grad_gamma, grad_beta, \ fp8_max, amax, scale, scale_inv
fp8_maxs, amax, scale, scale_inv
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd, _layernorm_fp8_dot_bwd) _layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule)
...@@ -3,462 +3,307 @@ ...@@ -3,462 +3,307 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX MLP modules""" """JAX MLP modules"""
from typing import Tuple, Sequence, Union, Callable from typing import List
from functools import partial, reduce from functools import partial
import operator
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.interpreters import pxla
from transformer_engine_jax import DType as TEDType
from .cpp_extensions import jax_dtype_to_te_dtype
from .cpp_extensions import transpose, cast_transpose from .cpp_extensions import transpose, cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8 from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .cpp_extensions import gemm from .dot import fp8_dot_impl
from .sharding import MajorShardingType, ShardingType
from .sharding import get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import merge_axis_resources, infer_sharding_type
from .sharding import xmap_runner, extend_fsdp_sharding_meta
from .layernorm import canonicalize_layernorm_type from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8GemmPackage from .fp8 import FP8Helper, FP8MetaPackage
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
thread_resources = pxla.thread_resources def geglu(x: jnp.ndarray):
def geglu(
inputs: jnp.ndarray,
contracting_dims: Sequence[int] = (-1,),
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0, # pylint: disable=unused-argument
):
""" """
Gated gelu Gated gelu
""" """
input_shape_suf_size = reduce(operator.mul, inputs.shape[min(contracting_dims):]) assert x.shape[-2] == 2 # Linear + GeLU
assert input_shape_suf_size % 2 == 0
output_shape = (*inputs.shape[:min(contracting_dims)], input_shape_suf_size // 2)
if sharding_type is ShardingType.SINGLE:
output = _geglu(inputs, contracting_dims)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, None,
dp_dim_index, dp_axis_name, tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
partial_geglu = partial(_geglu, contracting_dims=contracting_dims)
output = xmap_runner(partial_geglu, sharding_meta.in_axes, sharding_meta.out_axes, output = _geglu(x)
sharding_meta.axis_resources, (inputs_,))
output = jnp.reshape(output, output_shape)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(1,)) @partial(jax.custom_vjp)
def _geglu(inputs: jnp.ndarray, contracting_dims: Sequence[int] = (-1,)): def _geglu(x: jnp.ndarray):
geglu_output, _ = _geglu_fwd(inputs, contracting_dims) geglu_output, _ = _geglu_fwd_rule(x)
return geglu_output return geglu_output
def _geglu_fwd(inputs, contracting_dims): def _geglu_fwd_rule(x):
inputs_real_shape = (*inputs.shape[:min(contracting_dims)], geglu_output = gated_gelu(x)
reduce(operator.mul, inputs.shape[min(contracting_dims):])) return geglu_output, (x,)
inputs_ = jnp.reshape(inputs, inputs_real_shape)
geglu_output = gated_gelu(inputs_)
geglu_output = jnp.expand_dims(geglu_output, min(contracting_dims))
return geglu_output, (inputs_, inputs.shape)
def _geglu_bwd(contracting_dims, ctx, g): def _geglu_bwd_rule(ctx, g):
inputs_, inputs_shape = ctx x, = ctx
g = jnp.squeeze(g, min(contracting_dims)) assert x.dtype == g.dtype
assert inputs_.dtype == g.dtype
dgelu = dgated_gelu(g, inputs_) dgelu = dgated_gelu(g, x)
dgelu = jnp.reshape(dgelu, inputs_shape) dgelu = jnp.reshape(dgelu, x.shape)
return (dgelu,) return (dgelu,)
_geglu.defvjp(_geglu_fwd, _geglu_bwd) _geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
def fp8_ln_mlp( def layernrom_geglu_fp8_mlp(x: jnp.ndarray,
fp8_gemm_pkg: FP8GemmPackage, gamma: jnp.ndarray,
ln_scale: jnp.ndarray, beta: jnp.ndarray,
ln_bias: jnp.ndarray, kernels: List[jnp.ndarray],
layernorm_type: str, fp8_gemm_pkg: FP8MetaPackage,
fwd_dtype: TEDType, layernorm_type: str,
bwd_dtype: TEDType, zero_centered_gamma: bool = False,
zero_centered_gamma: bool = False, epsilon: float = 1e-6) -> jnp.ndarray:
epsilon: float = 1e-6,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE,
dp_dim_index: int = 0, # pylint: disable=unused-argument
activations: Sequence[Union[str, Callable]] = ('gelu', 'linear')
) -> jnp.ndarray:
""" """
FP8 layernorm MLP wrapper Layernorm + GEMM1 + GeGLU + GEMM2
(LN + Dense + act + Dense)
""" """
assert fp8_gemm_pkg.num_of_gemm == 2
inputs = fp8_gemm_pkg.inputs assert len(kernels) == 2
kernel_1 = fp8_gemm_pkg.kernels[0] assert fp8_gemm_pkg.num_of_gemm == len(kernels)
kernel_2 = fp8_gemm_pkg.kernels[1]
kernel_1 = kernels[0]
kernel_2 = kernels[1]
fp8_max = fp8_gemm_pkg.fp8_max fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv scale_inv = fp8_gemm_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
layernorm_type = canonicalize_layernorm_type(layernorm_type) layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm': if layernorm_type == 'rmsnorm':
assert ln_bias is None, "ln_bias should be None if layernorm_type is 'rmsnorm'" assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'" "if layernorm_type is 'rmsnorm'"
assert activations == ('gelu', 'linear') output = _layernrom_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
if major_sharding_type is MajorShardingType.SINGLE: scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale, zero_centered_gamma, epsilon)
scale_inv, layernorm_type, activations, zero_centered_gamma, epsilon, return output
fwd_dtype, bwd_dtype, contracting_dims, major_sharding_type, "", "", "")
else:
dp_axis_name = "batch" @partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13))
tp_axis_name = "model" def _layernrom_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
first_part_st, second_part_st = infer_sharding_type(major_sharding_type) amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
ln_sharding_meta = get_elementwise_sharding_meta(first_part_st, inputs.shape, zero_centered_gamma: bool, epsilon: float):
ln_scale.shape, dp_dim_index, dp_axis_name, output, _ = _layernrom_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
tp_axis_name) scale, scale_inv, fwd_dtype, bwd_dtype,
ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index}) layernorm_type, zero_centered_gamma, epsilon)
return output
input_tp_index = len(inputs.shape) - 1
first_dot_sharding_meta = get_dot_sharding_meta(first_part_st, inputs.shape, kernel_1.shape,
dp_dim_index, input_tp_index, 2, def _layernrom_geglu_fp8_mlp_fwd_rule(
contracting_dims, dp_axis_name, x,
tp_axis_name)
first_dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(
first_dot_sharding_meta, {0: dp_dim_index})
second_input_shape = (*first_dot_sharding_meta.output_shapes[0][:-2],
first_dot_sharding_meta.output_shapes[0][-1])
second_dot_sharding_meta = get_dot_sharding_meta(second_part_st, second_input_shape,
kernel_2.shape, dp_dim_index,
len(second_input_shape) - 1, 0,
contracting_dims, dp_axis_name,
tp_axis_name)
second_dot_sharding_meta, _ = extend_fsdp_sharding_meta(second_dot_sharding_meta,
{0: dp_dim_index})
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(first_part_st, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input
ln_scale_ = jnp.reshape(ln_scale, ln_sharding_meta.input_shapes[1]) # 1 for gamma
ln_bias_ = ln_bias
ln_bias_in_axis = {}
if ln_bias_ is not None:
ln_bias_ = jnp.reshape(ln_bias_, ln_sharding_meta.input_shapes[1]) # 1 for beta
ln_bias_in_axis = ln_sharding_meta.in_axes[1]
kernel_1_ = jnp.reshape(kernel_1, first_dot_sharding_meta.input_shapes[1]) # 1 for kernel
kernel_2_ = jnp.reshape(kernel_2,
second_dot_sharding_meta.input_shapes[1]) # 1 for kernel
axis_resource = merge_axis_resources([
ln_sharding_meta.axis_resources, first_dot_sharding_meta.axis_resources,
second_dot_sharding_meta.axis_resources, fp8_sharding_meta.axis_resources
])
partial_fp8_mlp = partial(_fp8_mlp,
layernorm_type=layernorm_type,
activations=activations,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
major_sharding_type=major_sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
in_axes = (ln_sharding_meta.in_axes[0], ln_sharding_meta.in_axes[1], ln_bias_in_axis,
first_dot_sharding_meta.in_axes[1], second_dot_sharding_meta.in_axes[1],
*fp8_sharding_meta.in_axes)
res = xmap_runner(
partial_fp8_mlp, in_axes, second_dot_sharding_meta.out_axes, axis_resource,
(inputs_, ln_scale_, ln_bias_, kernel_1_, kernel_2_, fp8_max, amax, scale, scale_inv))
res = jnp.reshape(res, second_dot_sharding_meta.output_shapes[0])
return res
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
activations: Sequence[Union[str, Callable]], zero_centered_gamma: bool, epsilon: float,
fwd_dtype: TEDType, bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int],
Sequence[int]],
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str,
fsdp_axis_name: str):
res, _ = _fp8_mlp_fwd(inputs,
ln_scale,
ln_bias,
kernel_1,
kernel_2,
fp8_maxs,
amax,
scale,
scale_inv,
layernorm_type,
activations,
zero_centered_gamma,
epsilon,
fwd_dtype,
bwd_dtype,
contracting_dims=contracting_dims,
major_sharding_type=major_sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
return res
def _fp8_mlp_fwd(
inputs,
gamma, gamma,
beta, beta,
kernel_1, kernel_1,
kernel_2, kernel_2,
fp8_maxs, fp8_max,
amax, amax,
scale, scale,
scale_inv, scale_inv,
layernorm_type,
activations,
zero_centered_gamma,
epsilon,
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
contracting_dims, layernorm_type,
major_sharding_type, zero_centered_gamma,
dp_axis_name, # pylint: disable=unused-argument epsilon):
tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument # x should be in shape of (batch..., hidden)
if activations != ('gelu', 'linear'): # Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out)
raise NotImplementedError("activations only support ('gelu', 'linear') for now.") # Kernel_2 should be in shape of (Hidden_in, Hidden_out)
lhs_contracting_dims, rhs_contracting_dims = contracting_dims assert len(kernel_1.shape) == 3
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)] assert kernel_1.shape[-2] == 2
input_shape_suf = inputs.shape[min(lhs_contracting_dims):] assert len(kernel_2.shape) == 2
kernel_1_shape_pre = kernel_1.shape[:max(rhs_contracting_dims) + 1]
kernel_1_shape_suf = kernel_1.shape[max(rhs_contracting_dims) + 1:] x_contracting_dims = (len(x.shape) - 1,)
kernel_2_shape_pre = kernel_2.shape[:max(rhs_contracting_dims) + 1] xt_batch_dims = tuple(range(1, x.ndim))
kernel_2_shape_suf = kernel_2.shape[max(rhs_contracting_dims) + 1:]
input_contracting_size = reduce(operator.mul, input_shape_suf) assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
kernel_1_pre_size = reduce(operator.mul, kernel_1_shape_pre) assert kernel_1.shape[-1] == kernel_2.shape[0]
kernel_1_suf_size = reduce(operator.mul, kernel_1_shape_suf)
kernel_2_pre_size = reduce(operator.mul, kernel_2_shape_pre)
assert input_contracting_size == kernel_1_pre_size
assert kernel_1_suf_size == kernel_2_pre_size * len(activations)
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_1_ = jnp.reshape(kernel_1, (kernel_1_pre_size, -1))
kernel_2_ = jnp.reshape(kernel_2, (kernel_2_pre_size, -1))
amax = FP8Helper.update_amax_history(amax) amax = FP8Helper.update_amax_history(amax)
gemm1_input_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
x_amax = amax[gemm1_x_idx, 0:1]
x_scale = scale[gemm1_x_idx]
x_scale_inv = scale_inv[gemm1_x_idx]
input_amax = amax[gemm1_input_idx, 0:1]
input_scale = scale[gemm1_input_idx]
input_scale_inv = scale_inv[gemm1_input_idx]
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
ln_out, mu, rsigma, ln_out_amax = layernorm_fwd_fp8(inputs_, ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
gamma, x,
beta, gamma,
input_amax, beta,
input_scale, x_amax,
input_scale_inv, x_scale,
zero_centered_gamma=zero_centered_gamma, x_scale_inv,
epsilon=epsilon) out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else: else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'" "if layernorm_type is 'rmsnorm'"
ln_out, rsigma, ln_out_amax = rmsnorm_fwd_fp8(inputs_, ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
gamma, gamma,
input_amax, x_amax,
input_scale, x_scale,
input_scale_inv, x_scale_inv,
epsilon=epsilon) out_dtype=fwd_dtype,
epsilon=epsilon)
mu = None mu = None
assert x.shape == ln_out.shape
kernel_1_amax = amax[gemm1_kernel_idx, 0:1] kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
kernel_1_scale = scale[gemm1_kernel_idx] kernel_1_scale = scale[gemm1_kernel_idx]
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
kernel_1_cast, kernel_1_cast_trans, kernel_1_amax = cast_transpose(
kernel_1_, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
dense_1_output = gemm(kernel_1_cast_trans, kernel_1_scale_inv, fwd_dtype, True, ln_out,
scale_inv[gemm1_input_idx], fwd_dtype, False,
jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
gemm2_input_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) casted_kerenl_1, casted_kerenl_1_t, updated_kernel_1_amax = \
cast_transpose(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-2)
# (batch..., hidden_in) x (2, hidden_out, hidden_in)
dot_1_output = fp8_dot_impl(ln_out, casted_kerenl_1_t, x_scale_inv, kernel_1_scale_inv, x.dtype,
(x_contracting_dims, (2,)))
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
geglu_out_amax = amax[gemm2_x_idx, 0:1]
geglu_out_scale = scale[gemm2_x_idx]
geglu_out_scale_inv = scale_inv[gemm2_x_idx]
# (batch..., hidden_in) -> (batch..., hidden)
casted_geglu_out, updated_geglu_amax = gated_gelu_fp8(dot_1_output, geglu_out_amax,
geglu_out_scale, geglu_out_scale_inv,
fwd_dtype)
kernel_2_amax = amax[gemm2_kernel_idx, 0:1] kernel_2_amax = amax[gemm2_kernel_idx, 0:1]
kernel_2_scale = scale[gemm2_kernel_idx] kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
kernel_2_cast, kernel_2_cast_trans, kernel_2_amax = cast_transpose(
kernel_2_, kernel_2_amax, kernel_2_scale, kernel_2_scale_inv, fwd_dtype) casted_kerenl_2, casted_kerenl_2_t, updated_kernel_2_amax = \
cast_transpose(kernel_2, kernel_2_amax, kernel_2_scale, kernel_2_scale_inv, fwd_dtype,
dense_1_out_amax = amax[gemm2_input_idx, 0:1] static_axis_boundary=-1, transpose_axis_boundary=-1)
dense_1_out_scale = scale[gemm2_input_idx]
dense_1_out_scale_inv = scale_inv[gemm2_input_idx] # (batch..., hidden_in) x (hidden_out, hidden_in)
gated_gelu_output_cast, gated_gelu_amax = gated_gelu_fp8(dense_1_output, dense_1_out_amax, dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kerenl_2_t, geglu_out_scale_inv,
dense_1_out_scale, kernel_2_scale_inv, x.dtype, (x_contracting_dims, (1,)))
dense_1_out_scale_inv, fwd_dtype)
res = gemm(kernel_2_cast_trans, kernel_2_scale_inv, fwd_dtype, True, ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kerenl_1,
gated_gelu_output_cast, dense_1_out_scale_inv, fwd_dtype, False, casted_kerenl_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP) updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims)
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP): return dot_2_output, ctx
res = jax.lax.psum(res, tp_axis_name)
# (input_shape_pre, input_shape_suf) def _layernrom_geglu_fp8_mlp_bwd_rule(
# x (kernel_1_shape_pre, kernel_1_shape_suf) fwd_dtype, # pylint: disable=unused-argument
# x (kernel_2_shape_pre, kernel_2_shape_suf) bwd_dtype,
# = (input_shape_pre, kernel_2_shape_suf)
output_shape = input_shape_pre + kernel_2_shape_suf
res = jnp.reshape(res, output_shape)
ctx = (inputs_, ln_out, mu, rsigma, gamma, dense_1_output, gated_gelu_output_cast,
kernel_1_cast, kernel_2_cast, fp8_maxs, amax, scale, scale_inv, ln_out_amax,
gated_gelu_amax, kernel_1_amax, kernel_2_amax, inputs.shape, kernel_1.shape,
kernel_2.shape)
return res, ctx
def _fp8_mlp_bwd(
layernorm_type, layernorm_type,
activations, # pylint: disable=unused-argument
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
major_sharding_type,
dp_axis_name,
tp_axis_name,
fsdp_axis_name,
ctx, ctx,
g): grad):
inputs_, ln_out, mu, rsigma, gamma, \ x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \
dense_1_output, gated_gelu_output_cast, \ casted_kerenl_1, casted_kerenl_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
kernel_1_cast, kernel_2_cast, \ updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
fp8_maxs, amax, scale, scale_inv, \ x_contracting_dims, xt_batch_dims = ctx
ln_out_amax, gated_gelu_amax, kernel_1_amax, kernel_2_amax, \
input_shape, kernel_1_shape, kernel_2_shape = ctx
g = jnp.reshape(g, (ln_out.shape[0], -1))
gemm2_input_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1] grad_amax = amax[gemm2_grad_idx, 0:1]
grad_scale = scale[gemm2_grad_idx] grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx] grad_scale_inv = scale_inv[gemm2_grad_idx]
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv, casted_grad, casted_grad_t, updated_grad_amax = \
bwd_dtype) cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
gated_gelu_output_cast_trans = transpose(gated_gelu_output_cast, fwd_dtype) static_axis_boundary=-1, transpose_axis_boundary=-1)
casted_geglu_out_t = transpose(casted_geglu_out,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
gemm2_input_scale_inv = scale_inv[gemm2_input_idx] # (hidden, batch...,) x (hidden, batch...)
wgrad_2 = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, True, gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
gated_gelu_output_cast_trans, gemm2_input_scale_inv, fwd_dtype, False, wgrad_2 = fp8_dot_impl(casted_geglu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD) grad.dtype, (xt_batch_dims, xt_batch_dims))
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
dgrad_2 = gemm(kernel_2_cast, kernel_2_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv, dgrad_2 = fp8_dot_impl(casted_grad, casted_kerenl_2, grad_scale_inv, kernel_2_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) grad.dtype, (x_contracting_dims, (1,)))
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgeglu_amax = amax[gemm1_grad_idx, 0:1]
dgeglu_scale = scale[gemm1_grad_idx]
dgeglu_scale_inv = scale_inv[gemm1_grad_idx]
gemm1_input_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0) casted_dgeglu, casted_dgeglu_t, updated_dgeglu_amax = dgated_gelu_cast_transpose(
dgrad_2,
dot_1_output,
dgeglu_amax,
dgeglu_scale,
dgeglu_scale_inv,
bwd_dtype,
static_axis_boundary=-1)
dgrad_2_amax = amax[gemm1_grad_idx, 0:1] ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
dgrad_2_scale = scale[gemm1_grad_idx]
dgrad_2_scale_inv = scale_inv[gemm1_grad_idx]
dgelu, dgelu_trans, dgelu_amax = dgated_gelu_cast_transpose(dgrad_2, dense_1_output,
dgrad_2_amax, dgrad_2_scale,
dgrad_2_scale_inv, bwd_dtype)
ln_out_trans = transpose(ln_out, fwd_dtype)
gemm1_input_scale_inv = scale_inv[gemm1_input_idx] # (hidden, batch...) x (2, hidden, batch...)
wgrad_1 = gemm(dgelu_trans, dgrad_2_scale_inv, bwd_dtype, True, xt_batch_dims_plus_act_dim = tuple(i + 1 for i in xt_batch_dims)
ln_out_trans, gemm1_input_scale_inv, fwd_dtype, False, gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD) wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgeglu_t, gemm1_x_scale_inv, dgeglu_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim))
# (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out)
x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple(
i + 1 for i in x_contracting_dims)
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = gemm(kernel_1_cast, kernel_1_scale_inv, fwd_dtype, True, dgelu, dgrad_2_scale_inv, dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kerenl_1, dgeglu_scale_inv, kernel_1_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) grad.dtype, (x_contracting_dims_plus_act_dim, (
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP): 1,
dgrad_1 = jax.lax.psum(dgrad_1, tp_axis_name) 2,
)))
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad_1, dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
mu, x,
rsigma, mu,
inputs_, rsigma,
gamma, gamma,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon) epsilon=epsilon)
else: else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'" "if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon) dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
grad_beta = None dbeta = None
amax = amax.at[gemm1_input_idx, 0].set(ln_out_amax[0]) amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(kernel_1_amax[0]) amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(dgelu_amax[0]) amax = amax.at[gemm1_grad_idx, 0].set(updated_dgeglu_amax[0])
amax = amax.at[gemm2_input_idx, 0].set(gated_gelu_amax[0]) amax = amax.at[gemm2_x_idx, 0].set(updated_geglu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(kernel_2_amax[0]) amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax[0])
amax = amax.at[gemm2_grad_idx, 0].set(grad_amax[0]) amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
if major_sharding_type in (MajorShardingType.DP, MajorShardingType.DPTP): scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
wgrad_1 = jax.lax.psum(wgrad_1, dp_axis_name)
wgrad_2 = jax.lax.psum(wgrad_2, dp_axis_name) return dx, dgamma, dbeta, wgrad_1, wgrad_2, \
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name) fp8_max, amax, scale, scale_inv
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name) _layernrom_geglu_fp8_mlp.defvjp(_layernrom_geglu_fp8_mlp_fwd_rule,
_layernrom_geglu_fp8_mlp_bwd_rule)
if len(fsdp_axis_name) > 0:
wgrad_1 = jax.lax.psum(wgrad_1, fsdp_axis_name)
wgrad_2 = jax.lax.psum(wgrad_2, fsdp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP):
amax = jax.lax.pmax(amax, tp_axis_name)
grad_input = jnp.reshape(grad_input, input_shape)
wgrad_1 = jnp.reshape(wgrad_1, kernel_1_shape)
wgrad_2 = jnp.reshape(wgrad_2, kernel_2_shape)
return grad_input, grad_gamma, grad_beta, \
wgrad_1, wgrad_2, \
fp8_maxs, amax, scale, scale_inv
_fp8_mlp.defvjp(_fp8_mlp_fwd, _fp8_mlp_bwd)
...@@ -49,7 +49,7 @@ class TransformerEngineBaseLayer(BaseLayer): ...@@ -49,7 +49,7 @@ class TransformerEngineBaseLayer(BaseLayer):
fp8_collection_map = { fp8_collection_map = {
FP8Helper.FP8_COLLECTION_NAME: [ FP8Helper.FP8_COLLECTION_NAME: [
WeightHParamsCollection.SKIP_LP_REGULARIZATION, WeightHParamsCollection.SKIP_LP_REGULARIZATION,
WeightHParamsCollection.NON_TRAINABLE, WeightHParamsCollection.OVERWRITE_WITH_GRADIENT,
WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
] ]
} }
...@@ -92,8 +92,7 @@ class LayerNorm(TransformerEngineBaseLayer): ...@@ -92,8 +92,7 @@ class LayerNorm(TransformerEngineBaseLayer):
"ln_bias", self.bias_init), "ln_bias", self.bias_init),
bias_axes=self.bias_axes, bias_axes=self.bias_axes,
dtype=self.dtype, dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence)
sharding_type=self.sharding_type)
self.create_layer("layer_norm", ln_cls) self.create_layer("layer_norm", ln_cls)
...@@ -115,8 +114,7 @@ class FusedSoftmax(TransformerEngineBaseLayer): ...@@ -115,8 +114,7 @@ class FusedSoftmax(TransformerEngineBaseLayer):
fused_softmax_cls = partial(Softmax, fused_softmax_cls = partial(Softmax,
scale_factor=self.scale_factor, scale_factor=self.scale_factor,
softmax_type=self.softmax_type, softmax_type=self.softmax_type)
sharding_type=self.sharding_type)
self.create_layer("fused_softmax", fused_softmax_cls) self.create_layer("fused_softmax", fused_softmax_cls)
...@@ -151,8 +149,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -151,8 +149,7 @@ class Linear(TransformerEngineBaseLayer):
bias_axes=self.bias_axes, bias_axes=self.bias_axes,
axis=self.axis, axis=self.axis,
dtype=self.dtype, dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence)
sharding_type=self.sharding_type)
self.create_layer("linear", dense_general_cls) self.create_layer("linear", dense_general_cls)
...@@ -208,8 +205,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -208,8 +205,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
axis=self.axis, axis=self.axis,
dtype=self.dtype, dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
depth_scaling=self.depth_scaling, depth_scaling=self.depth_scaling)
sharding_type=self.sharding_type)
self.create_layer("ln_linear", ln_dense_general_cls) self.create_layer("ln_linear", ln_dense_general_cls)
...@@ -273,8 +269,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -273,8 +269,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims, intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims,
axis=self.axis, axis=self.axis,
dtype=self.dtype, dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence)
major_sharding_type=self.major_sharding_type)
self.create_layer("ln_mlp", ln_mlp_cls) self.create_layer("ln_mlp", ln_mlp_cls)
......
...@@ -8,17 +8,12 @@ Sharding Meta for xmap with CustomCall ...@@ -8,17 +8,12 @@ Sharding Meta for xmap with CustomCall
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from itertools import repeat from typing import Callable
from typing import Union, Tuple, Dict, Callable, Sequence
from jax.interpreters import pxla from jax.interpreters import pxla
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.experimental.maps import xmap
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
_PXLA_THREAD_RESOURCES = pxla.thread_resources _PXLA_THREAD_RESOURCES = pxla.thread_resources
...@@ -29,6 +24,24 @@ def _get_mesh_info(resource: str): ...@@ -29,6 +24,24 @@ def _get_mesh_info(resource: str):
return mesh.shape[resource], resource return mesh.shape[resource], resource
def get_all_mesh_axes():
"""
Get all name of mesh axes
"""
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
return mesh.axis_names
def get_padded_spec(spec, ndim):
"""
Get padded spec for partitioning from arguments' information
"""
if spec is None:
return (None,) * ndim
assert len(spec) <= ndim
return spec + (None,) * (ndim - len(spec))
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
""" """
A wrapper function to jax.lax.with_sharding_constraint to A wrapper function to jax.lax.with_sharding_constraint to
...@@ -40,8 +53,25 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): ...@@ -40,8 +53,25 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
return jax.lax.with_sharding_constraint(x, pspec) return jax.lax.with_sharding_constraint(x, pspec)
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
"""
A wrapper function to invoke lax.p* operations, like psum.
"""
if mesh_resource is not None:
_, resource = _get_mesh_info(mesh_resource)
return ops(x, resource)
return x
def num_of_devices():
"""
Get total number of detected devices
"""
return len(jax.devices())
@dataclass @dataclass
class ShardingResource: class MeshResource:
""" """
A data container to indicate which axis in Mesh for data parallelism and A data container to indicate which axis in Mesh for data parallelism and
which for tensor parallelism. which for tensor parallelism.
...@@ -54,39 +84,73 @@ class ShardingResource: ...@@ -54,39 +84,73 @@ class ShardingResource:
tp_resource : str, default = None tp_resource : str, default = None
The axis name in Mesh used to split the hidden dimensions along. The axis name in Mesh used to split the hidden dimensions along.
If it is None, then tensor parallelism is disabled. If it is None, then tensor parallelism is disabled.
fsdp_resource : str, default = None
The axis name in Mesh used to split the batch and weights along.
If it is None, then full-sharded data parallelism is disabled.
pp_resource : str, default = None
The axis name in Mesh used to split model layers. along.
If it is None, then pipeline parallelism is disabled.
""" """
dp_resource: str = None dp_resource: str = None
tp_resource: str = None tp_resource: str = None
fsdp_resource: str = None fsdp_resource: str = None
pp_resource: str = None
_GLOBAL_SHARD_RESOURCE = ShardingResource() _GLOBAL_MESH_RESOURCE = MeshResource()
@contextmanager @contextmanager
def global_shard_guard(resource: ShardingResource): def global_shard_guard(resource: MeshResource):
""" """
A context manager to switch the global ShardingResource A context manager to switch the global MeshResource
""" """
global _GLOBAL_SHARD_RESOURCE global _GLOBAL_MESH_RESOURCE
prev_gsr = _GLOBAL_SHARD_RESOURCE prev_gmr = _GLOBAL_MESH_RESOURCE
try: try:
_GLOBAL_SHARD_RESOURCE = resource _GLOBAL_MESH_RESOURCE = resource
yield yield
finally: finally:
_GLOBAL_SHARD_RESOURCE = prev_gsr _GLOBAL_MESH_RESOURCE = prev_gmr
def global_mesh_resource() -> MeshResource:
"""
A getter of the global MeshResource
"""
return _GLOBAL_MESH_RESOURCE
def global_shard_resource() -> ShardingResource:
def all_reduce_sum_along_dp_fsdp(x: jnp.array):
""" """
A getter of the global ShardingResource All-Reduce (Sum) along DP and FSDP mesh axes.
""" """
return _GLOBAL_SHARD_RESOURCE x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource)
def all_reduce_max_along_all_axes_except_PP(x: jnp.array):
"""
All-Reduce (Max) along all mesh axes.
"""
all_axes = get_all_mesh_axes()
for axis in all_axes:
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis)
return x
# Deprecating Items ---------------------------------------------------------------
ShardingResource = MeshResource
global_shard_resource = global_mesh_resource
class MajorShardingType(Enum): class MajorShardingType(Enum):
r""" r"""
The major sharding type to indicate sharding pattern. The major sharding type to indicate sharding pattern.
.. warning::
MajorShardingType is deprecating in the near feature.
Values Values
---------- ----------
...@@ -108,6 +172,8 @@ class MajorShardingType(Enum): ...@@ -108,6 +172,8 @@ class MajorShardingType(Enum):
class ShardingType(Enum): class ShardingType(Enum):
""" """
The sharding type to indicate sharding pattern. The sharding type to indicate sharding pattern.
.. warning::
ShardingType is deprecating in the near feature.
Values Values
---------- ----------
...@@ -130,1058 +196,3 @@ class ShardingType(Enum): ...@@ -130,1058 +196,3 @@ class ShardingType(Enum):
TP_ROW = (MajorShardingType.TP, "tp_row") TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def infer_major_sharding_type() -> MajorShardingType:
"""
Infer MajorShardingType from _GLOBAL_SHARD_RESOURCE
"""
gsr = global_shard_resource()
resources = [gsr.dp_resource, gsr.tp_resource, gsr.fsdp_resource]
for idx, rs in enumerate(resources):
try:
size, _ = _get_mesh_info(rs)
if size <= 1:
resources[idx] = None
except AssertionError as _:
resources[idx] = None
dp_resource = resources[0]
tp_resource = resources[1]
fsdp_resource = resources[2]
def dp_enabled():
return (fsdp_resource is not None) or (dp_resource is not None)
if dp_enabled() and tp_resource is not None:
return MajorShardingType.DPTP
if dp_enabled():
return MajorShardingType.DP
if tp_resource is not None:
return MajorShardingType.TP
return MajorShardingType.SINGLE
def infer_sharding_type(major_st: MajorShardingType = None) -> Tuple[ShardingType, ShardingType]:
"""
Infer ShardingType via given MajorShardingType
"""
if major_st is None:
major_st = infer_major_sharding_type()
if major_st is MajorShardingType.DP:
return ShardingType.DP, ShardingType.DP
if major_st is MajorShardingType.TP:
return ShardingType.TP_COL, ShardingType.TP_ROW
if major_st is MajorShardingType.DPTP:
return ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW
return ShardingType.SINGLE, ShardingType.SINGLE
def is_dp_enabled(mst: MajorShardingType) -> bool:
"""
is_dp_enabled
"""
return mst in (MajorShardingType.DP, MajorShardingType.DPTP)
def is_tp_enabled(mst: MajorShardingType) -> bool:
"""
is_tp_enabled
"""
return mst in (MajorShardingType.TP, MajorShardingType.DPTP)
def merge_axis_resources(ars: Tuple[Dict]) -> Dict:
"""
merge_axis_resources
"""
output = {}
for ar in ars:
for key in ar:
if key not in output:
output[key] = ar[key]
else:
assert output[key] == ar[key]
return output
@dataclass
class ShardingMeta:
"""ShardingMeta"""
in_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]]
out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]]
axis_resources: Dict
input_shapes: Tuple[Tuple[int, ...]]
output_shapes: Tuple[Tuple[int, ...]]
class ShardingMetaGenerator:
"""
ShardingMetaGenerator
"""
def __init__(self):
def get_single_sharding_meta(*argv, **kwargs) -> ShardingMeta: # pylint: disable=unused-argument
return None
self.sharding_type_meta_map = {
ShardingType.SINGLE: get_single_sharding_meta,
ShardingType.DP: self.get_dp_sharding_meta,
ShardingType.TP_COL: self.get_tp_col_sharding_meta,
ShardingType.TP_ROW: self.get_tp_row_sharding_meta,
ShardingType.DP_TP_COL: self.get_dp_tp_col_sharding_meta,
ShardingType.DP_TP_ROW: self.get_dp_tp_row_sharding_meta
}
def get_sharding_meta(self, stype: ShardingType, *argv, **kwargs) -> ShardingMeta:
"""get_sharding_meta"""
return self.sharding_type_meta_map[stype](*argv, **kwargs)
def get_dp_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_sharding_meta"""
raise NotImplementedError
def get_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_col_sharding_meta"""
raise NotImplementedError
def get_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_row_sharding_meta"""
raise NotImplementedError
def get_dp_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
raise NotImplementedError
def get_dp_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
raise NotImplementedError
class FP8MetaShardingMetaGenerator(ShardingMetaGenerator):
"""
FP8MetaShardingMetaGenerator
"""
def get_dp_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_tp_col_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.TP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_tp_row_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.TP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_dp_tp_col_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DPTP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_dp_tp_row_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DPTP,
num_of_meta, dp_axis_name,
tp_axis_name)
@staticmethod
def _stack_axes_meta(num_of_meta: int, mapping: Dict) -> Tuple:
return tuple(mapping for _ in range(num_of_meta))
@staticmethod
def _generate_sharding_meta(type_: MajorShardingType,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
axis_resource = {}
if is_dp_enabled(type_):
axis_resource[dp_axis_name] = global_shard_resource().dp_resource
if is_tp_enabled(type_):
axis_resource[tp_axis_name] = global_shard_resource().tp_resource
return ShardingMeta(FP8MetaShardingMetaGenerator._stack_axes_meta(num_of_meta, {}),
FP8MetaShardingMetaGenerator._stack_axes_meta(num_of_meta, {}),
axis_resource, (), ())
class FusedAttnShardingMetaGenerator(ShardingMetaGenerator):
"""
FusedAttnShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]], # pylint: disable=unused-argument
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_dp_sharding_meta"""
dummy_tp_dims = [repeat(None), repeat(None)]
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(input_shapes, output_shapes,
dp_dims, dummy_tp_dims,
dp_axis_name, None)
def get_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_col_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_tp_sharding_meta(*argv, **kwargs)
def get_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_row_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_tp_sharding_meta(*argv, **kwargs)
def get_dp_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(*argv, **kwargs)
def get_dp_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(*argv, **kwargs)
@staticmethod
def _get_tp_sharding_meta(
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]], # pylint: disable=unused-argument
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_sharding_meta"""
dummy_dp_dims = [repeat(None), repeat(None)]
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(input_shapes, output_shapes,
dummy_dp_dims, tp_dims, None,
tp_axis_name)
@staticmethod
def _get_dptp_sharding_meta(input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_sharding_meta"""
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
input_dp_dims, output_dp_dims = dp_dims
input_tp_dims, output_tp_dims = tp_dims
input_new_shapes = []
in_axes = []
for input_shape, dp_dim, tp_dim in zip(input_shapes, input_dp_dims, input_tp_dims):
in_axis = {}
if dp_dim is not None and input_shape is not None:
in_axis[dp_dim] = dp_axis_name
assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of " \
f"data parallelism size, but got {input_shape[dp_dim]=} and {dp_size=}."
input_shape = (*input_shape[:dp_dim], dp_size, input_shape[dp_dim] // dp_size,
*input_shape[dp_dim + 1:])
# the input shape has been expanded for dp_dim, tp_dim should +1 if tp_dim >= dp_dim
if tp_dim is not None and tp_dim >= dp_dim:
tp_dim = tp_dim + 1
if tp_dim is not None and input_shape is not None:
in_axis[tp_dim] = tp_axis_name
assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of " \
f"tensor parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}."
input_shape = (*input_shape[:tp_dim], tp_size, input_shape[tp_dim] // tp_size,
*input_shape[tp_dim + 1:])
in_axes.append(in_axis)
input_new_shapes.append(input_shape)
output_new_shapes = output_shapes
out_axes = []
for dp_dim, tp_dim in zip(output_dp_dims, output_tp_dims):
out_axis = {}
if dp_dim is not None:
out_axis[dp_dim] = dp_axis_name
if tp_dim is not None and tp_dim >= dp_dim:
tp_dim = tp_dim + 1
if tp_dim is not None:
out_axis[tp_dim] = tp_axis_name
out_axes.append(out_axis)
assert len(out_axes) == 1, "Only allow single output at this moment."
assert len(output_new_shapes) == 1, "Only allow single output at this moment."
out_axes = out_axes[0]
output_new_shapes = output_new_shapes[0]
axis_resources = {}
if dp_axis_name is not None:
axis_resources[dp_axis_name] = dp_mesh_axis
if tp_axis_name is not None:
axis_resources[tp_axis_name] = tp_mesh_axis
return ShardingMeta(tuple(in_axes), out_axes, axis_resources, input_new_shapes,
output_new_shapes)
class DotShardingMetaGenerator(ShardingMetaGenerator):
"""
DotShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int, # pylint: disable=unused-argument
model_dim_of_b: int, # pylint: disable=unused-argument
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
out_batch_dim = batch_dim_of_a
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
assert a_shape[batch_dim_of_a] % dp_size == 0, \
f"The dimension of batch in a_shape should be a multiple of data parallelism size," \
f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}."
a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, -1, *a_shape[batch_dim_of_a + 1:])
return ShardingMeta(({
batch_dim_of_a: dp_axis_name
}, {}), ({
out_batch_dim: dp_axis_name
}), {dp_axis_name: dp_mesh_axis}, [a_new_shape, b_shape], [out_shape])
def get_tp_col_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int, # pylint: disable=unused-argument
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
out_model_idx = len(out_shape) - (len(b_shape) - model_dim_of_b)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}."
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(({}, {
model_dim_of_b: tp_axis_name
}), ({
out_model_idx: tp_axis_name
}), {tp_axis_name: tp_mesh_axis}, [a_shape, b_new_shape], [out_shape])
def get_tp_row_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int,
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, model_dim_of_a,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert a_shape[model_dim_of_a] % tp_size == 0, \
f"The dimension of model parallelism in a_shape should be a multiple of " \
f"tensor parallelism size,but got {a_shape[model_dim_of_a]=} and {tp_size=}."
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}."
a_new_shape = (*a_shape[:model_dim_of_a], tp_size, a_shape[model_dim_of_a] // tp_size,
*a_shape[model_dim_of_a + 1:])
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(({
model_dim_of_a: tp_axis_name
}, {
model_dim_of_b: tp_axis_name
}), ({}), {tp_axis_name: tp_mesh_axis}, [a_new_shape, b_new_shape], [out_shape])
def get_dp_tp_col_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int, # pylint: disable=unused-argument
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
out_model_idx = len(out_shape) + 1 - (len(b_shape) - model_dim_of_b)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert a_shape[batch_dim_of_a] % dp_size == 0, \
f"The dimension of batch in a_shape should be a multiple of data parallelism size," \
f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}."
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}."
a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, a_shape[batch_dim_of_a] // dp_size,
*a_shape[batch_dim_of_a + 1:])
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(({
batch_dim_of_a: dp_axis_name
}, {
model_dim_of_b: tp_axis_name
}), ({
batch_dim_of_a: dp_axis_name,
out_model_idx: tp_axis_name
}), {
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
}, [a_new_shape, b_new_shape], [out_shape])
def get_dp_tp_row_sharding_meta(self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int,
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, model_dim_of_a,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert a_shape[batch_dim_of_a] % dp_size == 0, \
f"The dimension of batch in a_shape should be a multiple of data parallelism size," \
f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}."
assert a_shape[model_dim_of_a] % tp_size == 0, \
f"The dimension of model parallelism in a_shape should be a multiple of " \
f"tensor parallelism size,but got {a_shape[model_dim_of_a]=} and {tp_size=}."
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but {b_shape[model_dim_of_b]=} and {tp_size=}."
a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, a_shape[batch_dim_of_a] // dp_size,
*a_shape[batch_dim_of_a + 1:model_dim_of_a], tp_size,
a_shape[model_dim_of_a] // tp_size, *a_shape[model_dim_of_a + 1:])
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(
(
{
batch_dim_of_a:
dp_axis_name,
# "model_dim_of_a+1" is the index to tp_size in a_new_shape
model_dim_of_a + 1:
tp_axis_name
},
{
model_dim_of_b: tp_axis_name
}),
({
batch_dim_of_a: dp_axis_name
}),
{
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
},
[a_new_shape, b_new_shape],
[out_shape])
@staticmethod
def _is_supported(
a_shape: Tuple, # pylint: disable=unused-argument
b_shape: Tuple, # pylint: disable=unused-argument
batch_dim_of_a: int,
model_dim_of_a: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
):
assert batch_dim_of_a not in contracting_dims[0], \
"batch_dim_of_a should be one of contracting_dims[0]"
assert batch_dim_of_a >= 0, \
"Only support non-negative value of batch_dim_of_a."
if model_dim_of_a is not None:
assert model_dim_of_a >= 0, \
"Only support non-negative value of model_dim_of_a"
assert model_dim_of_a > batch_dim_of_a, \
"Only support the case that model_dim_of_a > batch_dim_of_a."
@staticmethod
def _infer_output_shape(
a_shape: Tuple,
b_shape: Tuple,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
):
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
return (*a_shape[:min(lhs_contracting_dims)], *b_shape[max(rhs_contracting_dims) + 1:])
class ElementwiseShardingMetaGenerator(ShardingMetaGenerator):
"""
ElementwiseShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_dp_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, batch_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
assert input_shape[batch_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism " \
f"size, but got {input_shape[batch_dim]=} and {dp_size=}."
input_new_shape = (*input_shape[:batch_dim], dp_size, -1, *input_shape[batch_dim + 1:])
in_axes = [{batch_dim: dp_axis_name}]
input_new_shapes = [input_new_shape]
if other_shape is not None:
input_new_shapes.append(other_shape)
in_axes.append({})
return ShardingMeta(tuple(in_axes), ({
batch_dim: dp_axis_name
}), {dp_axis_name: dp_mesh_axis}, input_new_shapes, [input_shape])
def get_tp_col_sharding_meta(
self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int, # pylint: disable=unused-argument
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_tp_col_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, 0)
in_axes = [{}]
input_new_shapes = [input_shape]
if other_shape is not None:
in_axes.append({})
input_new_shapes.append(other_shape)
return ShardingMeta(tuple(in_axes), ({}), {}, input_new_shapes, [input_shape])
def get_tp_row_sharding_meta(
self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int, # pylint: disable=unused-argument
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_row_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, 0)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[-1] % tp_size == 0, \
f"The last dimension in input_shape should be a multiple of tensor parallelism size," \
f" but got {input_shape[-1]=} and {tp_size=}."
input_new_shape = (*input_shape[:-1], tp_size, -1)
in_axes = [{
# "len(a_new_shape)-2" is the index to tp_size in a_new_shape
len(input_new_shape) - 2:
tp_axis_name
}]
input_new_shapes = [input_new_shape]
if other_shape is not None:
assert other_shape[0] % tp_size == 0, \
f"The first dimension in other_shape should be a multiple of tensor parallelism size," \
f" but got {other_shape[0]=} and {tp_size=}."
other_new_shape = (tp_size, -1)
in_axes.append({0: tp_axis_name})
input_new_shapes.append(other_new_shape)
return ShardingMeta(tuple(in_axes), ({
len(input_new_shape) - 2: tp_axis_name
}), {tp_axis_name: tp_mesh_axis}, input_new_shapes, [input_shape])
def get_dp_tp_col_sharding_meta(self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
return self.get_dp_sharding_meta(input_shape, other_shape, batch_dim, dp_axis_name,
tp_axis_name)
def get_dp_tp_row_sharding_meta(self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, batch_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[batch_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism" \
f"size, but got {input_shape[batch_dim]=} and {dp_size=}."
assert input_shape[-1] % tp_size == 0, \
f"The last dimension in input_shape should be a multiple of tensor parallelism size," \
f" but got {input_shape[-1]=} and {tp_size=}."
input_new_shape = (*input_shape[:batch_dim], dp_size, -1, *input_shape[batch_dim + 1:-1],
tp_size, input_shape[-1] // tp_size)
in_axes = [{
batch_dim:
dp_axis_name,
# "len(a_new_shape)-2" is the index to tp_size in a_new_shape
len(input_new_shape) - 2:
tp_axis_name
}]
input_new_shapes = [input_new_shape]
other_new_shape = other_shape
if other_shape is not None:
assert other_shape[0] % tp_size == 0, \
f"The first dimension in other_shape should be a multiple of tensor parallelism size," \
f" but got {other_shape[0]=} and {tp_size=}."
other_new_shape = (tp_size, -1)
in_axes.append({0: tp_axis_name})
input_new_shapes.append(other_new_shape)
return ShardingMeta(tuple(in_axes), ({
batch_dim: dp_axis_name,
len(input_new_shape) - 2: tp_axis_name
}), {
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
}, input_new_shapes, [input_shape])
@staticmethod
def _is_supported(input_shape: Tuple, other_shape: Tuple, batch_dim: int):
if other_shape is not None:
assert len(other_shape) == 1, "Only support 1 dimension of other_shapes currently."
assert input_shape[-1] == other_shape[0], \
f"input_shape[-1] should equal to oshape[0], " \
f"but got {input_shape[-1]} and {other_shape[0]}."
assert batch_dim < len(input_shape)-1, \
"batch_dim cannot be the latest dim"
class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
"""
SoftmaxShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_dp_sharding_meta"""
SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism " \
f"size, but got {input_shape[dp_dim]=} and {dp_size=}."
input_new_shape = (*input_shape[:dp_dim], dp_size, -1, *input_shape[dp_dim + 1:])
in_axes = [{dp_dim: dp_axis_name}]
input_new_shapes = [input_new_shape]
out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, {dp_axis_name: dp_mesh_axis},
input_new_shapes, [input_shape])
def get_tp_col_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_col_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_tp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_tp_row_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_row_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_tp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_dp_tp_col_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_dptp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_dp_tp_row_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_dptp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
@staticmethod
def _is_supported(input_shape: Tuple, dp_dim: int, tp_dim: int):
assert len(input_shape) == 4
assert dp_dim == 0
assert tp_dim == 1
@staticmethod
def _get_tp_sharding_meta(
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_tp_sharding_meta"""
SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of data " \
f"parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}."
input_new_shape = (*input_shape[:tp_dim], tp_size, -1, *input_shape[tp_dim + 1:])
in_axes = [{tp_dim: tp_axis_name}]
input_new_shapes = [input_new_shape]
out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, {tp_axis_name: tp_mesh_axis},
input_new_shapes, [input_shape])
@staticmethod
def _get_dptp_sharding_meta(input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_sharding_meta"""
SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism " \
f"size, but got {input_shape[dp_dim]=} and {dp_size=}."
assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of data " \
f"parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}."
input_new_shape = (*input_shape[:dp_dim], dp_size, input_shape[dp_dim] // dp_size,
*input_shape[dp_dim + 1:tp_dim], tp_size, input_shape[tp_dim] // tp_size,
*input_shape[tp_dim + 1:])
in_axes = [{dp_dim: dp_axis_name, tp_dim + 1: tp_axis_name}]
input_new_shapes = [input_new_shape]
out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, {
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
}, input_new_shapes, [input_shape])
def get_fp8_meta_sharding_meta(stype: ShardingType,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_fp8_meta_sharding_meta
"""
return FP8MetaShardingMetaGenerator().get_sharding_meta(stype, num_of_meta, dp_axis_name,
tp_axis_name)
def get_dot_sharding_meta(stype: ShardingType,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int,
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_dot_sharding_meta
"""
if stype in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
assert model_dim_of_b <= max(contracting_dims[1]), \
f"The dimension of model parallelism in b_shape should be smaller than the max of" \
f" contracting_dims[1], but got {model_dim_of_b=} and {contracting_dims[1]=}."
if stype in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
assert model_dim_of_b > max(contracting_dims[1]), \
f"The dimension of model parallelism in b_shape should be larger than the max of" \
f" contracting_dims[1], but got {model_dim_of_b=} and {contracting_dims[1]=}."
return DotShardingMetaGenerator().get_sharding_meta(stype, a_shape, b_shape, batch_dim_of_a,
model_dim_of_a, model_dim_of_b,
contracting_dims, dp_axis_name,
tp_axis_name)
def get_elementwise_sharding_meta(stype: ShardingType,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_elementwise_sharding_meta
"""
return ElementwiseShardingMetaGenerator().get_sharding_meta(stype, input_shape, other_shape,
batch_dim, dp_axis_name,
tp_axis_name)
def get_softmax_sharding_meta(stype: ShardingType,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_softmax_sharding_meta
"""
return SoftmaxShardingMetaGenerator().get_sharding_meta(stype, input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_fused_attn_sharding_meta(stype: ShardingType,
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_self_fused_attn_sharding_meta
"""
return FusedAttnShardingMetaGenerator().get_sharding_meta(stype, input_shapes, output_shapes,
dp_dims, tp_dims, dp_axis_name,
tp_axis_name)
def extend_fsdp_sharding_meta(sharding_meta: ShardingMeta,
weight_fsdp_dim_map: Dict[int, int]) -> Tuple[ShardingMeta, str]:
"""
Extending the given ShardingMeta to be compatible with FSDP (ZeRO3) sharding pattern.
.. note::
The extending helper assumes the first shape in sharding_meta.input_shapes
corresponding to the input tensor. Please be sure that 0-idx is in
`weight_fsdp_dim_map`.
Parameters
----------
sharding_meta : ShardingMeta
the sharding meta object to extend with FSDP.
weight_fsdp_dim_map: Dict[int, int]
The dict, which key is idx of sharding_meta.input_shapes and value is the dimension
to extend FSDP. default is None, means no other sharding_meta.input_shapes to extend.
Returns
-------
updated_sharding_meta : ShardingMeta
a sharding_meta with the FSDP extenstion.
fsdp_axis_name: str
The name of FSDP named axis for further xmap projection.
"""
assert 0 in weight_fsdp_dim_map, \
"0-idx is required to be in 'weight_fsdp_dim_map' for the input."
mst = infer_major_sharding_type()
if mst is MajorShardingType.SINGLE:
return sharding_meta, ""
gsr = global_shard_resource()
dp_mesh_axis = gsr.dp_resource
fsdp_mesh_axis = gsr.fsdp_resource
if fsdp_mesh_axis == dp_mesh_axis:
return sharding_meta, ""
if fsdp_mesh_axis is None:
return sharding_meta, ""
fsdp_dim_size, _ = _get_mesh_info(fsdp_mesh_axis)
fsdp_axis_name = "fsdp"
def get_idx_to_extend(sharded_indices, target_idx):
idx_to_extend = target_idx
for i in sharded_indices:
if i <= target_idx:
idx_to_extend += 1
return idx_to_extend
def extend_exist_sharding(idx, shape):
remain_size = shape[idx]
assert remain_size == -1 or remain_size % fsdp_dim_size == 0
remain_size = remain_size // fsdp_dim_size
new_shape = tuple([*shape[:idx], fsdp_dim_size, remain_size, *shape[idx + 1:]])
return new_shape
new_input_shapes = []
new_in_axes = []
for i, shape in enumerate(sharding_meta.input_shapes):
idx_to_extend = -1
if i == 0: # Assume first shape corresponds to input
input_dp_dim = weight_fsdp_dim_map[i]
# idx_to_extend = input_dp_dim + 1 if is_dp_enabled(mst) else input_dp_dim
idx_to_extend = get_idx_to_extend(list(sharding_meta.in_axes[i].keys()), input_dp_dim)
new_shape = extend_exist_sharding(idx_to_extend, shape)
# assume one output only and have the same batch sharding like input
assert isinstance(sharding_meta.out_axes, dict)
new_out_axes = {}
for key in sharding_meta.out_axes:
if key < idx_to_extend:
new_out_axes[key] = sharding_meta.out_axes[key]
else:
new_out_axes[key + 1] = sharding_meta.out_axes[key]
new_out_axes[idx_to_extend] = fsdp_axis_name
sharding_meta.out_axes = new_out_axes
else:
new_shape = shape
if i in weight_fsdp_dim_map:
idx_to_extend = get_idx_to_extend(list(sharding_meta.in_axes[i].keys()),
weight_fsdp_dim_map[i])
if weight_fsdp_dim_map[i] in sharding_meta.in_axes[i]:
new_shape = extend_exist_sharding(idx_to_extend, shape)
else:
assert shape[idx_to_extend] % fsdp_dim_size == 0
remain_dim_size = shape[idx_to_extend] // fsdp_dim_size
new_shape = tuple([
*shape[:idx_to_extend], fsdp_dim_size, remain_dim_size,
*shape[idx_to_extend + 1:]
])
if idx_to_extend >= 0:
new_ia = {}
for key in sharding_meta.in_axes[i]:
if key < idx_to_extend:
new_ia[key] = sharding_meta.in_axes[i][key]
else:
new_ia[key + 1] = sharding_meta.in_axes[i][key]
new_ia[idx_to_extend] = fsdp_axis_name
else:
new_ia = sharding_meta.in_axes[i]
new_input_shapes.append(new_shape)
new_in_axes.append(new_ia)
sharding_meta.input_shapes = tuple(new_input_shapes)
sharding_meta.in_axes = tuple(new_in_axes)
sharding_meta.axis_resources[fsdp_axis_name] = fsdp_mesh_axis
return sharding_meta, fsdp_axis_name
def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...],
out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]],
axis_resources: Dict, inputs: Tuple):
"""
xmap_runner
"""
assert isinstance(inputs, tuple)
assert isinstance(in_axes, tuple)
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
fake_in_axes = {}
fake_axis_resource = {}
# Fake related setup is a workaround to "NotImplementedError:
# Collectives in manually partitioned computations are only supported
# when all mesh axes are partitioned manually (no partial automatic
# sharding). Make sure that you mention all mesh axes in axis_resources!"
fake_idx_counter = 0
for mesh_axis_names in mesh.axis_names:
if mesh_axis_names not in axis_resources.values():
fake_idx_counter += 1
fake_axis_name = f"{mesh_axis_names}_fake_{fake_idx_counter}"
fake_in_axes[fake_idx_counter] = fake_axis_name
fake_axis_resource[fake_axis_name] = mesh_axis_names
fake_input = jnp.zeros(tuple(64 for _ in range(len(fake_in_axes) + 1)))
xmapped = xmap(lambda func_input, _: func(*func_input),
in_axes=(in_axes, fake_in_axes),
out_axes=out_axes,
axis_resources={
**axis_resources,
**fake_axis_resource
})
output = xmapped(inputs, fake_input)
return output
...@@ -18,11 +18,6 @@ from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd ...@@ -18,11 +18,6 @@ from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd
from .cpp_extensions import ScaledSoftmaxFwdPrimitive from .cpp_extensions import ScaledSoftmaxFwdPrimitive
from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive
from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive
from .sharding import get_softmax_sharding_meta, ShardingType, ShardingMeta
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
class SoftmaxType(Enum): class SoftmaxType(Enum):
...@@ -48,100 +43,47 @@ def is_softmax_kernel_available(softmax_type: SoftmaxType, batch: int, heads: in ...@@ -48,100 +43,47 @@ def is_softmax_kernel_available(softmax_type: SoftmaxType, batch: int, heads: in
raise NotImplementedError raise NotImplementedError
def softmax(inputs: jnp.ndarray, def softmax(logits: jnp.ndarray,
mask: Optional[jnp.ndarray] = None, mask: Optional[jnp.ndarray] = None,
scale_factor: Optional[float] = 1.0, scale_factor: Optional[float] = 1.0,
softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED, softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED):
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0,
tp_dim_index: int = 1):
""" """
Softmax wrapper Softmax wrapper
""" """
assert dp_dim_index == 0, \ output = _softmax(logits, mask, scale_factor, softmax_type)
"Only softmax support batch dim in the first place currently." return output
assert tp_dim_index == 1, \
"Only softmax support head dim in the second place currently."
assert mask is None or mask.shape[tp_dim_index] == 1
if sharding_type is ShardingType.SINGLE:
outputs = _softmax(inputs, mask, scale_factor, softmax_type)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_softmax_sharding_meta(sharding_type,
inputs.shape,
dp_dim=dp_dim_index,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
mask_ = mask
mask_in_axis = {}
if mask_ is not None:
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
# If mask is head broadcastable (heads == 1),
# then it equals to DP sharding.
mask_sharding_meta = get_softmax_sharding_meta(ShardingType.DP,
mask_.shape,
dp_dim=dp_dim_index,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
else:
mask_sharding_meta = ShardingMeta([{}], {}, {}, [mask_.shape], mask_.shape)
mask_sharding_meta, _ = extend_fsdp_sharding_meta(mask_sharding_meta, {0: dp_dim_index})
mask_ = jnp.reshape(mask_, mask_sharding_meta.input_shapes[0])
mask_in_axis = mask_sharding_meta.in_axes[0]
partial_softmax = partial(_softmax, scale_factor=scale_factor, softmax_type=softmax_type)
in_axes = (sharding_meta.in_axes[0], mask_in_axis)
outputs = xmap_runner(partial_softmax, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, mask_))
outputs = jnp.reshape(outputs, sharding_meta.output_shapes[0])
return outputs
@partial(jax.custom_vjp, nondiff_argnums=(2, 3)) @partial(jax.custom_vjp, nondiff_argnums=(2, 3))
def _softmax(inputs, mask, scale_factor, softmax_type): def _softmax(logits, mask, scale_factor, softmax_type):
output, _ = _softmax_fwd(inputs, mask, scale_factor, softmax_type)
output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_type)
return output return output
def _softmax_fwd(inputs, mask, scale_factor, softmax_type): def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
if softmax_type is SoftmaxType.SCALED_MASKED: if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None assert mask is not None
outputs = scaled_masked_softmax_fwd(inputs, mask, scale_factor) output = scaled_masked_softmax_fwd(logits, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
outputs = scaled_upper_triang_masked_softmax_fwd(inputs, scale_factor) output = scaled_upper_triang_masked_softmax_fwd(logits, scale_factor)
else: else:
outputs = scaled_softmax_fwd(inputs, scale_factor) output = scaled_softmax_fwd(logits, scale_factor)
return outputs, (outputs, mask) return output, (output,)
def _softmax_bwd(scale_factor, softmax_type, ctx, grad_outputs): def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz):
softmax_outputs, mask = ctx softmax_output, = ctx
if softmax_type is SoftmaxType.SCALED_MASKED: if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None dgrad = scaled_masked_softmax_bwd(dz, softmax_output, scale_factor)
dgrad = scaled_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
dgrad = scaled_upper_triang_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor) dgrad = scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor)
else: else:
dgrad = scaled_softmax_bwd(grad_outputs, softmax_outputs, scale_factor) dgrad = scaled_softmax_bwd(dz, softmax_output, scale_factor)
return (dgrad, None) return (dgrad, None)
_softmax.defvjp(_softmax_fwd, _softmax_bwd) _softmax.defvjp(_softmax_fwd_rule, _softmax_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment