Unverified Commit 71e51eae authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Migrating from Xmap to Custom Partitioning for All Custom Calls (#472)



* Refactor sharding.py for the further custom_partitioning migration
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* WAR to LN/RMSN_fp8 before migrating to CP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix the wrong order of parameters of bwd of LN/RMSN_fp8.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Following review feedback to modify
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Force the hidden dim in Norm ops to no sharding and add warning msg.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Reuse fwd_rule in VJP functions
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* add gelu and dgelu.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Reuse fwd_rule in VJP functions for attentions
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Apply native FP8 Dtypes to fp8.py
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating cast_and_transpose from xmap to custom_partitioning
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating transpose from xmap to custom_partitioning
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Apply XLA pattern match to perform FP8 GEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* migrate layernorm_fp8 to custom_partitioning.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Unify code style
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Extend supported of Transpose with FP8
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Implementing layernorm_fp8_dot based on migrated custom calls.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Renaming variables and publish NVTE_FP8_COLLECTION_NAME
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Replace Q/DQ custom calls with native XLA implementations
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* migrate gelu_fp to custom_partitioning.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Miner fix
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Support custom calls with mutli-dims
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Support gerneral dot indices in _fp8_dot_impl
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Implementing layernrom_geglu_fp8_mlp
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Remove GEMM custom calls
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Remove xmap related code
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix typo and add query-function to FP8MetaPackage
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix some bugs of custom calls
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix CT's bugs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update UTs/eaxmaples to adapt to the API changes.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Unify kernel initilization in MLP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Modifing with code review's feedback
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update README and Add deprecating warning to *ShardingType
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Canonicalize the dtype
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding assertion for non-supported batch dims.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding doc/examples to _multidim_transpose
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Apply dtype-based rtol/atol to UTs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Deprecate QKV_INTERLEAVED enum
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Skip test_distributed_custom_ops.py
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong sharding of bias in SelfAttn
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding distributed ops unit-tests
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license to test_distributed_*
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Follow review feedback to modify
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Use total bytes involved in collective ops as criteria.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarDonglin Yang <dongliny@nvidia.com>
parent 7976bd00
......@@ -4,227 +4,167 @@
"""JAX te modules"""
from typing import Tuple, Sequence
from functools import partial, reduce
import operator
from functools import partial
import jax
import jax.numpy as jnp
from transformer_engine_jax import DType as TEDType
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner, extend_fsdp_sharding_meta
from .cpp_extensions import cast_transpose
from .fp8 import FP8Helper, FP8MetaPackage
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def type_safe_dot_general(
x,
kernel,
fp8_meta_pkg: FP8MetaPackage = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,))
) -> jnp.ndarray:
"""
Type safe dot_general, including FP8.
"""
if fp8_meta_pkg is None:
kernel = jnp.asarray(kernel, x.dtype)
return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))
fp8_max = fp8_meta_pkg.fp8_max
amax = fp8_meta_pkg.amax
scale = fp8_meta_pkg.scale
scale_inv = fp8_meta_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
return _fp8_dot(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
contracting_dims)
def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0) -> jnp.ndarray:
def quantize(x, q_dtype, scale):
"""
FP8 dot wrapper
Quantize with scale.
"""
assert fp8_gemm_pkg.num_of_gemm == 1
inputs = fp8_gemm_pkg.inputs
kernel = fp8_gemm_pkg.kernels[0]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
if sharding_type is ShardingType.SINGLE:
res = _fp8_dot(inputs,
kernel,
fp8_max,
amax,
scale,
scale_inv,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name="",
tp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
kernel_tp_index = None
# TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
kernel_tp_index = len(kernel.shape) - 1
elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
kernel_tp_index = 0
input_tp_index = len(inputs.shape) - 1
sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
axis_resources = merge_axis_resources(
[sharding_meta.axis_resources, fp8_sharding_meta.axis_resources])
partial_fp8_dot = partial(_fp8_dot,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes),
sharding_meta.out_axes, axis_resources,
(inputs_, kernel_, fp8_max, amax, scale, scale_inv))
res = jnp.reshape(res, sharding_meta.output_shapes[0])
return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs,
kernel,
fp8_maxs,
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
return res
dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype)
scale = scale.astype(x.dtype)
clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max)
return clipped_scaled_x.astype(q_dtype)
def _fp8_dot_fwd(
inputs,
def dequantize(x, dq_dtype, scale_inv):
"""
Dequantize with scale_inv.
"""
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(4, 5))
def fp8_dot_impl(
q_lhs: jnp.ndarray,
q_rhs: jnp.ndarray,
lhs_scale_inv: jnp.ndarray,
rhs_scale_inv: jnp.ndarray,
ctype: jnp.dtype, # computing type
contracting_dims: Tuple[Sequence[int], Sequence[int]]):
"""
FP8 GEMM for XLA pattern match
"""
dim_nums = (contracting_dims, ((), ()))
lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
rhs = dequantize(q_rhs, ctype, rhs_scale_inv)
return jax.lax.dot_general(lhs, rhs, dim_nums)
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
contracting_dims: Tuple[Sequence[int], Sequence[int]]):
output, _ = _fp8_dot_fwd_rule(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
contracting_dims)
return output
def _fp8_dot_fwd_rule(
x,
kernel,
fp8_maxs,
fp8_max,
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
contracting_dims):
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
x_shape_suf = x.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:]
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
assert x_shape_suf == kernel_shape_pre
amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx, 0:1]
input_scale = scale[gemm_input_idx]
input_scale_inv = scale_inv[gemm_input_idx]
input_cast, input_cast_trans, input_amax = cast_transpose(inputs_, input_amax, input_scale,
input_scale_inv, fwd_dtype)
x_amax = amax[gemm_x_idx, 0:1]
x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx]
casted_x, casted_xt, updated_x_amax = \
cast_transpose(x, x_amax, x_scale, x_scale_inv, fwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))
kernel_amax = amax[gemm_kernel_idx, 0:1]
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
kernel_scale_inv, fwd_dtype)
res = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, input_cast, input_scale_inv,
fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
res = jax.lax.psum(res, tp_axis_name)
casted_kerenl, casted_kerenl_t, updated_kernel_amax = \
cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv,
fwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=(max(rhs_contracting_dims) + 1))
# (input_shape_pre, input_shape_suf)
# x (kernel_shape_pre, kernel_shape_suf)
# = (input_shape_pre, kernel_shape_suf)
output_shape = input_shape_pre + kernel_shape_suf
res = jnp.reshape(res, output_shape)
rhs_t_contracting_dims = tuple(range(kernel.ndim - len(rhs_contracting_dims), kernel.ndim))
output = fp8_dot_impl(casted_x, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype,
(lhs_contracting_dims, rhs_t_contracting_dims))
ctx = (input_cast_trans, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax,
inputs.shape, kernel.shape)
return res, ctx
ctx = (casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape)
return output, ctx
def _fp8_dot_bwd(
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
sharding_type,
dp_axis_name,
tp_axis_name,
fsdp_axis_name,
ctx,
g):
input_cast_trans, kernel_cast, \
fp8_maxs, amax, scale, scale_inv, \
input_amax, kernel_amax, \
inputs_shape, kernel_shape = ctx
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, \
updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx
gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx, 0:1]
grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx]
g = jnp.reshape(g, (input_cast_trans.shape[1], -1))
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
input_scale_inv = scale_inv[gemm_input_idx]
wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype,
True, input_cast_trans, input_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
xt_constracting_dim = tuple(range(len(lhs_contracting_dims), len(x_shape)))
gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
x_scale_inv = scale_inv[gemm_x_idx]
wgrad = fp8_dot_impl(casted_xt, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(xt_constracting_dim, gt_constracting_dim))
if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if len(fsdp_axis_name) > 0:
wgrad = jax.lax.psum(wgrad, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_constracting_dim, k_constracting_dim))
if is_tp_enabled(sharding_type.value[0]):
amax = jax.lax.pmax(amax, tp_axis_name)
amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
dgrad = jax.lax.psum(dgrad, tp_axis_name)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
dgrad = jnp.reshape(dgrad, inputs_shape)
wgrad = jnp.reshape(wgrad, kernel_shape)
return dgrad, wgrad, fp8_maxs, amax, scale, scale_inv
return dgrad, wgrad, fp8_max, amax, scale, scale_inv
_fp8_dot.defvjp(_fp8_dot_fwd, _fp8_dot_bwd)
_fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule)
......@@ -6,6 +6,7 @@ Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
import warnings
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
import jax.numpy as jnp
......@@ -16,14 +17,12 @@ from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
from ..dot import fp8_dot
from ..fp8 import FP8GemmPackage, FP8Helper
from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import fp8_ln_mlp, geglu
from ..sharding import infer_sharding_type
from ..mlp import layernrom_geglu_fp8_mlp, geglu
from ..softmax import is_softmax_kernel_available
from ..sharding import MajorShardingType, ShardingType
from ..softmax import softmax, SoftmaxType
PRNGKey = Any
......@@ -119,16 +118,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
Scalar for the input to softmax.
softmax_type : SoftmaxType, default = SoftmaxType.SCALED
Indicate the type of softmax.
Optimization parameters
-----------------------
sharding_type : ShardingType, default = ShardingType.SINGLE
Indicate the sharding pattern.
"""
scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED
sharding_type: ShardingType = ShardingType.SINGLE
@nn.compact
def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
......@@ -149,8 +142,7 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
if self.softmax_type is not SoftmaxType.SCALED_MASKED:
mask_ = None
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type,
self.sharding_type)
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
else:
attention_bias = None
if mask is not None:
......@@ -168,8 +160,7 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
# and kernel is unavailable, then try on pure scaled softmax custom calls.
if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen,
dtype):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED,
self.sharding_type)
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
else:
outputs = jax_nn.softmax(logits * self.scale_factor)
......@@ -242,8 +233,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
sharding_type : ShardingType, default = ShardingType.SINGLE
Indicate the sharding pattern.
"""
epsilon: float = 1e-6
layernorm_type: str = 'layernorm'
......@@ -254,7 +243,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_axes: Tuple[str, ...] = ('embed',)
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
sharding_type = None
def __post_init__(self):
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
......@@ -276,6 +265,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
outputs : jax.numpy.ndarray
Output tensors.
"""
warnings.warn("sharding_type of LayerNorm would be removed in the near feature",
DeprecationWarning)
features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
......@@ -286,9 +277,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
ln_bias,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
epsilon=self.epsilon)
class TransformerEngineBase(nn.Module):
......@@ -329,17 +318,15 @@ class TransformerEngineBase(nn.Module):
return fp8_max.value, fp8_metas_amax.value, fp8_metas_scale.value, fp8_metas_scale_inv.value
@staticmethod
def get_fp8_gemm_package(num_of_gemm: int, inputs: jnp.ndarray,
kernels: List[jnp.ndarray]) -> FP8GemmPackage:
def get_fp8_meta_package(num_of_gemm: int) -> FP8MetaPackage:
"""
Get the FP8 metas
"""
assert num_of_gemm == len(kernels)
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
TransformerEngineBase.get_fp8_metas(num_of_gemm)
return FP8GemmPackage(num_of_gemm, inputs, kernels, fp8_max, fp8_metas_amax,
fp8_metas_scale, fp8_metas_scale_inv)
return FP8MetaPackage(num_of_gemm, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
class DenseGeneral(TransformerEngineBase):
......@@ -376,8 +363,6 @@ class DenseGeneral(TransformerEngineBase):
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
sharding_type : ShardingType, default = ShardingType.SINGLE
Indicate the sharding pattern.
"""
features: Union[Iterable[int], int]
......@@ -389,7 +374,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
sharding_type = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -411,6 +396,9 @@ class DenseGeneral(TransformerEngineBase):
outputs : jax.numpy.ndarray
Output tensors.
"""
warnings.warn("sharding_type of DenseGeneral would be removed in the near feature",
DeprecationWarning)
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
......@@ -438,18 +426,15 @@ class DenseGeneral(TransformerEngineBase):
bias = None
contract_ind = tuple(range(0, len(axis)))
fp8_gemm_pkg = None
if FP8Helper.is_fp8_enabled():
fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(1, inputs, [kernel])
y = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
fp8_gemm_pkg = \
TransformerEngineBase.get_fp8_meta_package(1)
y = type_safe_dot_general(inputs,
kernel,
fp8_meta_pkg=fp8_gemm_pkg,
contracting_dims=(axis, contract_ind))
if bias is not None:
bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
......@@ -528,8 +513,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied.
sharding_type : ShardingType, default = ShardingType.SINGLE
Indicate the sharding pattern.
"""
features: Union[Iterable[int], int]
......@@ -551,7 +534,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
depth_scaling: float = None
sharding_type: ShardingType = ShardingType.SINGLE
sharding_type = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -578,12 +561,16 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
warnings.warn("sharding_type of LayerNormDenseGeneral would be removed in the near feature",
DeprecationWarning)
ln_output = None
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm
if self.enable_layernorm:
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
......@@ -597,9 +584,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
ln_bias,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
epsilon=self.epsilon)
else:
assert not self.return_layernorm_output
y = inputs
......@@ -627,30 +612,25 @@ class LayerNormDenseGeneral(TransformerEngineBase):
contract_ind = tuple(range(0, len(axis)))
fp8_meta_package = None
if FP8Helper.is_fp8_enabled():
fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(1, y, [kernel])
fp8_meta_package = \
TransformerEngineBase.get_fp8_meta_package(1)
if not fuse_layernorm:
z = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
z = layernorm_fp8_dot(fp8_gemm_package,
if fuse_layernorm:
z = layernorm_fp8_dot(y,
kernel,
scale,
ln_bias,
fp8_meta_package,
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
epsilon=self.epsilon)
else:
kernel = jnp.asarray(kernel, self.dtype)
z = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ())))
z = type_safe_dot_general(y,
kernel,
fp8_meta_pkg=fp8_meta_package,
contracting_dims=(axis, contract_ind))
bias = None
if self.use_bias:
......@@ -758,8 +738,6 @@ class LayerNormMLP(TransformerEngineBase):
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
major_sharding_type : MajorShardingType, default = MajorShardingType.SINGLE
Indicate the sharding pattern.
"""
intermediate_dim: int = 2048
......@@ -776,10 +754,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = (
'act',
'mlp',
)
bias_axes_1: Tuple[str, ...] = ('act', 'mlp')
bias_axes_2: Tuple[str, ...] = ('embed',)
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',)
......@@ -789,7 +764,7 @@ class LayerNormMLP(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE
major_sharding_type = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -818,19 +793,32 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
warnings.warn("major_sharding_type of LayerNormMLP would be removed in the near feature",
DeprecationWarning)
ln_output = None
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm
def is_geglu(acts):
geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')]
normalize_acts = []
for act in acts:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return normalize_acts in geglu_act_pool
use_fused_ln_mlp = fuse_layernorm \
and (not self.use_bias) and self.activations == ('gelu', 'linear') \
and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3)
first_sharding_type, second_sharding_type = infer_sharding_type(self.major_sharding_type)
# LayerNorm
if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
......@@ -844,9 +832,7 @@ class LayerNormMLP(TransformerEngineBase):
ln_bias,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
epsilon=self.epsilon)
else:
assert not self.return_layernorm_output
y = inputs
......@@ -864,13 +850,17 @@ class LayerNormMLP(TransformerEngineBase):
return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)
num_of_gemm = 2
if use_fused_ln_mlp:
fp8_meta_package = None
if FP8Helper.is_fp8_enabled():
fp8_meta_package = \
TransformerEngineBase.get_fp8_meta_package(num_of_gemm)
num_activations = len(self.activations)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, inputs.ndim)
axis = _normalize_axes(axis, y.ndim)
intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
kernel_1_shape = tuple(inputs.shape[ax] for ax in axis) + intermediate_dim
kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes('wi_kernel',
kernel_1_init,
......@@ -892,79 +882,35 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
contract_ind = tuple(range(0, len(axis)))
fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(num_of_gemm, y, [kernel_1, kernel_2])
out = fp8_ln_mlp(fp8_gemm_package,
if use_fused_ln_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernrom_geglu_fp8_mlp(y,
scale,
ln_bias,
ln_bias, [kernel_1, kernel_2],
fp8_meta_package,
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
contracting_dims=(axis, contract_ind),
major_sharding_type=self.major_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
activations=self.activations)
epsilon=self.epsilon)
else: # not use_fused_ln_mlp
def fp8_meta_generator():
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = (None, None, None,
None)
if FP8Helper.is_fp8_enabled():
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
TransformerEngineBase.get_fp8_metas(num_of_gemm)
return fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
fp8_meta_generator()
# DenseGeneral 1
activations = []
num_activations = len(self.activations)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim)
intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
kernel_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel = nn_partitioning.param_with_axes('wi_kernel',
kernel_1_init,
num_activations,
-2,
kernel_1_each_shape,
jnp.float32,
axes=self.kernel_axes_1)
kernel = jnp.reshape(kernel, kernel_shape)
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.is_fp8_enabled():
fp8_gemm_package = FP8GemmPackage(
1, y, [kernel], fp8_max[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_amax[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_scale[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_scale_inv[:FP8Helper.NUM_META_PER_GEMM, :])
if not fuse_layernorm:
x = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
x = layernorm_fp8_dot(fp8_gemm_package,
gemm1_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(0)
if fuse_layernorm:
x = layernorm_fp8_dot(y,
kernel_1,
scale,
ln_bias,
gemm1_fp8_meta_package,
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else: # not enable fp8
kernel = jnp.asarray(kernel, self.dtype)
x = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ())))
epsilon=self.epsilon)
else:
x = type_safe_dot_general(y,
kernel_1,
fp8_meta_pkg=gemm1_fp8_meta_package,
contracting_dims=(axis, contract_ind))
bias = None
if self.use_bias:
......@@ -977,11 +923,9 @@ class LayerNormMLP(TransformerEngineBase):
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)
if self.activations == ('gelu', 'linear'):
z = geglu(x,
contracting_dims=(-2, -1),
sharding_type=second_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
activations = []
if is_geglu(self.activations):
z = geglu(x)
else:
x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations):
......@@ -996,37 +940,13 @@ class LayerNormMLP(TransformerEngineBase):
z, deterministic=deterministic)
# DenseGeneral 2
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, z.ndim)
kernel_shape = tuple(z.shape[ax] for ax in axis) + hidden_size_tuple
kernel_param_shape = (np.prod([z.shape[ax] for ax in axis]), np.prod(hidden_size_tuple))
kernel = nn_partitioning.param_with_axes('wo_kernel',
self.kernel_init,
kernel_param_shape,
jnp.float32,
axes=self.kernel_axes_2)
kernel = jnp.reshape(kernel, kernel_shape)
gemm2_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(1)
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.is_fp8_enabled():
fp8_gemm_package = FP8GemmPackage(
1, z, [kernel], fp8_max[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_amax[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_scale[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_scale_inv[FP8Helper.NUM_META_PER_GEMM:, :])
out = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=second_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
kernel = jnp.asarray(kernel, self.dtype)
out = lax.dot_general(z, kernel, ((axis, contract_ind), ((), ())))
out = type_safe_dot_general(z,
kernel_2,
fp8_meta_pkg=gemm2_fp8_meta_package,
contracting_dims=(axis, contract_ind))
bias = None
if self.use_bias:
......
......@@ -27,9 +27,8 @@ from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType
from ..sharding import infer_major_sharding_type, infer_sharding_type
from ..sharding import global_shard_resource, with_sharding_constraint
from ..sharding import ShardingType
from ..sharding import global_mesh_resource, num_of_devices
from ..sharding import with_sharding_constraint
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -102,7 +101,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
else:
rules_map[key] = [val]
gsr = global_shard_resource()
gsr = global_mesh_resource()
batch_dim_rule = []
if gsr.dp_resource is not None:
......@@ -186,7 +185,6 @@ def core_attention(query: Array,
scale_factor: float,
transpose_batch_sequence: bool,
softmax_type: SoftmaxType = SoftmaxType.SCALED,
softmax_sharding_type: ShardingType = ShardingType.SINGLE,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
dropout_rng: Optional[PRNGKey] = None,
......@@ -226,9 +224,7 @@ def core_attention(query: Array,
fused_scale_factor = scale_factor
attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor,
sharding_type=softmax_sharding_type)(attn_weights, mask,
bias).astype(dtype)
scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
......@@ -482,8 +478,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"Fused attention is not enabled. Because " \
f"{reason}fall back to unfused attention.")
first_sharding_type, second_sharding_type = infer_sharding_type()
residual = inputs_q
if self.fuse_qkv:
if is_self_attn:
......@@ -494,7 +488,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=(3, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=(W_NO_SHARD_AXES,),
......@@ -516,7 +509,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=(W_NO_SHARD_AXES,),
......@@ -530,7 +522,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
name='query')(inputs_q)
kv_proj = DenseGeneral(axis=-1,
features=(2, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init,
......@@ -546,7 +537,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias,
......@@ -560,7 +550,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True,
scale_axes=(W_NO_SHARD_AXES,),
......@@ -648,7 +637,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
seed = None
if dropout_rng is not None:
seed = jax.random.split(dropout_rng, len(jax.devices()))
seed = jax.random.split(dropout_rng, num_of_devices())
# ensure the old key never used
del dropout_rng
......@@ -665,8 +654,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic,
sharding_type=first_sharding_type)
is_training=not deterministic)
else:
assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
......@@ -685,8 +673,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic,
sharding_type=first_sharding_type)
is_training=not deterministic)
else:
def convert_to_softmax_type(attn_mask_type, mask):
......@@ -710,7 +697,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
softmax_type=softmax_type,
softmax_sharding_type=first_sharding_type,
mask=mask,
bias=bias,
dropout_rng=dropout_rng,
......@@ -728,7 +714,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
x = _with_sharding_constraint(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1],
sharding_type=second_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1,
kernel_init=self.kernel_init,
......@@ -1175,7 +1160,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
major_sharding_type=infer_major_sharding_type(),
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size,
......@@ -1208,7 +1192,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
z = z + residual
if self.output_layernorm:
ln_sharding_type, _ = infer_sharding_type()
z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
......@@ -1216,7 +1199,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
sharding_type=ln_sharding_type,
name="output_layer_norm")(z)
return z
......@@ -6,7 +6,7 @@ Helper module for fp8 meta management
"""
from contextlib import contextmanager
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import jax
import jax.numpy as jnp
......@@ -17,7 +17,7 @@ from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import ShardingResource
from transformer_engine.jax.sharding import MeshResource
_is_fp8_available = None
_reason_for_no_fp8 = ""
......@@ -59,37 +59,29 @@ def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
def _format2dtypes(format_: Format):
if format_ == Format.E4M3:
return DType.kFloat8E4M3, DType.kFloat8E4M3
return jnp.float8_e4m3fn, jnp.float8_e4m3fn
if format_ == Format.E5M2:
return DType.kFloat8E5M2, DType.kFloat8E5M2
return jnp.float8_e5m2, jnp.float8_e5m2
if format_ == Format.HYBRID:
return DType.kFloat8E4M3, DType.kFloat8E5M2
return DType.kBFloat16, DType.kBFloat16
return jnp.float8_e4m3fn, jnp.float8_e5m2
return jnp.bfloat16, jnp.bfloat16
class FP8GemmPackage:
class FP8MetaPackage:
"""
A container that contains all required data for
FP8 GEMM
A container that contains all required meta data for FP8
"""
def __init__(
self,
num_of_gemm: int,
inputs: jnp.ndarray,
kernels: List[jnp.ndarray],
fp8_max: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
) -> None:
total_num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
self._num_of_gemm = num_of_gemm
self._inputs = inputs
assert len(kernels) == self._num_of_gemm
self._kernels = kernels
total_num_of_meta = self._num_of_gemm * FP8Helper.NUM_META_PER_GEMM
assert fp8_max.shape[0] == total_num_of_meta
self._fp8_max = fp8_max
assert amax.shape[0] == total_num_of_meta
......@@ -106,20 +98,6 @@ class FP8GemmPackage:
"""
return self._num_of_gemm
@property
def inputs(self) -> jnp.ndarray:
"""
inputs of this package
"""
return self._inputs
@property
def kernels(self) -> List[jnp.ndarray]:
"""
kernels of this package
"""
return self._kernels
@property
def fp8_max(self) -> jnp.ndarray:
"""
......@@ -148,6 +126,19 @@ class FP8GemmPackage:
"""
return self._scale_inv
def get_package_by_gemm_idx(self, gemm_idx):
"""
Get a sub package by gemm_idx
"""
assert self.num_of_gemm > gemm_idx
meta_start_idx = gemm_idx * FP8Helper.NUM_META_PER_GEMM
meta_end_idx = (gemm_idx + 1) * FP8Helper.NUM_META_PER_GEMM
return FP8MetaPackage(1, self.fp8_max[meta_start_idx:meta_end_idx],
self.amax[meta_start_idx:meta_end_idx],
self.scale[meta_start_idx:meta_end_idx],
self.scale_inv[meta_start_idx:meta_end_idx])
class AmaxComputeAlgo(Enum):
"""AmaxComputeAlgo."""
......@@ -155,6 +146,9 @@ class AmaxComputeAlgo(Enum):
MOST_RECENT = "most_recent"
NVTE_FP8_COLLECTION_NAME = "fp8_meta_collection"
class FP8Helper:
"""
FP8 helper to manage the FP8 meta
......@@ -162,8 +156,8 @@ class FP8Helper:
INITIALIZED = False
MARGIN: float = 0.0
FP8_FORMAT: Format = Format.HYBRID
FWD_DTYPE: DType = DType.kFloat8E4M3
BWD_DTYPE: DType = DType.kFloat8E5M2
FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
......@@ -171,7 +165,7 @@ class FP8Helper:
INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1
GRAD_META_IDX_PER_GEMM: int = 2
FP8_COLLECTION_NAME: str = "fp8_meta_collection"
FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_AMAX_NAME: str = "fp8_meta_amax"
FP8_SCALE_NAME: str = "fp8_meta_scale"
FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
......@@ -216,21 +210,12 @@ class FP8Helper:
FP8Helper.INITIALIZED = False
FP8Helper.MARGIN = 0.0
FP8Helper.FP8_FORMAT = Format.HYBRID
FP8Helper.FWD_DTYPE = DType.kFloat8E4M3
FP8Helper.BWD_DTYPE = DType.kFloat8E5M2
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_LEN = 1024
FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
@staticmethod
def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax_buffers = jnp.roll(amax_buffers, -1, 1)
updated_amax_buffers = updated_amax_buffers.at[:, 0].set(0)
return updated_amax_buffers
@staticmethod
def update_collections(new: Collection, original: Collection) -> Collection:
"""
......@@ -270,8 +255,8 @@ class FP8Helper:
Generate the FP8 max array
"""
num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
fp8_max_fwd = FP8Helper.FP8_FORMAT.value.max_fwd
fp8_max_bwd = FP8Helper.FP8_FORMAT.value.max_bwd
fp8_max_fwd = jnp.finfo(FP8Helper.FWD_DTYPE).max
fp8_max_bwd = jnp.finfo(FP8Helper.BWD_DTYPE).max
fp8_max_per_gemm = []
for i in range(FP8Helper.NUM_META_PER_GEMM):
val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
......@@ -318,11 +303,40 @@ class FP8Helper:
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
@staticmethod
def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax = jnp.roll(amax, -1, -1)
updated_amax = updated_amax.at[..., 0].set(0)
return updated_amax
@staticmethod
@jax.jit
def update_fp8_scale(fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray) -> jnp.ndarray:
"""
Calculate fp8 scale and scale_inv based on given amax.
"""
if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax, axis=-1, keepdims=True)
else:
amax = amax[..., 0:1]
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = sf
scale_inv = 1 / sf
return scale, scale_inv
@contextmanager
def fp8_autocast(enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
sharding_resource: Optional[ShardingResource] = None) -> None:
mesh_resource: Optional[MeshResource] = None) -> None:
r"""
Context manager for FP8 usage.
......@@ -334,9 +348,9 @@ def fp8_autocast(enabled: bool = False,
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
sharding_resource=ShardingResource(dp_mesh_axis_name, tp_mesh_axis_name)
mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, sharding_resource=sharding_resource):
with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer()
......@@ -356,7 +370,7 @@ def fp8_autocast(enabled: bool = False,
Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training.
sharding_resource: ShardingResource, default = None
mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used.
......@@ -373,11 +387,11 @@ def fp8_autocast(enabled: bool = False,
"DelayedScaling override_linear_precision isn't supported by TE/JAX.")
assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")
if sharding_resource is None:
sharding_resource = ShardingResource()
if mesh_resource is None:
mesh_resource = MeshResource()
try:
with global_shard_guard(sharding_resource):
with global_shard_guard(mesh_resource):
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
......
......@@ -15,12 +15,6 @@ from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
class AttnBiasType(Enum):
......@@ -54,23 +48,15 @@ def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type,
head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Self fused attention wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"self_fused_attn does not support row-split tensor parallelism currently."
assert attn_mask_type is not AttnMaskType.NO_MASK, \
"Currently not support AttnMaskType.NO_MASK."
if sharding_type is ShardingType.SINGLE:
output = _self_fused_attn(qkv,
bias,
mask,
......@@ -80,36 +66,6 @@ def self_fused_attn(qkv: jnp.ndarray,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [qkv, bias, mask, seed]
batch, seqlen, _, num_head, head_dim = qkv.shape
output_shape = [batch, seqlen, num_head, head_dim]
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, None, 0, 0], [0]),
tp_dims=([3, 1, None, 0], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_self_fused_attn = partial(_self_fused_attn,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes)
return output
......@@ -118,81 +74,61 @@ def self_fused_attn(qkv: jnp.ndarray,
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_fwd(qkv,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
return output
def _self_fused_attn_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen))
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
squeezed_mask = mask[:, :, :, 0]
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias,
cu_seqlen,
squeezed_mask,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (qkv, softmax_aux, rng_state, output, cu_seqlen)
return output, (qkv, softmax_aux, rng_state, output, squeezed_mask)
def _self_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, grad):
qkv, softmax_aux, rng_state, output, cu_seqlen = ctx
doutput = grad
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
qkv, softmax_aux, rng_state, output, squeezed_mask = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
softmax_aux,
rng_state,
output,
doutput,
cu_seqlen,
dz,
squeezed_mask,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_qkv, grad_bias, None, None
_self_fused_attn.defvjp(_self_fused_attn_fwd, _self_fused_attn_bwd)
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
def cross_fused_attn(q: jnp.ndarray,
kv: jnp.ndarray,
mask: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Cross multi-head attention wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"cross_fused_attn does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE:
output = _cross_fused_attn(q,
kv,
mask,
......@@ -202,35 +138,6 @@ def cross_fused_attn(q: jnp.ndarray,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [q, kv, mask, seed]
output_shape = q.shape
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, 0, 0, None], [0]),
tp_dims=([2, 3, None, None], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_cross_fused_attn = partial(_cross_fused_attn,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes)
return output
......@@ -240,54 +147,40 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed:
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_fwd(q,
kv,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output, _ = _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
return output
def _cross_fused_attn_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
def _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
q_cu_seqlen = jnp.cumsum(q_seqlen)
q_cu_seqlen = jnp.hstack((0, q_cu_seqlen))
kv_seqlen = jnp.sum(mask[:, :, 0, :] == 0, axis=(-1, -2), dtype=jnp.int32)
kv_cu_seqlen = jnp.cumsum(kv_seqlen)
kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))
q_squeezed_mask = mask[:, :, :, 0]
kv_squeezed_mask = mask[:, :, 0, :]
output, softmax_aux = cross_fused_attn_fwd(q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
q_squeezed_mask,
kv_squeezed_mask,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen)
return output, (softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask)
def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, grad):
softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx
doutput = grad
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask = ctx
grad_q, grad_kv = cross_fused_attn_bwd(q,
kv,
softmax_aux,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
dz,
q_squeezed_mask,
kv_squeezed_mask,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
......@@ -297,4 +190,4 @@ def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropou
return grad_q, grad_kv, None, None
_cross_fused_attn.defvjp(_cross_fused_attn_fwd, _cross_fused_attn_bwd)
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
......@@ -3,25 +3,15 @@
# See LICENSE for license information.
"""JAX layernorm modules"""
from typing import Tuple, Sequence
from functools import partial, reduce
import operator
from functools import partial
import jax
import jax.numpy as jnp
from transformer_engine_jax import DType as TEDType
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .cpp_extensions import transpose
from .cpp_extensions import cast_transpose, transpose
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
from .dot import fp8_dot_impl
from .fp8 import FP8Helper, FP8MetaPackage
def canonicalize_layernorm_type(x):
......@@ -38,421 +28,241 @@ def layernorm(inputs: jnp.ndarray,
beta: jnp.ndarray,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0):
epsilon: float = 1e-6):
"""
Layernorm wrapper
LN/RMSNorm wrapper
Only support layernorm_type in ['layernorm', 'rmsnorm']
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"layernorm does not support row-split tensor parallelism currently."
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
if sharding_type is ShardingType.SINGLE:
output = _layernorm(inputs,
gamma,
beta,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
dp_dim_index, dp_axis_name, tp_axis_name)
sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
gamma_ = jnp.reshape(gamma, sharding_meta.input_shapes[1]) # 1 for gamma
beta_ = beta
beta_in_axis = {}
if beta_ is not None:
beta_ = jnp.reshape(beta_, sharding_meta.input_shapes[1]) # 1 for beta
beta_in_axis = sharding_meta.in_axes[1]
in_axes = (*sharding_meta.in_axes, beta_in_axis)
partial_ln = partial(_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
fsdp_axis_name=fsdp_axis_name)
output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, gamma_, beta_))
output = jnp.reshape(output, sharding_meta.output_shapes[0])
epsilon=epsilon)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8))
def _layernorm(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, sharding_type,
dp_axis_name, fsdp_axis_name):
output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon,
sharding_type, dp_axis_name, fsdp_axis_name)
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _layernorm(x,
gamma,
beta,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6):
output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon)
return output
def _layernorm_fwd(
x,
def _layernorm_fwd_rule(x,
gamma,
beta,
layernorm_type,
zero_centered_gamma,
epsilon,
sharding_type, # pylint: disable=unused-argument
dp_axis_name, # pylint: disable=unused-argument
fsdp_axis_name # pylint: disable=unused-argument
):
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6):
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'layernorm':
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
else:
elif layernorm_type == 'rmsnorm':
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
mu = None
return output, (mu, rsigma, x, gamma)
else:
raise ValueError(f"{layernorm_type=} is not supported.")
return output, (x, mu, rsigma, gamma)
def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name,
fsdp_axis_name, ctx, g):
mu, rsigma, x, gamma = ctx
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
x, mu, rsigma, gamma = ctx
if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(g,
dx, dgamma, dbeta = layernorm_bwd(dz,
x,
mu,
rsigma,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
elif layernorm_type == 'rmsnorm':
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(g, rsigma, x, gamma, epsilon=epsilon)
grad_beta = None
if is_dp_enabled(sharding_type.value[0]):
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
if len(fsdp_axis_name) > 0:
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
dx, dgamma = rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
else:
raise ValueError(f"{layernorm_type=} is not supported.")
return grad_input, grad_gamma, grad_beta
return dx, dgamma, dbeta
_layernorm.defvjp(_layernorm_fwd, _layernorm_bwd)
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
def layernorm_fp8_dot(x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
fp8_meta_pkg: FP8MetaPackage,
layernorm_type: str,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0) -> jnp.ndarray:
epsilon: float = 1e-6) -> jnp.ndarray:
"""
LN + fp8 dot fusion wrapper
Layernorm + FP8 GEMM
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"layernorm_fp8_dot does not support row-split tensor parallelism currently."
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
assert fp8_gemm_pkg.num_of_gemm == 1
inputs = fp8_gemm_pkg.inputs
kernel = fp8_gemm_pkg.kernels[0]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
if sharding_type is ShardingType.SINGLE:
output = _layernorm_fp8_dot(inputs,
kernel,
gamma,
beta,
fp8_max,
amax,
scale,
scale_inv,
layernorm_type,
fwd_dtype,
bwd_dtype,
contracting_dims,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name="",
tp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
ln_sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
dp_dim_index, dp_axis_name, tp_axis_name)
ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input
gamma_ = jnp.reshape(gamma, ln_sharding_meta.input_shapes[1]) # 1 for gamma
beta_ = beta
beta_in_axis = {}
if beta_ is not None:
beta_ = jnp.reshape(beta_, ln_sharding_meta.input_shapes[1]) # 1 for beta
beta_in_axis = ln_sharding_meta.in_axes[1]
kernel_tp_index = None
# TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
kernel_tp_index = len(kernel.shape) - 1
elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
kernel_tp_index = 0
input_tp_index = len(inputs.shape) - 1
dot_sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(dot_sharding_meta,
{0: dp_dim_index})
kernel_ = jnp.reshape(kernel, dot_sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
axis_resource = merge_axis_resources([
ln_sharding_meta.axis_resources, dot_sharding_meta.axis_resources,
fp8_sharding_meta.axis_resources
])
partial_ln_fp8_dot = partial(_layernorm_fp8_dot,
layernorm_type=layernorm_type,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
# input, kernel, gamma, beta, fp8_metas
in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1],
ln_sharding_meta.in_axes[1], beta_in_axis, *fp8_sharding_meta.in_axes)
output = xmap_runner(partial_ln_fp8_dot, in_axes, dot_sharding_meta.out_axes, axis_resource,
(inputs_, kernel_, gamma_, beta_, fp8_max, amax, scale, scale_inv))
output = jnp.reshape(output, dot_sharding_meta.output_shapes[0])
fp8_max = fp8_meta_pkg.fp8_max
amax = fp8_meta_pkg.amax
scale = fp8_meta_pkg.scale
scale_inv = fp8_meta_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon)
return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
zero_centered_gamma: bool, epsilon: float, sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str) -> jnp.ndarray:
output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
contracting_dims, zero_centered_gamma, epsilon,
sharding_type, dp_axis_name, tp_axis_name, fsdp_axis_name)
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12))
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float):
output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype,
zero_centered_gamma, epsilon)
return output
def _layernorm_fp8_dot_fwd(
inputs,
def _layernorm_fp8_dot_fwd_rule(
x,
kernel,
gamma,
beta,
fp8_maxs,
fp8_max,
amax,
scale,
scale_inv,
layernorm_type,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
zero_centered_gamma,
epsilon,
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:]
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size
epsilon):
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
x_amax = amax[gemm_x_idx, 0:1]
x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx]
input_amax = amax[gemm_input_idx, 0:1]
input_scale = scale[gemm_input_idx]
input_scale_inv = scale_inv[gemm_input_idx]
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, input_amax = layernorm_fwd_fp8(inputs,
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x,
gamma,
beta,
input_amax,
input_scale,
input_scale_inv,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, input_amax = rmsnorm_fwd_fp8(inputs,
ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
gamma,
input_amax,
input_scale,
input_scale_inv,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
epsilon=epsilon)
mu = None
assert inputs.shape == ln_out.shape
ln_out_ = jnp.reshape(ln_out, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
assert x.shape == ln_out.shape
kernel_amax = amax[gemm_kernel_idx, 0:1]
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
kernel_scale_inv, fwd_dtype)
output = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, ln_out_, input_scale_inv,
fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
# Kernel in (hidden_in, hidden_out...)
casted_kerenl, casted_kerenl_t, updated_kernel_amax = \
cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=1)
if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
output = jax.lax.psum(output, tp_axis_name)
# (batch..., hidden_in) x (hidden_in, hidden_out...)
kt_contracting_dims = (kernel.ndim - 1,)
output = fp8_dot_impl(ln_out, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype,
(x_contracting_dims, kt_contracting_dims))
# (input_shape_pre, input_shape_suf)
# x (kernel_shape_pre, kernel_shape_suf)
# = (input_shape_pre, kernel_shape_suf)
output_shape = input_shape_pre + kernel_shape_suf
output = jnp.reshape(output, output_shape)
ctx = (ln_out, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
k_contracting_dims)
ctx = (ln_out_, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax,
inputs.shape, kernel.shape, mu, rsigma, inputs, gamma)
return output, ctx
def _layernorm_fp8_dot_bwd(
def _layernorm_fp8_dot_bwd_rule(
layernorm_type,
fwd_dtype,
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
zero_centered_gamma,
epsilon,
sharding_type,
dp_axis_name,
tp_axis_name,
fsdp_axis_name,
ctx,
g):
ln_out_, kernel_cast, \
fp8_maxs, amax, scale, scale_inv, \
input_amax, kernel_amax, \
inputs_shape, kernel_shape, \
mu, rsigma, inputs, gamma = ctx
grad):
ln_out_, casted_kerenl, fp8_max, amax, scale, scale_inv, \
updated_x_amax, updated_kernel_amax, \
x_shape, kernel_shape, mu, rsigma, x, gamma, \
x_contracting_dims, k_contracting_dims = ctx
ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = \
FP8Helper.get_fp8_meta_indices(0)
gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx, 0:1]
grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx]
ln_out_trans = transpose(ln_out_, fwd_dtype)
g = jnp.reshape(g, (ln_out_trans.shape[1], -1))
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=min(x_contracting_dims))
# cast and transpose the grad_output
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
input_scale_inv = scale_inv[gemm_input_idx]
wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, True, ln_out_trans, input_scale_inv,
fwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape)))
gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
x_scale_inv = scale_inv[gemm_x_idx]
wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(xt_constracting_dim, gt_constracting_dim))
g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
dgrad = jnp.reshape(dgrad, inputs_shape)
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
dgrad = jax.lax.psum(dgrad, tp_axis_name)
dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_constracting_dim, k_constracting_dim))
if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad,
dx, dgamma, dbeta = layernorm_bwd(dgrad,
x,
mu,
rsigma,
inputs,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon)
grad_beta = None
amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
dx, dgamma = rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
if len(fsdp_axis_name) > 0:
wgrad = jax.lax.psum(wgrad, fsdp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
if is_tp_enabled(sharding_type.value[0]):
amax = jax.lax.pmax(amax, tp_axis_name)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
wgrad = jnp.reshape(wgrad, kernel_shape)
return grad_input, wgrad, \
grad_gamma, grad_beta, \
fp8_maxs, amax, scale, scale_inv
return dx, wgrad, \
dgamma, dbeta, \
fp8_max, amax, scale, scale_inv
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd, _layernorm_fp8_dot_bwd)
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule)
......@@ -3,462 +3,307 @@
# See LICENSE for license information.
"""JAX MLP modules"""
from typing import Tuple, Sequence, Union, Callable
from functools import partial, reduce
import operator
from typing import List
from functools import partial
import jax
import jax.numpy as jnp
from jax.interpreters import pxla
from transformer_engine_jax import DType as TEDType
from .cpp_extensions import jax_dtype_to_te_dtype
from .cpp_extensions import transpose, cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .cpp_extensions import gemm
from .sharding import MajorShardingType, ShardingType
from .sharding import get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import merge_axis_resources, infer_sharding_type
from .sharding import xmap_runner, extend_fsdp_sharding_meta
from .dot import fp8_dot_impl
from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8GemmPackage
from .fp8 import FP8Helper, FP8MetaPackage
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
thread_resources = pxla.thread_resources
def geglu(
inputs: jnp.ndarray,
contracting_dims: Sequence[int] = (-1,),
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0, # pylint: disable=unused-argument
):
def geglu(x: jnp.ndarray):
"""
Gated gelu
"""
input_shape_suf_size = reduce(operator.mul, inputs.shape[min(contracting_dims):])
assert input_shape_suf_size % 2 == 0
output_shape = (*inputs.shape[:min(contracting_dims)], input_shape_suf_size // 2)
if sharding_type is ShardingType.SINGLE:
output = _geglu(inputs, contracting_dims)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
assert x.shape[-2] == 2 # Linear + GeLU
sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, None,
dp_dim_index, dp_axis_name, tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
output = _geglu(x)
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
partial_geglu = partial(_geglu, contracting_dims=contracting_dims)
output = xmap_runner(partial_geglu, sharding_meta.in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_,))
output = jnp.reshape(output, output_shape)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _geglu(inputs: jnp.ndarray, contracting_dims: Sequence[int] = (-1,)):
@partial(jax.custom_vjp)
def _geglu(x: jnp.ndarray):
geglu_output, _ = _geglu_fwd(inputs, contracting_dims)
geglu_output, _ = _geglu_fwd_rule(x)
return geglu_output
def _geglu_fwd(inputs, contracting_dims):
inputs_real_shape = (*inputs.shape[:min(contracting_dims)],
reduce(operator.mul, inputs.shape[min(contracting_dims):]))
inputs_ = jnp.reshape(inputs, inputs_real_shape)
geglu_output = gated_gelu(inputs_)
geglu_output = jnp.expand_dims(geglu_output, min(contracting_dims))
return geglu_output, (inputs_, inputs.shape)
def _geglu_fwd_rule(x):
geglu_output = gated_gelu(x)
return geglu_output, (x,)
def _geglu_bwd(contracting_dims, ctx, g):
inputs_, inputs_shape = ctx
g = jnp.squeeze(g, min(contracting_dims))
assert inputs_.dtype == g.dtype
def _geglu_bwd_rule(ctx, g):
x, = ctx
assert x.dtype == g.dtype
dgelu = dgated_gelu(g, inputs_)
dgelu = jnp.reshape(dgelu, inputs_shape)
dgelu = dgated_gelu(g, x)
dgelu = jnp.reshape(dgelu, x.shape)
return (dgelu,)
_geglu.defvjp(_geglu_fwd, _geglu_bwd)
_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
def fp8_ln_mlp(
fp8_gemm_pkg: FP8GemmPackage,
ln_scale: jnp.ndarray,
ln_bias: jnp.ndarray,
def layernrom_geglu_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
kernels: List[jnp.ndarray],
fp8_gemm_pkg: FP8MetaPackage,
layernorm_type: str,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE,
dp_dim_index: int = 0, # pylint: disable=unused-argument
activations: Sequence[Union[str, Callable]] = ('gelu', 'linear')
) -> jnp.ndarray:
epsilon: float = 1e-6) -> jnp.ndarray:
"""
FP8 layernorm MLP wrapper
(LN + Dense + act + Dense)
Layernorm + GEMM1 + GeGLU + GEMM2
"""
assert fp8_gemm_pkg.num_of_gemm == 2
inputs = fp8_gemm_pkg.inputs
kernel_1 = fp8_gemm_pkg.kernels[0]
kernel_2 = fp8_gemm_pkg.kernels[1]
assert len(kernels) == 2
assert fp8_gemm_pkg.num_of_gemm == len(kernels)
kernel_1 = kernels[0]
kernel_2 = kernels[1]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert ln_bias is None, "ln_bias should be None if layernorm_type is 'rmsnorm'"
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
assert activations == ('gelu', 'linear')
if major_sharding_type is MajorShardingType.SINGLE:
res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, layernorm_type, activations, zero_centered_gamma, epsilon,
fwd_dtype, bwd_dtype, contracting_dims, major_sharding_type, "", "", "")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
first_part_st, second_part_st = infer_sharding_type(major_sharding_type)
ln_sharding_meta = get_elementwise_sharding_meta(first_part_st, inputs.shape,
ln_scale.shape, dp_dim_index, dp_axis_name,
tp_axis_name)
ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index})
input_tp_index = len(inputs.shape) - 1
first_dot_sharding_meta = get_dot_sharding_meta(first_part_st, inputs.shape, kernel_1.shape,
dp_dim_index, input_tp_index, 2,
contracting_dims, dp_axis_name,
tp_axis_name)
first_dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(
first_dot_sharding_meta, {0: dp_dim_index})
second_input_shape = (*first_dot_sharding_meta.output_shapes[0][:-2],
first_dot_sharding_meta.output_shapes[0][-1])
second_dot_sharding_meta = get_dot_sharding_meta(second_part_st, second_input_shape,
kernel_2.shape, dp_dim_index,
len(second_input_shape) - 1, 0,
contracting_dims, dp_axis_name,
tp_axis_name)
second_dot_sharding_meta, _ = extend_fsdp_sharding_meta(second_dot_sharding_meta,
{0: dp_dim_index})
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(first_part_st, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input
ln_scale_ = jnp.reshape(ln_scale, ln_sharding_meta.input_shapes[1]) # 1 for gamma
ln_bias_ = ln_bias
ln_bias_in_axis = {}
if ln_bias_ is not None:
ln_bias_ = jnp.reshape(ln_bias_, ln_sharding_meta.input_shapes[1]) # 1 for beta
ln_bias_in_axis = ln_sharding_meta.in_axes[1]
kernel_1_ = jnp.reshape(kernel_1, first_dot_sharding_meta.input_shapes[1]) # 1 for kernel
kernel_2_ = jnp.reshape(kernel_2,
second_dot_sharding_meta.input_shapes[1]) # 1 for kernel
axis_resource = merge_axis_resources([
ln_sharding_meta.axis_resources, first_dot_sharding_meta.axis_resources,
second_dot_sharding_meta.axis_resources, fp8_sharding_meta.axis_resources
])
partial_fp8_mlp = partial(_fp8_mlp,
layernorm_type=layernorm_type,
activations=activations,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
major_sharding_type=major_sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
in_axes = (ln_sharding_meta.in_axes[0], ln_sharding_meta.in_axes[1], ln_bias_in_axis,
first_dot_sharding_meta.in_axes[1], second_dot_sharding_meta.in_axes[1],
*fp8_sharding_meta.in_axes)
res = xmap_runner(
partial_fp8_mlp, in_axes, second_dot_sharding_meta.out_axes, axis_resource,
(inputs_, ln_scale_, ln_bias_, kernel_1_, kernel_2_, fp8_max, amax, scale, scale_inv))
res = jnp.reshape(res, second_dot_sharding_meta.output_shapes[0])
return res
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
activations: Sequence[Union[str, Callable]], zero_centered_gamma: bool, epsilon: float,
fwd_dtype: TEDType, bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int],
Sequence[int]],
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str,
fsdp_axis_name: str):
res, _ = _fp8_mlp_fwd(inputs,
ln_scale,
ln_bias,
kernel_1,
kernel_2,
fp8_maxs,
amax,
scale,
scale_inv,
layernorm_type,
activations,
zero_centered_gamma,
epsilon,
fwd_dtype,
bwd_dtype,
contracting_dims=contracting_dims,
major_sharding_type=major_sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
return res
output = _layernrom_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon)
return output
def _fp8_mlp_fwd(
inputs,
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13))
def _layernrom_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
zero_centered_gamma: bool, epsilon: float):
output, _ = _layernrom_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
scale, scale_inv, fwd_dtype, bwd_dtype,
layernorm_type, zero_centered_gamma, epsilon)
return output
def _layernrom_geglu_fp8_mlp_fwd_rule(
x,
gamma,
beta,
kernel_1,
kernel_2,
fp8_maxs,
fp8_max,
amax,
scale,
scale_inv,
layernorm_type,
activations,
zero_centered_gamma,
epsilon,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
major_sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
if activations != ('gelu', 'linear'):
raise NotImplementedError("activations only support ('gelu', 'linear') for now.")
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
kernel_1_shape_pre = kernel_1.shape[:max(rhs_contracting_dims) + 1]
kernel_1_shape_suf = kernel_1.shape[max(rhs_contracting_dims) + 1:]
kernel_2_shape_pre = kernel_2.shape[:max(rhs_contracting_dims) + 1]
kernel_2_shape_suf = kernel_2.shape[max(rhs_contracting_dims) + 1:]
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_1_pre_size = reduce(operator.mul, kernel_1_shape_pre)
kernel_1_suf_size = reduce(operator.mul, kernel_1_shape_suf)
kernel_2_pre_size = reduce(operator.mul, kernel_2_shape_pre)
assert input_contracting_size == kernel_1_pre_size
assert kernel_1_suf_size == kernel_2_pre_size * len(activations)
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_1_ = jnp.reshape(kernel_1, (kernel_1_pre_size, -1))
kernel_2_ = jnp.reshape(kernel_2, (kernel_2_pre_size, -1))
layernorm_type,
zero_centered_gamma,
epsilon):
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out)
assert len(kernel_1.shape) == 3
assert kernel_1.shape[-2] == 2
assert len(kernel_2.shape) == 2
x_contracting_dims = (len(x.shape) - 1,)
xt_batch_dims = tuple(range(1, x.ndim))
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
assert kernel_1.shape[-1] == kernel_2.shape[0]
amax = FP8Helper.update_amax_history(amax)
gemm1_input_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
x_amax = amax[gemm1_x_idx, 0:1]
x_scale = scale[gemm1_x_idx]
x_scale_inv = scale_inv[gemm1_x_idx]
input_amax = amax[gemm1_input_idx, 0:1]
input_scale = scale[gemm1_input_idx]
input_scale_inv = scale_inv[gemm1_input_idx]
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, ln_out_amax = layernorm_fwd_fp8(inputs_,
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x,
gamma,
beta,
input_amax,
input_scale,
input_scale_inv,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, ln_out_amax = rmsnorm_fwd_fp8(inputs_,
ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
gamma,
input_amax,
input_scale,
input_scale_inv,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
epsilon=epsilon)
mu = None
assert x.shape == ln_out.shape
kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
kernel_1_scale = scale[gemm1_kernel_idx]
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
kernel_1_cast, kernel_1_cast_trans, kernel_1_amax = cast_transpose(
kernel_1_, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
dense_1_output = gemm(kernel_1_cast_trans, kernel_1_scale_inv, fwd_dtype, True, ln_out,
scale_inv[gemm1_input_idx], fwd_dtype, False,
jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
gemm2_input_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
casted_kerenl_1, casted_kerenl_1_t, updated_kernel_1_amax = \
cast_transpose(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-2)
# (batch..., hidden_in) x (2, hidden_out, hidden_in)
dot_1_output = fp8_dot_impl(ln_out, casted_kerenl_1_t, x_scale_inv, kernel_1_scale_inv, x.dtype,
(x_contracting_dims, (2,)))
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
geglu_out_amax = amax[gemm2_x_idx, 0:1]
geglu_out_scale = scale[gemm2_x_idx]
geglu_out_scale_inv = scale_inv[gemm2_x_idx]
# (batch..., hidden_in) -> (batch..., hidden)
casted_geglu_out, updated_geglu_amax = gated_gelu_fp8(dot_1_output, geglu_out_amax,
geglu_out_scale, geglu_out_scale_inv,
fwd_dtype)
kernel_2_amax = amax[gemm2_kernel_idx, 0:1]
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
kernel_2_cast, kernel_2_cast_trans, kernel_2_amax = cast_transpose(
kernel_2_, kernel_2_amax, kernel_2_scale, kernel_2_scale_inv, fwd_dtype)
dense_1_out_amax = amax[gemm2_input_idx, 0:1]
dense_1_out_scale = scale[gemm2_input_idx]
dense_1_out_scale_inv = scale_inv[gemm2_input_idx]
gated_gelu_output_cast, gated_gelu_amax = gated_gelu_fp8(dense_1_output, dense_1_out_amax,
dense_1_out_scale,
dense_1_out_scale_inv, fwd_dtype)
res = gemm(kernel_2_cast_trans, kernel_2_scale_inv, fwd_dtype, True,
gated_gelu_output_cast, dense_1_out_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP):
res = jax.lax.psum(res, tp_axis_name)
# (input_shape_pre, input_shape_suf)
# x (kernel_1_shape_pre, kernel_1_shape_suf)
# x (kernel_2_shape_pre, kernel_2_shape_suf)
# = (input_shape_pre, kernel_2_shape_suf)
output_shape = input_shape_pre + kernel_2_shape_suf
res = jnp.reshape(res, output_shape)
ctx = (inputs_, ln_out, mu, rsigma, gamma, dense_1_output, gated_gelu_output_cast,
kernel_1_cast, kernel_2_cast, fp8_maxs, amax, scale, scale_inv, ln_out_amax,
gated_gelu_amax, kernel_1_amax, kernel_2_amax, inputs.shape, kernel_1.shape,
kernel_2.shape)
return res, ctx
def _fp8_mlp_bwd(
casted_kerenl_2, casted_kerenl_2_t, updated_kernel_2_amax = \
cast_transpose(kernel_2, kernel_2_amax, kernel_2_scale, kernel_2_scale_inv, fwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-1)
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kerenl_2_t, geglu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (1,)))
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kerenl_1,
casted_kerenl_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims)
return dot_2_output, ctx
def _layernrom_geglu_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
layernorm_type,
activations, # pylint: disable=unused-argument
zero_centered_gamma,
epsilon,
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
major_sharding_type,
dp_axis_name,
tp_axis_name,
fsdp_axis_name,
ctx,
g):
inputs_, ln_out, mu, rsigma, gamma, \
dense_1_output, gated_gelu_output_cast, \
kernel_1_cast, kernel_2_cast, \
fp8_maxs, amax, scale, scale_inv, \
ln_out_amax, gated_gelu_amax, kernel_1_amax, kernel_2_amax, \
input_shape, kernel_1_shape, kernel_2_shape = ctx
grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \
casted_kerenl_1, casted_kerenl_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims = ctx
g = jnp.reshape(g, (ln_out.shape[0], -1))
gemm2_input_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1]
grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx]
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
gated_gelu_output_cast_trans = transpose(gated_gelu_output_cast, fwd_dtype)
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-1)
casted_geglu_out_t = transpose(casted_geglu_out,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
gemm2_input_scale_inv = scale_inv[gemm2_input_idx]
wgrad_2 = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, True,
gated_gelu_output_cast_trans, gemm2_input_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
# (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
wgrad_2 = fp8_dot_impl(casted_geglu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims))
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
dgrad_2 = gemm(kernel_2_cast, kernel_2_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
dgrad_2 = fp8_dot_impl(casted_grad, casted_kerenl_2, grad_scale_inv, kernel_2_scale_inv,
grad.dtype, (x_contracting_dims, (1,)))
gemm1_input_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgrad_2_amax = amax[gemm1_grad_idx, 0:1]
dgrad_2_scale = scale[gemm1_grad_idx]
dgrad_2_scale_inv = scale_inv[gemm1_grad_idx]
dgelu, dgelu_trans, dgelu_amax = dgated_gelu_cast_transpose(dgrad_2, dense_1_output,
dgrad_2_amax, dgrad_2_scale,
dgrad_2_scale_inv, bwd_dtype)
ln_out_trans = transpose(ln_out, fwd_dtype)
dgeglu_amax = amax[gemm1_grad_idx, 0:1]
dgeglu_scale = scale[gemm1_grad_idx]
dgeglu_scale_inv = scale_inv[gemm1_grad_idx]
gemm1_input_scale_inv = scale_inv[gemm1_input_idx]
wgrad_1 = gemm(dgelu_trans, dgrad_2_scale_inv, bwd_dtype, True,
ln_out_trans, gemm1_input_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
casted_dgeglu, casted_dgeglu_t, updated_dgeglu_amax = dgated_gelu_cast_transpose(
dgrad_2,
dot_1_output,
dgeglu_amax,
dgeglu_scale,
dgeglu_scale_inv,
bwd_dtype,
static_axis_boundary=-1)
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (2, hidden, batch...)
xt_batch_dims_plus_act_dim = tuple(i + 1 for i in xt_batch_dims)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgeglu_t, gemm1_x_scale_inv, dgeglu_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim))
# (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out)
x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple(
i + 1 for i in x_contracting_dims)
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = gemm(kernel_1_cast, kernel_1_scale_inv, fwd_dtype, True, dgelu, dgrad_2_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP):
dgrad_1 = jax.lax.psum(dgrad_1, tp_axis_name)
dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kerenl_1, dgeglu_scale_inv, kernel_1_scale_inv,
grad.dtype, (x_contracting_dims_plus_act_dim, (
1,
2,
)))
if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad_1,
dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
x,
mu,
rsigma,
inputs_,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon)
grad_beta = None
amax = amax.at[gemm1_input_idx, 0].set(ln_out_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(dgelu_amax[0])
amax = amax.at[gemm2_input_idx, 0].set(gated_gelu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(kernel_2_amax[0])
amax = amax.at[gemm2_grad_idx, 0].set(grad_amax[0])
if major_sharding_type in (MajorShardingType.DP, MajorShardingType.DPTP):
wgrad_1 = jax.lax.psum(wgrad_1, dp_axis_name)
wgrad_2 = jax.lax.psum(wgrad_2, dp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if len(fsdp_axis_name) > 0:
wgrad_1 = jax.lax.psum(wgrad_1, fsdp_axis_name)
wgrad_2 = jax.lax.psum(wgrad_2, fsdp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP):
amax = jax.lax.pmax(amax, tp_axis_name)
grad_input = jnp.reshape(grad_input, input_shape)
wgrad_1 = jnp.reshape(wgrad_1, kernel_1_shape)
wgrad_2 = jnp.reshape(wgrad_2, kernel_2_shape)
return grad_input, grad_gamma, grad_beta, \
wgrad_1, wgrad_2, \
fp8_maxs, amax, scale, scale_inv
_fp8_mlp.defvjp(_fp8_mlp_fwd, _fp8_mlp_bwd)
dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(updated_dgeglu_amax[0])
amax = amax.at[gemm2_x_idx, 0].set(updated_geglu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax[0])
amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, \
fp8_max, amax, scale, scale_inv
_layernrom_geglu_fp8_mlp.defvjp(_layernrom_geglu_fp8_mlp_fwd_rule,
_layernrom_geglu_fp8_mlp_bwd_rule)
......@@ -49,7 +49,7 @@ class TransformerEngineBaseLayer(BaseLayer):
fp8_collection_map = {
FP8Helper.FP8_COLLECTION_NAME: [
WeightHParamsCollection.SKIP_LP_REGULARIZATION,
WeightHParamsCollection.NON_TRAINABLE,
WeightHParamsCollection.OVERWRITE_WITH_GRADIENT,
WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
]
}
......@@ -92,8 +92,7 @@ class LayerNorm(TransformerEngineBaseLayer):
"ln_bias", self.bias_init),
bias_axes=self.bias_axes,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
sharding_type=self.sharding_type)
transpose_batch_sequence=self.transpose_batch_sequence)
self.create_layer("layer_norm", ln_cls)
......@@ -115,8 +114,7 @@ class FusedSoftmax(TransformerEngineBaseLayer):
fused_softmax_cls = partial(Softmax,
scale_factor=self.scale_factor,
softmax_type=self.softmax_type,
sharding_type=self.sharding_type)
softmax_type=self.softmax_type)
self.create_layer("fused_softmax", fused_softmax_cls)
......@@ -151,8 +149,7 @@ class Linear(TransformerEngineBaseLayer):
bias_axes=self.bias_axes,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
sharding_type=self.sharding_type)
transpose_batch_sequence=self.transpose_batch_sequence)
self.create_layer("linear", dense_general_cls)
......@@ -208,8 +205,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
depth_scaling=self.depth_scaling,
sharding_type=self.sharding_type)
depth_scaling=self.depth_scaling)
self.create_layer("ln_linear", ln_dense_general_cls)
......@@ -273,8 +269,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
major_sharding_type=self.major_sharding_type)
transpose_batch_sequence=self.transpose_batch_sequence)
self.create_layer("ln_mlp", ln_mlp_cls)
......
......@@ -8,17 +8,12 @@ Sharding Meta for xmap with CustomCall
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from itertools import repeat
from typing import Union, Tuple, Dict, Callable, Sequence
from typing import Callable
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
from jax.experimental.maps import xmap
from jax.sharding import PartitionSpec
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
_PXLA_THREAD_RESOURCES = pxla.thread_resources
......@@ -29,6 +24,24 @@ def _get_mesh_info(resource: str):
return mesh.shape[resource], resource
def get_all_mesh_axes():
"""
Get all name of mesh axes
"""
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
return mesh.axis_names
def get_padded_spec(spec, ndim):
"""
Get padded spec for partitioning from arguments' information
"""
if spec is None:
return (None,) * ndim
assert len(spec) <= ndim
return spec + (None,) * (ndim - len(spec))
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
......@@ -40,8 +53,25 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
return jax.lax.with_sharding_constraint(x, pspec)
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
"""
A wrapper function to invoke lax.p* operations, like psum.
"""
if mesh_resource is not None:
_, resource = _get_mesh_info(mesh_resource)
return ops(x, resource)
return x
def num_of_devices():
"""
Get total number of detected devices
"""
return len(jax.devices())
@dataclass
class ShardingResource:
class MeshResource:
"""
A data container to indicate which axis in Mesh for data parallelism and
which for tensor parallelism.
......@@ -54,39 +84,73 @@ class ShardingResource:
tp_resource : str, default = None
The axis name in Mesh used to split the hidden dimensions along.
If it is None, then tensor parallelism is disabled.
fsdp_resource : str, default = None
The axis name in Mesh used to split the batch and weights along.
If it is None, then full-sharded data parallelism is disabled.
pp_resource : str, default = None
The axis name in Mesh used to split model layers. along.
If it is None, then pipeline parallelism is disabled.
"""
dp_resource: str = None
tp_resource: str = None
fsdp_resource: str = None
pp_resource: str = None
_GLOBAL_SHARD_RESOURCE = ShardingResource()
_GLOBAL_MESH_RESOURCE = MeshResource()
@contextmanager
def global_shard_guard(resource: ShardingResource):
def global_shard_guard(resource: MeshResource):
"""
A context manager to switch the global ShardingResource
A context manager to switch the global MeshResource
"""
global _GLOBAL_SHARD_RESOURCE
prev_gsr = _GLOBAL_SHARD_RESOURCE
global _GLOBAL_MESH_RESOURCE
prev_gmr = _GLOBAL_MESH_RESOURCE
try:
_GLOBAL_SHARD_RESOURCE = resource
_GLOBAL_MESH_RESOURCE = resource
yield
finally:
_GLOBAL_SHARD_RESOURCE = prev_gsr
_GLOBAL_MESH_RESOURCE = prev_gmr
def global_mesh_resource() -> MeshResource:
"""
A getter of the global MeshResource
"""
return _GLOBAL_MESH_RESOURCE
def global_shard_resource() -> ShardingResource:
def all_reduce_sum_along_dp_fsdp(x: jnp.array):
"""
A getter of the global ShardingResource
All-Reduce (Sum) along DP and FSDP mesh axes.
"""
return _GLOBAL_SHARD_RESOURCE
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource)
def all_reduce_max_along_all_axes_except_PP(x: jnp.array):
"""
All-Reduce (Max) along all mesh axes.
"""
all_axes = get_all_mesh_axes()
for axis in all_axes:
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis)
return x
# Deprecating Items ---------------------------------------------------------------
ShardingResource = MeshResource
global_shard_resource = global_mesh_resource
class MajorShardingType(Enum):
r"""
The major sharding type to indicate sharding pattern.
.. warning::
MajorShardingType is deprecating in the near feature.
Values
----------
......@@ -108,6 +172,8 @@ class MajorShardingType(Enum):
class ShardingType(Enum):
"""
The sharding type to indicate sharding pattern.
.. warning::
ShardingType is deprecating in the near feature.
Values
----------
......@@ -130,1058 +196,3 @@ class ShardingType(Enum):
TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def infer_major_sharding_type() -> MajorShardingType:
"""
Infer MajorShardingType from _GLOBAL_SHARD_RESOURCE
"""
gsr = global_shard_resource()
resources = [gsr.dp_resource, gsr.tp_resource, gsr.fsdp_resource]
for idx, rs in enumerate(resources):
try:
size, _ = _get_mesh_info(rs)
if size <= 1:
resources[idx] = None
except AssertionError as _:
resources[idx] = None
dp_resource = resources[0]
tp_resource = resources[1]
fsdp_resource = resources[2]
def dp_enabled():
return (fsdp_resource is not None) or (dp_resource is not None)
if dp_enabled() and tp_resource is not None:
return MajorShardingType.DPTP
if dp_enabled():
return MajorShardingType.DP
if tp_resource is not None:
return MajorShardingType.TP
return MajorShardingType.SINGLE
def infer_sharding_type(major_st: MajorShardingType = None) -> Tuple[ShardingType, ShardingType]:
"""
Infer ShardingType via given MajorShardingType
"""
if major_st is None:
major_st = infer_major_sharding_type()
if major_st is MajorShardingType.DP:
return ShardingType.DP, ShardingType.DP
if major_st is MajorShardingType.TP:
return ShardingType.TP_COL, ShardingType.TP_ROW
if major_st is MajorShardingType.DPTP:
return ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW
return ShardingType.SINGLE, ShardingType.SINGLE
def is_dp_enabled(mst: MajorShardingType) -> bool:
"""
is_dp_enabled
"""
return mst in (MajorShardingType.DP, MajorShardingType.DPTP)
def is_tp_enabled(mst: MajorShardingType) -> bool:
"""
is_tp_enabled
"""
return mst in (MajorShardingType.TP, MajorShardingType.DPTP)
def merge_axis_resources(ars: Tuple[Dict]) -> Dict:
"""
merge_axis_resources
"""
output = {}
for ar in ars:
for key in ar:
if key not in output:
output[key] = ar[key]
else:
assert output[key] == ar[key]
return output
@dataclass
class ShardingMeta:
"""ShardingMeta"""
in_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]]
out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]]
axis_resources: Dict
input_shapes: Tuple[Tuple[int, ...]]
output_shapes: Tuple[Tuple[int, ...]]
class ShardingMetaGenerator:
"""
ShardingMetaGenerator
"""
def __init__(self):
def get_single_sharding_meta(*argv, **kwargs) -> ShardingMeta: # pylint: disable=unused-argument
return None
self.sharding_type_meta_map = {
ShardingType.SINGLE: get_single_sharding_meta,
ShardingType.DP: self.get_dp_sharding_meta,
ShardingType.TP_COL: self.get_tp_col_sharding_meta,
ShardingType.TP_ROW: self.get_tp_row_sharding_meta,
ShardingType.DP_TP_COL: self.get_dp_tp_col_sharding_meta,
ShardingType.DP_TP_ROW: self.get_dp_tp_row_sharding_meta
}
def get_sharding_meta(self, stype: ShardingType, *argv, **kwargs) -> ShardingMeta:
"""get_sharding_meta"""
return self.sharding_type_meta_map[stype](*argv, **kwargs)
def get_dp_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_sharding_meta"""
raise NotImplementedError
def get_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_col_sharding_meta"""
raise NotImplementedError
def get_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_row_sharding_meta"""
raise NotImplementedError
def get_dp_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
raise NotImplementedError
def get_dp_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
raise NotImplementedError
class FP8MetaShardingMetaGenerator(ShardingMetaGenerator):
"""
FP8MetaShardingMetaGenerator
"""
def get_dp_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_tp_col_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.TP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_tp_row_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.TP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_dp_tp_col_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DPTP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_dp_tp_row_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DPTP,
num_of_meta, dp_axis_name,
tp_axis_name)
@staticmethod
def _stack_axes_meta(num_of_meta: int, mapping: Dict) -> Tuple:
return tuple(mapping for _ in range(num_of_meta))
@staticmethod
def _generate_sharding_meta(type_: MajorShardingType,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
axis_resource = {}
if is_dp_enabled(type_):
axis_resource[dp_axis_name] = global_shard_resource().dp_resource
if is_tp_enabled(type_):
axis_resource[tp_axis_name] = global_shard_resource().tp_resource
return ShardingMeta(FP8MetaShardingMetaGenerator._stack_axes_meta(num_of_meta, {}),
FP8MetaShardingMetaGenerator._stack_axes_meta(num_of_meta, {}),
axis_resource, (), ())
class FusedAttnShardingMetaGenerator(ShardingMetaGenerator):
"""
FusedAttnShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]], # pylint: disable=unused-argument
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_dp_sharding_meta"""
dummy_tp_dims = [repeat(None), repeat(None)]
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(input_shapes, output_shapes,
dp_dims, dummy_tp_dims,
dp_axis_name, None)
def get_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_col_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_tp_sharding_meta(*argv, **kwargs)
def get_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_row_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_tp_sharding_meta(*argv, **kwargs)
def get_dp_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(*argv, **kwargs)
def get_dp_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(*argv, **kwargs)
@staticmethod
def _get_tp_sharding_meta(
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]], # pylint: disable=unused-argument
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_sharding_meta"""
dummy_dp_dims = [repeat(None), repeat(None)]
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(input_shapes, output_shapes,
dummy_dp_dims, tp_dims, None,
tp_axis_name)
@staticmethod
def _get_dptp_sharding_meta(input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_sharding_meta"""
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
input_dp_dims, output_dp_dims = dp_dims
input_tp_dims, output_tp_dims = tp_dims
input_new_shapes = []
in_axes = []
for input_shape, dp_dim, tp_dim in zip(input_shapes, input_dp_dims, input_tp_dims):
in_axis = {}
if dp_dim is not None and input_shape is not None:
in_axis[dp_dim] = dp_axis_name
assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of " \
f"data parallelism size, but got {input_shape[dp_dim]=} and {dp_size=}."
input_shape = (*input_shape[:dp_dim], dp_size, input_shape[dp_dim] // dp_size,
*input_shape[dp_dim + 1:])
# the input shape has been expanded for dp_dim, tp_dim should +1 if tp_dim >= dp_dim
if tp_dim is not None and tp_dim >= dp_dim:
tp_dim = tp_dim + 1
if tp_dim is not None and input_shape is not None:
in_axis[tp_dim] = tp_axis_name
assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of " \
f"tensor parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}."
input_shape = (*input_shape[:tp_dim], tp_size, input_shape[tp_dim] // tp_size,
*input_shape[tp_dim + 1:])
in_axes.append(in_axis)
input_new_shapes.append(input_shape)
output_new_shapes = output_shapes
out_axes = []
for dp_dim, tp_dim in zip(output_dp_dims, output_tp_dims):
out_axis = {}
if dp_dim is not None:
out_axis[dp_dim] = dp_axis_name
if tp_dim is not None and tp_dim >= dp_dim:
tp_dim = tp_dim + 1
if tp_dim is not None:
out_axis[tp_dim] = tp_axis_name
out_axes.append(out_axis)
assert len(out_axes) == 1, "Only allow single output at this moment."
assert len(output_new_shapes) == 1, "Only allow single output at this moment."
out_axes = out_axes[0]
output_new_shapes = output_new_shapes[0]
axis_resources = {}
if dp_axis_name is not None:
axis_resources[dp_axis_name] = dp_mesh_axis
if tp_axis_name is not None:
axis_resources[tp_axis_name] = tp_mesh_axis
return ShardingMeta(tuple(in_axes), out_axes, axis_resources, input_new_shapes,
output_new_shapes)
class DotShardingMetaGenerator(ShardingMetaGenerator):
"""
DotShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int, # pylint: disable=unused-argument
model_dim_of_b: int, # pylint: disable=unused-argument
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
out_batch_dim = batch_dim_of_a
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
assert a_shape[batch_dim_of_a] % dp_size == 0, \
f"The dimension of batch in a_shape should be a multiple of data parallelism size," \
f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}."
a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, -1, *a_shape[batch_dim_of_a + 1:])
return ShardingMeta(({
batch_dim_of_a: dp_axis_name
}, {}), ({
out_batch_dim: dp_axis_name
}), {dp_axis_name: dp_mesh_axis}, [a_new_shape, b_shape], [out_shape])
def get_tp_col_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int, # pylint: disable=unused-argument
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
out_model_idx = len(out_shape) - (len(b_shape) - model_dim_of_b)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}."
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(({}, {
model_dim_of_b: tp_axis_name
}), ({
out_model_idx: tp_axis_name
}), {tp_axis_name: tp_mesh_axis}, [a_shape, b_new_shape], [out_shape])
def get_tp_row_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int,
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, model_dim_of_a,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert a_shape[model_dim_of_a] % tp_size == 0, \
f"The dimension of model parallelism in a_shape should be a multiple of " \
f"tensor parallelism size,but got {a_shape[model_dim_of_a]=} and {tp_size=}."
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}."
a_new_shape = (*a_shape[:model_dim_of_a], tp_size, a_shape[model_dim_of_a] // tp_size,
*a_shape[model_dim_of_a + 1:])
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(({
model_dim_of_a: tp_axis_name
}, {
model_dim_of_b: tp_axis_name
}), ({}), {tp_axis_name: tp_mesh_axis}, [a_new_shape, b_new_shape], [out_shape])
def get_dp_tp_col_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int, # pylint: disable=unused-argument
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
out_model_idx = len(out_shape) + 1 - (len(b_shape) - model_dim_of_b)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert a_shape[batch_dim_of_a] % dp_size == 0, \
f"The dimension of batch in a_shape should be a multiple of data parallelism size," \
f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}."
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}."
a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, a_shape[batch_dim_of_a] // dp_size,
*a_shape[batch_dim_of_a + 1:])
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(({
batch_dim_of_a: dp_axis_name
}, {
model_dim_of_b: tp_axis_name
}), ({
batch_dim_of_a: dp_axis_name,
out_model_idx: tp_axis_name
}), {
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
}, [a_new_shape, b_new_shape], [out_shape])
def get_dp_tp_row_sharding_meta(self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int,
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, model_dim_of_a,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert a_shape[batch_dim_of_a] % dp_size == 0, \
f"The dimension of batch in a_shape should be a multiple of data parallelism size," \
f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}."
assert a_shape[model_dim_of_a] % tp_size == 0, \
f"The dimension of model parallelism in a_shape should be a multiple of " \
f"tensor parallelism size,but got {a_shape[model_dim_of_a]=} and {tp_size=}."
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but {b_shape[model_dim_of_b]=} and {tp_size=}."
a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, a_shape[batch_dim_of_a] // dp_size,
*a_shape[batch_dim_of_a + 1:model_dim_of_a], tp_size,
a_shape[model_dim_of_a] // tp_size, *a_shape[model_dim_of_a + 1:])
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(
(
{
batch_dim_of_a:
dp_axis_name,
# "model_dim_of_a+1" is the index to tp_size in a_new_shape
model_dim_of_a + 1:
tp_axis_name
},
{
model_dim_of_b: tp_axis_name
}),
({
batch_dim_of_a: dp_axis_name
}),
{
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
},
[a_new_shape, b_new_shape],
[out_shape])
@staticmethod
def _is_supported(
a_shape: Tuple, # pylint: disable=unused-argument
b_shape: Tuple, # pylint: disable=unused-argument
batch_dim_of_a: int,
model_dim_of_a: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
):
assert batch_dim_of_a not in contracting_dims[0], \
"batch_dim_of_a should be one of contracting_dims[0]"
assert batch_dim_of_a >= 0, \
"Only support non-negative value of batch_dim_of_a."
if model_dim_of_a is not None:
assert model_dim_of_a >= 0, \
"Only support non-negative value of model_dim_of_a"
assert model_dim_of_a > batch_dim_of_a, \
"Only support the case that model_dim_of_a > batch_dim_of_a."
@staticmethod
def _infer_output_shape(
a_shape: Tuple,
b_shape: Tuple,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
):
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
return (*a_shape[:min(lhs_contracting_dims)], *b_shape[max(rhs_contracting_dims) + 1:])
class ElementwiseShardingMetaGenerator(ShardingMetaGenerator):
"""
ElementwiseShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_dp_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, batch_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
assert input_shape[batch_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism " \
f"size, but got {input_shape[batch_dim]=} and {dp_size=}."
input_new_shape = (*input_shape[:batch_dim], dp_size, -1, *input_shape[batch_dim + 1:])
in_axes = [{batch_dim: dp_axis_name}]
input_new_shapes = [input_new_shape]
if other_shape is not None:
input_new_shapes.append(other_shape)
in_axes.append({})
return ShardingMeta(tuple(in_axes), ({
batch_dim: dp_axis_name
}), {dp_axis_name: dp_mesh_axis}, input_new_shapes, [input_shape])
def get_tp_col_sharding_meta(
self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int, # pylint: disable=unused-argument
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_tp_col_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, 0)
in_axes = [{}]
input_new_shapes = [input_shape]
if other_shape is not None:
in_axes.append({})
input_new_shapes.append(other_shape)
return ShardingMeta(tuple(in_axes), ({}), {}, input_new_shapes, [input_shape])
def get_tp_row_sharding_meta(
self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int, # pylint: disable=unused-argument
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_row_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, 0)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[-1] % tp_size == 0, \
f"The last dimension in input_shape should be a multiple of tensor parallelism size," \
f" but got {input_shape[-1]=} and {tp_size=}."
input_new_shape = (*input_shape[:-1], tp_size, -1)
in_axes = [{
# "len(a_new_shape)-2" is the index to tp_size in a_new_shape
len(input_new_shape) - 2:
tp_axis_name
}]
input_new_shapes = [input_new_shape]
if other_shape is not None:
assert other_shape[0] % tp_size == 0, \
f"The first dimension in other_shape should be a multiple of tensor parallelism size," \
f" but got {other_shape[0]=} and {tp_size=}."
other_new_shape = (tp_size, -1)
in_axes.append({0: tp_axis_name})
input_new_shapes.append(other_new_shape)
return ShardingMeta(tuple(in_axes), ({
len(input_new_shape) - 2: tp_axis_name
}), {tp_axis_name: tp_mesh_axis}, input_new_shapes, [input_shape])
def get_dp_tp_col_sharding_meta(self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
return self.get_dp_sharding_meta(input_shape, other_shape, batch_dim, dp_axis_name,
tp_axis_name)
def get_dp_tp_row_sharding_meta(self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, batch_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[batch_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism" \
f"size, but got {input_shape[batch_dim]=} and {dp_size=}."
assert input_shape[-1] % tp_size == 0, \
f"The last dimension in input_shape should be a multiple of tensor parallelism size," \
f" but got {input_shape[-1]=} and {tp_size=}."
input_new_shape = (*input_shape[:batch_dim], dp_size, -1, *input_shape[batch_dim + 1:-1],
tp_size, input_shape[-1] // tp_size)
in_axes = [{
batch_dim:
dp_axis_name,
# "len(a_new_shape)-2" is the index to tp_size in a_new_shape
len(input_new_shape) - 2:
tp_axis_name
}]
input_new_shapes = [input_new_shape]
other_new_shape = other_shape
if other_shape is not None:
assert other_shape[0] % tp_size == 0, \
f"The first dimension in other_shape should be a multiple of tensor parallelism size," \
f" but got {other_shape[0]=} and {tp_size=}."
other_new_shape = (tp_size, -1)
in_axes.append({0: tp_axis_name})
input_new_shapes.append(other_new_shape)
return ShardingMeta(tuple(in_axes), ({
batch_dim: dp_axis_name,
len(input_new_shape) - 2: tp_axis_name
}), {
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
}, input_new_shapes, [input_shape])
@staticmethod
def _is_supported(input_shape: Tuple, other_shape: Tuple, batch_dim: int):
if other_shape is not None:
assert len(other_shape) == 1, "Only support 1 dimension of other_shapes currently."
assert input_shape[-1] == other_shape[0], \
f"input_shape[-1] should equal to oshape[0], " \
f"but got {input_shape[-1]} and {other_shape[0]}."
assert batch_dim < len(input_shape)-1, \
"batch_dim cannot be the latest dim"
class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
"""
SoftmaxShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_dp_sharding_meta"""
SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism " \
f"size, but got {input_shape[dp_dim]=} and {dp_size=}."
input_new_shape = (*input_shape[:dp_dim], dp_size, -1, *input_shape[dp_dim + 1:])
in_axes = [{dp_dim: dp_axis_name}]
input_new_shapes = [input_new_shape]
out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, {dp_axis_name: dp_mesh_axis},
input_new_shapes, [input_shape])
def get_tp_col_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_col_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_tp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_tp_row_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_row_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_tp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_dp_tp_col_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_dptp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_dp_tp_row_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_dptp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
@staticmethod
def _is_supported(input_shape: Tuple, dp_dim: int, tp_dim: int):
assert len(input_shape) == 4
assert dp_dim == 0
assert tp_dim == 1
@staticmethod
def _get_tp_sharding_meta(
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_tp_sharding_meta"""
SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of data " \
f"parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}."
input_new_shape = (*input_shape[:tp_dim], tp_size, -1, *input_shape[tp_dim + 1:])
in_axes = [{tp_dim: tp_axis_name}]
input_new_shapes = [input_new_shape]
out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, {tp_axis_name: tp_mesh_axis},
input_new_shapes, [input_shape])
@staticmethod
def _get_dptp_sharding_meta(input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_sharding_meta"""
SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism " \
f"size, but got {input_shape[dp_dim]=} and {dp_size=}."
assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of data " \
f"parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}."
input_new_shape = (*input_shape[:dp_dim], dp_size, input_shape[dp_dim] // dp_size,
*input_shape[dp_dim + 1:tp_dim], tp_size, input_shape[tp_dim] // tp_size,
*input_shape[tp_dim + 1:])
in_axes = [{dp_dim: dp_axis_name, tp_dim + 1: tp_axis_name}]
input_new_shapes = [input_new_shape]
out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, {
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
}, input_new_shapes, [input_shape])
def get_fp8_meta_sharding_meta(stype: ShardingType,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_fp8_meta_sharding_meta
"""
return FP8MetaShardingMetaGenerator().get_sharding_meta(stype, num_of_meta, dp_axis_name,
tp_axis_name)
def get_dot_sharding_meta(stype: ShardingType,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int,
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_dot_sharding_meta
"""
if stype in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
assert model_dim_of_b <= max(contracting_dims[1]), \
f"The dimension of model parallelism in b_shape should be smaller than the max of" \
f" contracting_dims[1], but got {model_dim_of_b=} and {contracting_dims[1]=}."
if stype in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
assert model_dim_of_b > max(contracting_dims[1]), \
f"The dimension of model parallelism in b_shape should be larger than the max of" \
f" contracting_dims[1], but got {model_dim_of_b=} and {contracting_dims[1]=}."
return DotShardingMetaGenerator().get_sharding_meta(stype, a_shape, b_shape, batch_dim_of_a,
model_dim_of_a, model_dim_of_b,
contracting_dims, dp_axis_name,
tp_axis_name)
def get_elementwise_sharding_meta(stype: ShardingType,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_elementwise_sharding_meta
"""
return ElementwiseShardingMetaGenerator().get_sharding_meta(stype, input_shape, other_shape,
batch_dim, dp_axis_name,
tp_axis_name)
def get_softmax_sharding_meta(stype: ShardingType,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_softmax_sharding_meta
"""
return SoftmaxShardingMetaGenerator().get_sharding_meta(stype, input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_fused_attn_sharding_meta(stype: ShardingType,
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_self_fused_attn_sharding_meta
"""
return FusedAttnShardingMetaGenerator().get_sharding_meta(stype, input_shapes, output_shapes,
dp_dims, tp_dims, dp_axis_name,
tp_axis_name)
def extend_fsdp_sharding_meta(sharding_meta: ShardingMeta,
weight_fsdp_dim_map: Dict[int, int]) -> Tuple[ShardingMeta, str]:
"""
Extending the given ShardingMeta to be compatible with FSDP (ZeRO3) sharding pattern.
.. note::
The extending helper assumes the first shape in sharding_meta.input_shapes
corresponding to the input tensor. Please be sure that 0-idx is in
`weight_fsdp_dim_map`.
Parameters
----------
sharding_meta : ShardingMeta
the sharding meta object to extend with FSDP.
weight_fsdp_dim_map: Dict[int, int]
The dict, which key is idx of sharding_meta.input_shapes and value is the dimension
to extend FSDP. default is None, means no other sharding_meta.input_shapes to extend.
Returns
-------
updated_sharding_meta : ShardingMeta
a sharding_meta with the FSDP extenstion.
fsdp_axis_name: str
The name of FSDP named axis for further xmap projection.
"""
assert 0 in weight_fsdp_dim_map, \
"0-idx is required to be in 'weight_fsdp_dim_map' for the input."
mst = infer_major_sharding_type()
if mst is MajorShardingType.SINGLE:
return sharding_meta, ""
gsr = global_shard_resource()
dp_mesh_axis = gsr.dp_resource
fsdp_mesh_axis = gsr.fsdp_resource
if fsdp_mesh_axis == dp_mesh_axis:
return sharding_meta, ""
if fsdp_mesh_axis is None:
return sharding_meta, ""
fsdp_dim_size, _ = _get_mesh_info(fsdp_mesh_axis)
fsdp_axis_name = "fsdp"
def get_idx_to_extend(sharded_indices, target_idx):
idx_to_extend = target_idx
for i in sharded_indices:
if i <= target_idx:
idx_to_extend += 1
return idx_to_extend
def extend_exist_sharding(idx, shape):
remain_size = shape[idx]
assert remain_size == -1 or remain_size % fsdp_dim_size == 0
remain_size = remain_size // fsdp_dim_size
new_shape = tuple([*shape[:idx], fsdp_dim_size, remain_size, *shape[idx + 1:]])
return new_shape
new_input_shapes = []
new_in_axes = []
for i, shape in enumerate(sharding_meta.input_shapes):
idx_to_extend = -1
if i == 0: # Assume first shape corresponds to input
input_dp_dim = weight_fsdp_dim_map[i]
# idx_to_extend = input_dp_dim + 1 if is_dp_enabled(mst) else input_dp_dim
idx_to_extend = get_idx_to_extend(list(sharding_meta.in_axes[i].keys()), input_dp_dim)
new_shape = extend_exist_sharding(idx_to_extend, shape)
# assume one output only and have the same batch sharding like input
assert isinstance(sharding_meta.out_axes, dict)
new_out_axes = {}
for key in sharding_meta.out_axes:
if key < idx_to_extend:
new_out_axes[key] = sharding_meta.out_axes[key]
else:
new_out_axes[key + 1] = sharding_meta.out_axes[key]
new_out_axes[idx_to_extend] = fsdp_axis_name
sharding_meta.out_axes = new_out_axes
else:
new_shape = shape
if i in weight_fsdp_dim_map:
idx_to_extend = get_idx_to_extend(list(sharding_meta.in_axes[i].keys()),
weight_fsdp_dim_map[i])
if weight_fsdp_dim_map[i] in sharding_meta.in_axes[i]:
new_shape = extend_exist_sharding(idx_to_extend, shape)
else:
assert shape[idx_to_extend] % fsdp_dim_size == 0
remain_dim_size = shape[idx_to_extend] // fsdp_dim_size
new_shape = tuple([
*shape[:idx_to_extend], fsdp_dim_size, remain_dim_size,
*shape[idx_to_extend + 1:]
])
if idx_to_extend >= 0:
new_ia = {}
for key in sharding_meta.in_axes[i]:
if key < idx_to_extend:
new_ia[key] = sharding_meta.in_axes[i][key]
else:
new_ia[key + 1] = sharding_meta.in_axes[i][key]
new_ia[idx_to_extend] = fsdp_axis_name
else:
new_ia = sharding_meta.in_axes[i]
new_input_shapes.append(new_shape)
new_in_axes.append(new_ia)
sharding_meta.input_shapes = tuple(new_input_shapes)
sharding_meta.in_axes = tuple(new_in_axes)
sharding_meta.axis_resources[fsdp_axis_name] = fsdp_mesh_axis
return sharding_meta, fsdp_axis_name
def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...],
out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]],
axis_resources: Dict, inputs: Tuple):
"""
xmap_runner
"""
assert isinstance(inputs, tuple)
assert isinstance(in_axes, tuple)
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
fake_in_axes = {}
fake_axis_resource = {}
# Fake related setup is a workaround to "NotImplementedError:
# Collectives in manually partitioned computations are only supported
# when all mesh axes are partitioned manually (no partial automatic
# sharding). Make sure that you mention all mesh axes in axis_resources!"
fake_idx_counter = 0
for mesh_axis_names in mesh.axis_names:
if mesh_axis_names not in axis_resources.values():
fake_idx_counter += 1
fake_axis_name = f"{mesh_axis_names}_fake_{fake_idx_counter}"
fake_in_axes[fake_idx_counter] = fake_axis_name
fake_axis_resource[fake_axis_name] = mesh_axis_names
fake_input = jnp.zeros(tuple(64 for _ in range(len(fake_in_axes) + 1)))
xmapped = xmap(lambda func_input, _: func(*func_input),
in_axes=(in_axes, fake_in_axes),
out_axes=out_axes,
axis_resources={
**axis_resources,
**fake_axis_resource
})
output = xmapped(inputs, fake_input)
return output
......@@ -18,11 +18,6 @@ from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd
from .cpp_extensions import ScaledSoftmaxFwdPrimitive
from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive
from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive
from .sharding import get_softmax_sharding_meta, ShardingType, ShardingMeta
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
class SoftmaxType(Enum):
......@@ -48,100 +43,47 @@ def is_softmax_kernel_available(softmax_type: SoftmaxType, batch: int, heads: in
raise NotImplementedError
def softmax(inputs: jnp.ndarray,
def softmax(logits: jnp.ndarray,
mask: Optional[jnp.ndarray] = None,
scale_factor: Optional[float] = 1.0,
softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED,
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0,
tp_dim_index: int = 1):
softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED):
"""
Softmax wrapper
"""
assert dp_dim_index == 0, \
"Only softmax support batch dim in the first place currently."
assert tp_dim_index == 1, \
"Only softmax support head dim in the second place currently."
assert mask is None or mask.shape[tp_dim_index] == 1
if sharding_type is ShardingType.SINGLE:
outputs = _softmax(inputs, mask, scale_factor, softmax_type)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_softmax_sharding_meta(sharding_type,
inputs.shape,
dp_dim=dp_dim_index,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
mask_ = mask
mask_in_axis = {}
if mask_ is not None:
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
# If mask is head broadcastable (heads == 1),
# then it equals to DP sharding.
mask_sharding_meta = get_softmax_sharding_meta(ShardingType.DP,
mask_.shape,
dp_dim=dp_dim_index,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
else:
mask_sharding_meta = ShardingMeta([{}], {}, {}, [mask_.shape], mask_.shape)
mask_sharding_meta, _ = extend_fsdp_sharding_meta(mask_sharding_meta, {0: dp_dim_index})
mask_ = jnp.reshape(mask_, mask_sharding_meta.input_shapes[0])
mask_in_axis = mask_sharding_meta.in_axes[0]
partial_softmax = partial(_softmax, scale_factor=scale_factor, softmax_type=softmax_type)
in_axes = (sharding_meta.in_axes[0], mask_in_axis)
outputs = xmap_runner(partial_softmax, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, mask_))
outputs = jnp.reshape(outputs, sharding_meta.output_shapes[0])
return outputs
output = _softmax(logits, mask, scale_factor, softmax_type)
return output
@partial(jax.custom_vjp, nondiff_argnums=(2, 3))
def _softmax(inputs, mask, scale_factor, softmax_type):
output, _ = _softmax_fwd(inputs, mask, scale_factor, softmax_type)
def _softmax(logits, mask, scale_factor, softmax_type):
output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_type)
return output
def _softmax_fwd(inputs, mask, scale_factor, softmax_type):
def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None
outputs = scaled_masked_softmax_fwd(inputs, mask, scale_factor)
output = scaled_masked_softmax_fwd(logits, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
outputs = scaled_upper_triang_masked_softmax_fwd(inputs, scale_factor)
output = scaled_upper_triang_masked_softmax_fwd(logits, scale_factor)
else:
outputs = scaled_softmax_fwd(inputs, scale_factor)
output = scaled_softmax_fwd(logits, scale_factor)
return outputs, (outputs, mask)
return output, (output,)
def _softmax_bwd(scale_factor, softmax_type, ctx, grad_outputs):
softmax_outputs, mask = ctx
def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz):
softmax_output, = ctx
if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None
dgrad = scaled_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
dgrad = scaled_masked_softmax_bwd(dz, softmax_output, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
dgrad = scaled_upper_triang_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
dgrad = scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor)
else:
dgrad = scaled_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
dgrad = scaled_softmax_bwd(dz, softmax_output, scale_factor)
return (dgrad, None)
_softmax.defvjp(_softmax_fwd, _softmax_bwd)
_softmax.defvjp(_softmax_fwd_rule, _softmax_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment