Unverified Commit a3ec6a54 authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

add building workflow for TE/Jax (#53)



* add building workflow for jax modules
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace bit_cast with reinterpret_cast
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add nvtx to cmake check list
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor rmsnorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm_bwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* set pytorch as default in setup.py
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename extension from *.cc to *.cpp

cpplint cannot recognize *.cc file, so rename the extension
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor style, to align TE/PyTorch
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add pybinding, unittest and qa
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix license
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* disable c-extension-no-member and no-name-in-module
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add dataclass avoid pylint error
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/__init__.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py

fix typo
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* add building workflow for jax modules
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace bit_cast with reinterpret_cast
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add nvtx to cmake check list
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor rmsnorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm_bwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* set pytorch as default in setup.py
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename extension from *.cc to *.cpp

cpplint cannot recognize *.cc file, so rename the extension
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor style, to align TE/PyTorch
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add pybinding, unittest and qa
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix license
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* disable c-extension-no-member and no-name-in-module
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add dataclass avoid pylint error
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/__init__.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py

fix typo
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* fix conflict due to PR62
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix c-extension-no-member and no-name-in-module

1. add transformer_engine_jax into extension-pkg-whitelist
2. convert pylintrc from CRLF to LF format
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update setup.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* remove pylint:disable and refactor import order
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

---------
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d8a2f352
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX layernorm modules"""
from typing import Tuple, Sequence
from functools import partial, reduce
import operator
import jax
import jax.numpy as jnp
from transformer_engine_jax import DType as TEDType
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .cpp_extensions import transpose
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def canonicalize_layernorm_type(x):
'''
Canonicalize the layernorm type
'''
canonicalized = x.lower().strip().replace('-', '').replace('_', '')
assert canonicalized in ['layernorm', 'rmsnorm']
return canonicalized
def layernorm(inputs: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
layernorm_type: str,
epsilon: float = 1e-6,
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0):
"""
Layernorm wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"layernorm does not support row-split tensor parallelism currently."
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
if sharding_type is ShardingType.SINGLE:
output = _layernorm(inputs,
gamma,
beta,
layernorm_type=layernorm_type,
sharding_type=sharding_type,
dp_axis_name="",
epsilon=epsilon)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
dp_dim_index, dp_axis_name, tp_axis_name)
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
gamma_ = jnp.reshape(gamma, sharding_meta.input_shapes[1]) # 1 for gamma
beta_ = beta
beta_in_axis = {}
if beta_ is not None:
beta_ = jnp.reshape(beta_, sharding_meta.input_shapes[1]) # 1 for beta
beta_in_axis = sharding_meta.in_axes[1]
in_axes = (*sharding_meta.in_axes, beta_in_axis)
partial_ln = partial(_layernorm,
layernorm_type=layernorm_type,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
epsilon=epsilon)
output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, gamma_, beta_))
output = jnp.reshape(output, sharding_meta.output_shapes[0])
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _layernorm(x, gamma, beta, layernorm_type, sharding_type, dp_axis_name, epsilon=1e-6):
output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, sharding_type, dp_axis_name, epsilon)
return output
def _layernorm_fwd(
x,
gamma,
beta,
layernorm_type,
sharding_type, # pylint: disable=unused-argument
dp_axis_name, # pylint: disable=unused-argument
epsilon):
if layernorm_type == 'layernorm':
output, mu, rsigma = layernorm_fwd(x, gamma, beta, epsilon)
else:
output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
mu = None
return output, (mu, rsigma, x, gamma)
def _layernorm_bwd(layernorm_type, sharding_type, dp_axis_name, epsilon, ctx, g):
mu, rsigma, x, gamma = ctx
if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(g, mu, rsigma, x, gamma, epsilon=epsilon)
else:
grad_input, grad_gamma = rmsnorm_bwd(g, rsigma, x, gamma, epsilon=epsilon)
grad_beta = None
if is_dp_enabled(sharding_type.value[0]):
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
return grad_input, grad_gamma, grad_beta
_layernorm.defvjp(_layernorm_fwd, _layernorm_bwd)
def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
gamma: jnp.ndarray,
beta: jnp.ndarray,
layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0,
epsilon: float = 1e-6) -> jnp.ndarray:
"""
LN + fp8 dot fusion wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"layernorm_fp8_dot does not support row-split tensor parallelism currently."
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert fp8_gemm_pkg.num_of_gemm == 1
inputs = fp8_gemm_pkg.inputs
kernel = fp8_gemm_pkg.kernels[0]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
if sharding_type is ShardingType.SINGLE:
output = _layernorm_fp8_dot(inputs,
kernel,
gamma,
beta,
fp8_max,
amax,
scale,
scale_inv,
layernorm_type,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims,
sharding_type=sharding_type,
dp_axis_name="",
tp_axis_name="",
epsilon=epsilon)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
ln_sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
dp_dim_index, dp_axis_name, tp_axis_name)
inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input
gamma_ = jnp.reshape(gamma, ln_sharding_meta.input_shapes[1]) # 1 for gamma
beta_ = beta
beta_in_axis = {}
if beta_ is not None:
beta_ = jnp.reshape(beta_, ln_sharding_meta.input_shapes[1]) # 1 for beta
beta_in_axis = ln_sharding_meta.in_axes[1]
kernel_tp_index = None
# TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
kernel_tp_index = len(kernel.shape) - 1
elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
kernel_tp_index = 0
input_tp_index = len(inputs.shape) - 1
dot_sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
kernel_ = jnp.reshape(kernel, dot_sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
axis_resource = merge_axis_resources([
ln_sharding_meta.axis_resources, dot_sharding_meta.axis_resources,
fp8_sharding_meta.axis_resources
])
partial_ln_fp8_dot = partial(_layernorm_fp8_dot,
layernorm_type=layernorm_type,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
epsilon=epsilon)
# input, kernel, gamma, beta, fp8_metas
in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1],
ln_sharding_meta.in_axes[1], beta_in_axis, *fp8_sharding_meta.in_axes)
output = xmap_runner(partial_ln_fp8_dot, in_axes, dot_sharding_meta.out_axes, axis_resource,
(inputs_, kernel_, gamma_, beta_, fp8_max, amax, scale, scale_inv))
output = jnp.reshape(output, dot_sharding_meta.output_shapes[0])
return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16))
def _layernorm_fp8_dot(inputs: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
fp8_maxs: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
sharding_type: ShardingType,
dp_axis_name: str,
tp_axis_name: str,
epsilon: float = 1e-6) -> jnp.ndarray:
output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
scale_inv, layernorm_type, amax_history_idx, fwd_dtype,
bwd_dtype, contracting_dims, sharding_type, dp_axis_name,
tp_axis_name, epsilon)
return output
def _layernorm_fp8_dot_fwd(
inputs,
kernel,
gamma,
beta,
fp8_maxs,
amax,
scale,
scale_inv,
layernorm_type,
amax_history_idx, # pylint: disable=unused-argument
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name,
epsilon):
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:]
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx]
input_scale = scale[gemm_input_idx]
input_scale_inv = scale_inv[gemm_input_idx]
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, input_amax = layernorm_fwd_fp8(inputs,
gamma,
beta,
input_amax,
input_scale,
input_scale_inv,
epsilon=epsilon)
else:
ln_out, rsigma, input_amax = rmsnorm_fwd_fp8(inputs,
gamma,
input_amax,
input_scale,
input_scale_inv,
epsilon=epsilon)
mu = None
assert inputs.shape == ln_out.shape
ln_out_ = jnp.reshape(ln_out, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
kernel_amax = amax[gemm_kernel_idx]
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
kernel_scale_inv, fwd_dtype)
output = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, ln_out_, input_scale_inv,
fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
output = jax.lax.psum(output, tp_axis_name)
# (input_shape_pre, input_shape_suf)
# x (kernel_shape_pre, kernel_shape_suf)
# = (input_shape_pre, kernel_shape_suf)
output_shape = input_shape_pre + kernel_shape_suf
output = jnp.reshape(output, output_shape)
ctx = (ln_out_, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax,
inputs.shape, kernel.shape, mu, rsigma, inputs, gamma)
return output, ctx
def _layernorm_fp8_dot_bwd(
layernorm_type,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
sharding_type,
dp_axis_name,
tp_axis_name,
epsilon,
ctx,
g):
ln_out_, kernel_cast, \
fp8_maxs, amax, scale, scale_inv, \
input_amax, kernel_amax, \
inputs_shape, kernel_shape, \
mu, rsigma, inputs, gamma = ctx
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = \
FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx]
grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx]
ln_out_trans = transpose(ln_out_, fwd_dtype)
g = jnp.reshape(g, (ln_out_trans.shape[1], -1))
# cast and transpose the grad_output
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
input_scale_inv = scale_inv[gemm_input_idx]
wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, True, ln_out_trans, input_scale_inv,
fwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
dgrad = jnp.reshape(dgrad, inputs_shape)
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
dgrad = jax.lax.psum(dgrad, tp_axis_name)
if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad, mu, rsigma, inputs, gamma, epsilon)
else:
grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon)
grad_beta = None
amax = amax.at[gemm_input_idx, amax_history_idx].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, amax_history_idx].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, amax_history_idx].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if is_tp_enabled(sharding_type.value[0]):
amax = jax.lax.pmax(amax, tp_axis_name)
wgrad = jnp.reshape(wgrad, kernel_shape)
return grad_input, wgrad, \
grad_gamma, grad_beta, \
fp8_maxs, amax, scale, scale_inv
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd, _layernorm_fp8_dot_bwd)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX MLP modules"""
from typing import Tuple, Sequence, Union, Callable
from functools import partial, reduce
import operator
import jax
import jax.numpy as jnp
from jax.interpreters import pxla
from transformer_engine_jax import DType as TEDType
from .cpp_extensions import jax_dtype_to_te_dtype
from .cpp_extensions import transpose, cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .cpp_extensions import gemm
from .sharding import MajorShardingType, ShardingType
from .sharding import get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import merge_axis_resources, infer_sharding_type
from .sharding import xmap_runner
from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8GemmPackage
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
thread_resources = pxla.thread_resources
def geglu(
inputs: jnp.ndarray,
contracting_dims: Sequence[int] = (-1,),
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0, # pylint: disable=unused-argument
):
"""
Gated gelu
"""
input_shape_suf_size = reduce(operator.mul, inputs.shape[min(contracting_dims):])
assert input_shape_suf_size % 2 == 0
output_shape = (*inputs.shape[:min(contracting_dims)], input_shape_suf_size // 2)
if sharding_type is ShardingType.SINGLE:
output = _geglu(inputs, contracting_dims)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, None,
dp_dim_index, dp_axis_name, tp_axis_name)
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
partial_geglu = partial(_geglu, contracting_dims=contracting_dims)
output = xmap_runner(partial_geglu, sharding_meta.in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_,))
output = jnp.reshape(output, output_shape)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _geglu(inputs: jnp.ndarray, contracting_dims: Sequence[int] = (-1,)):
geglu_output, _ = _geglu_fwd(inputs, contracting_dims)
return geglu_output
def _geglu_fwd(inputs, contracting_dims):
inputs_real_shape = (*inputs.shape[:min(contracting_dims)],
reduce(operator.mul, inputs.shape[min(contracting_dims):]))
inputs_ = jnp.reshape(inputs, inputs_real_shape)
geglu_output = gated_gelu(inputs_)
geglu_output = jnp.expand_dims(geglu_output, min(contracting_dims))
return geglu_output, (inputs_, inputs.shape)
def _geglu_bwd(contracting_dims, ctx, g):
inputs_, inputs_shape = ctx
g = jnp.squeeze(g, min(contracting_dims))
assert inputs_.dtype == g.dtype
dgelu = dgated_gelu(g, inputs_)
dgelu = jnp.reshape(dgelu, inputs_shape)
return (dgelu,)
_geglu.defvjp(_geglu_fwd, _geglu_bwd)
def fp8_ln_mlp(
fp8_gemm_pkg: FP8GemmPackage,
ln_scale: jnp.ndarray,
ln_bias: jnp.ndarray,
layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
epsilon: float = 1e-6,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE,
dp_dim_index: int = 0, # pylint: disable=unused-argument
activations: Sequence[Union[str, Callable]] = ('gelu', 'linear')
) -> jnp.ndarray:
"""
FP8 layernorm MLP wrapper
(LN + Dense + act + Dense)
"""
assert fp8_gemm_pkg.num_of_gemm == 2
inputs = fp8_gemm_pkg.inputs
kernel_1 = fp8_gemm_pkg.kernels[0]
kernel_2 = fp8_gemm_pkg.kernels[1]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert ln_bias is None, "ln_bias should be None if layernorm_type is 'rmsnorm'"
assert activations == ('gelu', 'linear')
if major_sharding_type is MajorShardingType.SINGLE:
res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, layernorm_type, amax_history_idx, activations, epsilon, fwd_dtype,
bwd_dtype, contracting_dims, major_sharding_type, "", "")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
first_part_st, second_part_st = infer_sharding_type(major_sharding_type)
ln_sharding_meta = get_elementwise_sharding_meta(first_part_st, inputs.shape,
ln_scale.shape, dp_dim_index, dp_axis_name,
tp_axis_name)
input_tp_index = len(inputs.shape) - 1
first_dot_sharding_meta = get_dot_sharding_meta(first_part_st, inputs.shape, kernel_1.shape,
dp_dim_index, input_tp_index, 2,
contracting_dims, dp_axis_name,
tp_axis_name)
second_input_shape = (*first_dot_sharding_meta.output_shapes[0][:-2],
first_dot_sharding_meta.output_shapes[0][-1])
second_dot_sharding_meta = get_dot_sharding_meta(second_part_st, second_input_shape,
kernel_2.shape, dp_dim_index,
len(second_input_shape) - 1, 0,
contracting_dims, dp_axis_name,
tp_axis_name)
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(first_part_st, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input
ln_scale_ = jnp.reshape(ln_scale, ln_sharding_meta.input_shapes[1]) # 1 for gamma
ln_bias_ = ln_bias
ln_bias_in_axis = {}
if ln_bias_ is not None:
ln_bias_ = jnp.reshape(ln_bias_, ln_sharding_meta.input_shapes[1]) # 1 for beta
ln_bias_in_axis = ln_sharding_meta.in_axes[1]
kernel_1_ = jnp.reshape(kernel_1, first_dot_sharding_meta.input_shapes[1]) # 1 for kernel
kernel_2_ = jnp.reshape(kernel_2,
second_dot_sharding_meta.input_shapes[1]) # 1 for kernel
axis_resource = merge_axis_resources([
ln_sharding_meta.axis_resources, first_dot_sharding_meta.axis_resources,
second_dot_sharding_meta.axis_resources, fp8_sharding_meta.axis_resources
])
partial_fp8_mlp = partial(_fp8_mlp,
layernorm_type=layernorm_type,
amax_history_idx=amax_history_idx,
activations=activations,
epsilon=epsilon,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
major_sharding_type=major_sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
in_axes = (ln_sharding_meta.in_axes[0], ln_sharding_meta.in_axes[1], ln_bias_in_axis,
first_dot_sharding_meta.in_axes[1], second_dot_sharding_meta.in_axes[1],
*fp8_sharding_meta.in_axes)
res = xmap_runner(
partial_fp8_mlp, in_axes, second_dot_sharding_meta.out_axes, axis_resource,
(inputs_, ln_scale_, ln_bias_, kernel_1_, kernel_2_, fp8_max, amax, scale, scale_inv))
res = jnp.reshape(res, second_dot_sharding_meta.output_shapes[0])
return res
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str, amax_history_idx: int,
activations: Sequence[Union[str, Callable]], epsilon: float, fwd_dtype: TEDType,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]],
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str):
res, _ = _fp8_mlp_fwd(inputs,
ln_scale,
ln_bias,
kernel_1,
kernel_2,
fp8_maxs,
amax,
scale,
scale_inv,
layernorm_type,
amax_history_idx,
activations,
epsilon,
fwd_dtype,
bwd_dtype,
contracting_dims=contracting_dims,
major_sharding_type=major_sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
return res
def _fp8_mlp_fwd(
inputs,
gamma,
beta,
kernel_1,
kernel_2,
fp8_maxs,
amax,
scale,
scale_inv,
layernorm_type,
amax_history_idx, # pylint: disable=unused-argument
activations,
epsilon,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
major_sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name):
if activations != ('gelu', 'linear'):
raise NotImplementedError("activations only support ('gelu', 'linear') for now.")
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
kernel_1_shape_pre = kernel_1.shape[:max(rhs_contracting_dims) + 1]
kernel_1_shape_suf = kernel_1.shape[max(rhs_contracting_dims) + 1:]
kernel_2_shape_pre = kernel_2.shape[:max(rhs_contracting_dims) + 1]
kernel_2_shape_suf = kernel_2.shape[max(rhs_contracting_dims) + 1:]
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_1_pre_size = reduce(operator.mul, kernel_1_shape_pre)
kernel_1_suf_size = reduce(operator.mul, kernel_1_shape_suf)
kernel_2_pre_size = reduce(operator.mul, kernel_2_shape_pre)
assert input_contracting_size == kernel_1_pre_size
assert kernel_1_suf_size == kernel_2_pre_size * len(activations)
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_1_ = jnp.reshape(kernel_1, (kernel_1_pre_size, -1))
kernel_2_ = jnp.reshape(kernel_2, (kernel_2_pre_size, -1))
gemm1_input_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm1_input_idx]
input_scale = scale[gemm1_input_idx]
input_scale_inv = scale_inv[gemm1_input_idx]
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, ln_out_amax = layernorm_fwd_fp8(inputs_,
gamma,
beta,
input_amax,
input_scale,
input_scale_inv,
epsilon=epsilon)
else:
ln_out, rsigma, ln_out_amax = rmsnorm_fwd_fp8(inputs_,
gamma,
input_amax,
input_scale,
input_scale_inv,
epsilon=epsilon)
mu = None
kernel_1_amax = amax[gemm1_kernel_idx]
kernel_1_scale = scale[gemm1_kernel_idx]
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
kernel_1_cast, kernel_1_cast_trans, kernel_1_amax = cast_transpose(
kernel_1_, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
dense_1_output = gemm(kernel_1_cast_trans, kernel_1_scale_inv, fwd_dtype, True, ln_out,
scale_inv[gemm1_input_idx], fwd_dtype, False,
jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
gemm2_input_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
kernel_2_amax = amax[gemm2_kernel_idx]
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
kernel_2_cast, kernel_2_cast_trans, kernel_2_amax = cast_transpose(
kernel_2_, kernel_2_amax, kernel_2_scale, kernel_2_scale_inv, fwd_dtype)
dense_1_out_amax = amax[gemm2_input_idx]
dense_1_out_scale = scale[gemm2_input_idx]
dense_1_out_scale_inv = scale_inv[gemm2_input_idx]
gated_gelu_output_cast, gated_gelu_amax = gated_gelu_fp8(dense_1_output, dense_1_out_amax,
dense_1_out_scale,
dense_1_out_scale_inv, fwd_dtype)
res = gemm(kernel_2_cast_trans, kernel_2_scale_inv, fwd_dtype, True,
gated_gelu_output_cast, dense_1_out_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP):
res = jax.lax.psum(res, tp_axis_name)
# (input_shape_pre, input_shape_suf)
# x (kernel_1_shape_pre, kernel_1_shape_suf)
# x (kernel_2_shape_pre, kernel_2_shape_suf)
# = (input_shape_pre, kernel_2_shape_suf)
output_shape = input_shape_pre + kernel_2_shape_suf
res = jnp.reshape(res, output_shape)
ctx = (inputs_, ln_out, mu, rsigma, gamma, dense_1_output, gated_gelu_output_cast,
kernel_1_cast, kernel_2_cast, fp8_maxs, amax, scale, scale_inv, ln_out_amax,
gated_gelu_amax, kernel_1_amax, kernel_2_amax, inputs.shape, kernel_1.shape,
kernel_2.shape)
return res, ctx
def _fp8_mlp_bwd(
layernorm_type,
amax_history_idx,
activations, # pylint: disable=unused-argument
epsilon,
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
major_sharding_type,
dp_axis_name,
tp_axis_name,
ctx,
g):
inputs_, ln_out, mu, rsigma, gamma, \
dense_1_output, gated_gelu_output_cast, \
kernel_1_cast, kernel_2_cast, \
fp8_maxs, amax, scale, scale_inv, \
ln_out_amax, gated_gelu_amax, kernel_1_amax, kernel_2_amax, \
input_shape, kernel_1_shape, kernel_2_shape = ctx
g = jnp.reshape(g, (ln_out.shape[0], -1))
gemm2_input_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx]
grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx]
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
gated_gelu_output_cast_trans = transpose(gated_gelu_output_cast, fwd_dtype)
gemm2_input_scale_inv = scale_inv[gemm2_input_idx]
wgrad_2 = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype, True,
gated_gelu_output_cast_trans, gemm2_input_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
dgrad_2 = gemm(kernel_2_cast, kernel_2_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
gemm1_input_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgrad_2_amax = amax[gemm1_grad_idx]
dgrad_2_scale = scale[gemm1_grad_idx]
dgrad_2_scale_inv = scale_inv[gemm1_grad_idx]
dgelu, dgelu_trans, dgelu_amax = dgated_gelu_cast_transpose(dgrad_2, dense_1_output,
dgrad_2_amax, dgrad_2_scale,
dgrad_2_scale_inv, bwd_dtype)
ln_out_trans = transpose(ln_out, fwd_dtype)
gemm1_input_scale_inv = scale_inv[gemm1_input_idx]
wgrad_1 = gemm(dgelu_trans, dgrad_2_scale_inv, bwd_dtype, True,
ln_out_trans, gemm1_input_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = gemm(kernel_1_cast, kernel_1_scale_inv, fwd_dtype, True, dgelu, dgrad_2_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP):
dgrad_1 = jax.lax.psum(dgrad_1, tp_axis_name)
if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad_1,
mu,
rsigma,
inputs_,
gamma,
epsilon=epsilon)
else:
grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon)
grad_beta = None
amax = amax.at[gemm1_input_idx, amax_history_idx].set(ln_out_amax[0])
amax = amax.at[gemm1_kernel_idx, amax_history_idx].set(kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, amax_history_idx].set(dgelu_amax[0])
amax = amax.at[gemm2_input_idx, amax_history_idx].set(gated_gelu_amax[0])
amax = amax.at[gemm2_kernel_idx, amax_history_idx].set(kernel_2_amax[0])
amax = amax.at[gemm2_grad_idx, amax_history_idx].set(grad_amax[0])
if major_sharding_type in (MajorShardingType.DP, MajorShardingType.DPTP):
wgrad_1 = jax.lax.psum(wgrad_1, dp_axis_name)
wgrad_2 = jax.lax.psum(wgrad_2, dp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP):
amax = jax.lax.pmax(amax, tp_axis_name)
grad_input = jnp.reshape(grad_input, input_shape)
wgrad_1 = jnp.reshape(wgrad_1, kernel_1_shape)
wgrad_2 = jnp.reshape(wgrad_2, kernel_2_shape)
return grad_input, grad_gamma, grad_beta, \
wgrad_1, wgrad_2, \
fp8_maxs, amax, scale, scale_inv
_fp8_mlp.defvjp(_fp8_mlp_fwd, _fp8_mlp_bwd)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Sharding Meta for xmap with CustomCall
"""
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Union, Tuple, Dict, Callable, Sequence
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
from jax.experimental.maps import xmap
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
_PXLA_THREAD_RESOURCES = pxla.thread_resources
def _get_mesh_info(resource: str):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
assert resource in mesh.axis_names, \
f"{resource} is not in the axis_names of Mesh {mesh}."
return mesh.shape[resource], resource
@dataclass
class ShardingResource:
"""
A data container to indicate which axis in Mesh for data parallelism and
which for tensor parallelism.
Parameters
----------
dp_resource : str, default = None
axis name in Mesh used to shard batch along.
if it is None, then disabling data parallelism.
tp_resource : str, default = None
axis name in Mesh used to split model tensor along.
if it is None, then disabling tensor parallelism.
"""
dp_resource: str = None
tp_resource: str = None
_GLOBAL_SHARD_RESOURCE = ShardingResource()
@contextmanager
def global_shard_guard(resource: ShardingResource):
"""
A context manager to switch the global ShardingResource
"""
global _GLOBAL_SHARD_RESOURCE
prev_gsr = _GLOBAL_SHARD_RESOURCE
try:
_GLOBAL_SHARD_RESOURCE = resource
yield
finally:
_GLOBAL_SHARD_RESOURCE = prev_gsr
def global_shard_resource() -> ShardingResource:
"""
A getter of the global ShardingResource
"""
return _GLOBAL_SHARD_RESOURCE
class MajorShardingType(Enum):
"""
The major sharding type to indicate sharding pattern.
`SINGLE` means single process training.
`DP` means data parallel traiing.
`TP` means tensor parallel traiing.
`DPTP` means data and tensor parallel traiing.
"""
SINGLE = 0
DP = 1
TP = 2
DPTP = 3
class ShardingType(Enum):
"""
The sharding type to indicate sharding pattern.
`SINGLE` means no sharding.
`DP` means sharding along data parallelism.
`TP_COL` means sharding along column-split tensor parallelism.
`TP_ROW` means sharding along row-split tensor parallelism.
`DP_TP_COL` means sharding along data and column-split tensor parallelism.
`DP_TP_ROW` means sharding along data and row-split tensor parallelism.
"""
SINGLE = (MajorShardingType.SINGLE, "single")
DP = (MajorShardingType.DP, "dp")
TP_COL = (MajorShardingType.TP, "tp_col")
TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def infer_major_sharding_type() -> MajorShardingType:
"""
Infer MajorShardingType from _GLOBAL_SHARD_RESOURCE
"""
gsr = global_shard_resource()
resources = [gsr.dp_resource, gsr.tp_resource]
for idx, rs in enumerate(resources):
try:
size, _ = _get_mesh_info(rs)
if size <= 1:
resources[idx] = None
except AssertionError as _:
resources[idx] = None
dp_resource = resources[0]
tp_resource = resources[1]
if dp_resource is not None and \
tp_resource is not None :
return MajorShardingType.DPTP
if dp_resource is not None:
return MajorShardingType.DP
if tp_resource is not None:
return MajorShardingType.TP
return MajorShardingType.SINGLE
def infer_sharding_type(major_st: MajorShardingType = None) -> Tuple[ShardingType, ShardingType]:
"""
Infer ShardingType via given MajorShardingType
"""
if major_st is None:
major_st = infer_major_sharding_type()
if major_st is MajorShardingType.DP:
return ShardingType.DP, ShardingType.DP
if major_st is MajorShardingType.TP:
return ShardingType.TP_COL, ShardingType.TP_ROW
if major_st is MajorShardingType.DPTP:
return ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW
return ShardingType.SINGLE, ShardingType.SINGLE
def is_dp_enabled(mst: MajorShardingType) -> bool:
"""
is_dp_enabled
"""
return mst in (MajorShardingType.DP, MajorShardingType.DPTP)
def is_tp_enabled(mst: MajorShardingType) -> bool:
"""
is_tp_enabled
"""
return mst in (MajorShardingType.TP, MajorShardingType.DPTP)
def merge_axis_resources(ars: Tuple[Dict]) -> Dict:
"""
merge_axis_resources
"""
output = {}
for ar in ars:
for key in ar:
if key not in output:
output[key] = ar[key]
else:
assert output[key] == ar[key]
return output
@dataclass
class ShardingMeta:
"""ShardingMeta"""
in_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]]
out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]]
axis_resources: Dict
input_shapes: Tuple[Tuple[int, ...]]
output_shapes: Tuple[Tuple[int, ...]]
class ShardingMetaGenerator:
"""
ShardingMetaGenerator
"""
def __init__(self):
def get_single_sharding_meta(*argv, **kwargs) -> ShardingMeta: # pylint: disable=unused-argument
return None
self.sharding_type_meta_map = {
ShardingType.SINGLE: get_single_sharding_meta,
ShardingType.DP: self.get_dp_sharding_meta,
ShardingType.TP_COL: self.get_tp_col_sharding_meta,
ShardingType.TP_ROW: self.get_tp_row_sharding_meta,
ShardingType.DP_TP_COL: self.get_dp_tp_col_sharding_meta,
ShardingType.DP_TP_ROW: self.get_dp_tp_row_sharding_meta
}
def get_sharding_meta(self, stype: ShardingType, *argv, **kwargs) -> ShardingMeta:
"""get_sharding_meta"""
return self.sharding_type_meta_map[stype](*argv, **kwargs)
def get_dp_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_sharding_meta"""
raise NotImplementedError
def get_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_col_sharding_meta"""
raise NotImplementedError
def get_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_row_sharding_meta"""
raise NotImplementedError
def get_dp_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
raise NotImplementedError
def get_dp_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
raise NotImplementedError
class FP8MetaShardingMetaGenerator(ShardingMetaGenerator):
"""
FP8MetaShardingMetaGenerator
"""
def get_dp_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_tp_col_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.TP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_tp_row_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.TP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_dp_tp_col_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DPTP,
num_of_meta, dp_axis_name,
tp_axis_name)
def get_dp_tp_row_sharding_meta(self,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
return FP8MetaShardingMetaGenerator._generate_sharding_meta(MajorShardingType.DPTP,
num_of_meta, dp_axis_name,
tp_axis_name)
@staticmethod
def _stack_axes_meta(num_of_meta: int, mapping: Dict) -> Tuple:
return tuple(mapping for _ in range(num_of_meta))
@staticmethod
def _generate_sharding_meta(type_: MajorShardingType,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
axis_resource = {}
if is_dp_enabled(type_):
axis_resource[dp_axis_name] = global_shard_resource().dp_resource
if is_tp_enabled(type_):
axis_resource[tp_axis_name] = global_shard_resource().tp_resource
return ShardingMeta(FP8MetaShardingMetaGenerator._stack_axes_meta(num_of_meta, {}),
FP8MetaShardingMetaGenerator._stack_axes_meta(num_of_meta, {}),
axis_resource, (), ())
class DotShardingMetaGenerator(ShardingMetaGenerator):
"""
DotShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int, # pylint: disable=unused-argument
model_dim_of_b: int, # pylint: disable=unused-argument
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
out_batch_dim = batch_dim_of_a
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
assert a_shape[batch_dim_of_a] % dp_size == 0, \
f"The dimension of batch in a_shape should be a multiple of data parallelism size," \
f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}."
a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, -1, *a_shape[batch_dim_of_a + 1:])
return ShardingMeta(({
batch_dim_of_a: dp_axis_name
}, {}), ({
out_batch_dim: dp_axis_name
}), {dp_axis_name: dp_mesh_axis}, [a_new_shape, b_shape], [out_shape])
def get_tp_col_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int, # pylint: disable=unused-argument
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
out_model_idx = len(out_shape) - (len(b_shape) - model_dim_of_b)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}."
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(({}, {
model_dim_of_b: tp_axis_name
}), ({
out_model_idx: tp_axis_name
}), {tp_axis_name: tp_mesh_axis}, [a_shape, b_new_shape], [out_shape])
def get_tp_row_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int,
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, model_dim_of_a,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert a_shape[model_dim_of_a] % tp_size == 0, \
f"The dimension of model parallelism in a_shape should be a multiple of " \
f"tensor parallelism size,but got {a_shape[model_dim_of_a]=} and {tp_size=}."
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}."
a_new_shape = (*a_shape[:model_dim_of_a], tp_size, a_shape[model_dim_of_a] // tp_size,
*a_shape[model_dim_of_a + 1:])
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(({
model_dim_of_a: tp_axis_name
}, {
model_dim_of_b: tp_axis_name
}), ({}), {tp_axis_name: tp_mesh_axis}, [a_new_shape, b_new_shape], [out_shape])
def get_dp_tp_col_sharding_meta(
self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int, # pylint: disable=unused-argument
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, None,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
out_model_idx = len(out_shape) + 1 - (len(b_shape) - model_dim_of_b)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert a_shape[batch_dim_of_a] % dp_size == 0, \
f"The dimension of batch in a_shape should be a multiple of data parallelism size," \
f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}."
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but got {b_shape[model_dim_of_b]=} and {tp_size=}."
a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, a_shape[batch_dim_of_a] // dp_size,
*a_shape[batch_dim_of_a + 1:])
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(({
batch_dim_of_a: dp_axis_name
}, {
model_dim_of_b: tp_axis_name
}), ({
batch_dim_of_a: dp_axis_name,
out_model_idx: tp_axis_name
}), {
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
}, [a_new_shape, b_new_shape], [out_shape])
def get_dp_tp_row_sharding_meta(self,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int,
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
DotShardingMetaGenerator._is_supported(a_shape, b_shape, batch_dim_of_a, model_dim_of_a,
contracting_dims)
out_shape = DotShardingMetaGenerator._infer_output_shape(a_shape, b_shape, contracting_dims)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert a_shape[batch_dim_of_a] % dp_size == 0, \
f"The dimension of batch in a_shape should be a multiple of data parallelism size," \
f" but got {a_shape[batch_dim_of_a]=} and {dp_size=}."
assert a_shape[model_dim_of_a] % tp_size == 0, \
f"The dimension of model parallelism in a_shape should be a multiple of " \
f"tensor parallelism size,but got {a_shape[model_dim_of_a]=} and {tp_size=}."
assert b_shape[model_dim_of_b] % tp_size == 0, \
f"The dimension of model parallelism in b_shape should be a multiple of " \
f"tensor parallelism size,but {b_shape[model_dim_of_b]=} and {tp_size=}."
a_new_shape = (*a_shape[:batch_dim_of_a], dp_size, a_shape[batch_dim_of_a] // dp_size,
*a_shape[batch_dim_of_a + 1:model_dim_of_a], tp_size,
a_shape[model_dim_of_a] // tp_size, *a_shape[model_dim_of_a + 1:])
b_new_shape = (*b_shape[:model_dim_of_b], tp_size, b_shape[model_dim_of_b] // tp_size,
*b_shape[model_dim_of_b + 1:])
return ShardingMeta(
(
{
batch_dim_of_a:
dp_axis_name,
# "model_dim_of_a+1" is the index to tp_size in a_new_shape
model_dim_of_a + 1:
tp_axis_name
},
{
model_dim_of_b: tp_axis_name
}),
({
batch_dim_of_a: dp_axis_name
}),
{
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
},
[a_new_shape, b_new_shape],
[out_shape])
@staticmethod
def _is_supported(
a_shape: Tuple, # pylint: disable=unused-argument
b_shape: Tuple, # pylint: disable=unused-argument
batch_dim_of_a: int,
model_dim_of_a: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
):
assert batch_dim_of_a not in contracting_dims[0], \
"batch_dim_of_a should be one of contracting_dims[0]"
assert batch_dim_of_a >= 0, \
"Only support non-negative value of batch_dim_of_a."
if model_dim_of_a is not None:
assert model_dim_of_a >= 0, \
"Only support non-negative value of model_dim_of_a"
assert model_dim_of_a > batch_dim_of_a, \
"Only support the case that model_dim_of_a > batch_dim_of_a."
@staticmethod
def _infer_output_shape(
a_shape: Tuple,
b_shape: Tuple,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
):
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
return (*a_shape[:min(lhs_contracting_dims)], *b_shape[max(rhs_contracting_dims) + 1:])
class ElementwiseShardingMetaGenerator(ShardingMetaGenerator):
"""
ElementwiseShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_dp_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, batch_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
assert input_shape[batch_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism " \
f"size, but got {input_shape[batch_dim]=} and {dp_size=}."
input_new_shape = (*input_shape[:batch_dim], dp_size, -1, *input_shape[batch_dim + 1:])
in_axes = [{batch_dim: dp_axis_name}]
input_new_shapes = [input_new_shape]
if other_shape is not None:
input_new_shapes.append(other_shape)
in_axes.append({})
return ShardingMeta(tuple(in_axes), ({
batch_dim: dp_axis_name
}), {dp_axis_name: dp_mesh_axis}, input_new_shapes, [input_shape])
def get_tp_col_sharding_meta(
self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int, # pylint: disable=unused-argument
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_tp_col_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, 0)
in_axes = [{}]
input_new_shapes = [input_shape]
if other_shape is not None:
in_axes.append({})
input_new_shapes.append(other_shape)
return ShardingMeta(tuple(in_axes), ({}), {}, input_new_shapes, [input_shape])
def get_tp_row_sharding_meta(
self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int, # pylint: disable=unused-argument
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_row_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, 0)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[-1] % tp_size == 0, \
f"The last dimension in input_shape should be a multiple of tensor parallelism size," \
f" but got {input_shape[-1]=} and {tp_size=}."
input_new_shape = (*input_shape[:-1], tp_size, -1)
in_axes = [{
# "len(a_new_shape)-2" is the index to tp_size in a_new_shape
len(input_new_shape) - 2:
tp_axis_name
}]
input_new_shapes = [input_new_shape]
if other_shape is not None:
assert other_shape[0] % tp_size == 0, \
f"The first dimension in other_shape should be a multiple of tensor parallelism size," \
f" but got {other_shape[0]=} and {tp_size=}."
other_new_shape = (tp_size, -1)
in_axes.append({0: tp_axis_name})
input_new_shapes.append(other_new_shape)
return ShardingMeta(tuple(in_axes), ({
len(input_new_shape) - 2: tp_axis_name
}), {tp_axis_name: tp_mesh_axis}, input_new_shapes, [input_shape])
def get_dp_tp_col_sharding_meta(self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
return self.get_dp_sharding_meta(input_shape, other_shape, batch_dim, dp_axis_name,
tp_axis_name)
def get_dp_tp_row_sharding_meta(self,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
ElementwiseShardingMetaGenerator._is_supported(input_shape, other_shape, batch_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[batch_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism" \
f"size, but got {input_shape[batch_dim]=} and {dp_size=}."
assert input_shape[-1] % tp_size == 0, \
f"The last dimension in input_shape should be a multiple of tensor parallelism size," \
f" but got {input_shape[-1]=} and {tp_size=}."
input_new_shape = (*input_shape[:batch_dim], dp_size, -1, *input_shape[batch_dim + 1:-1],
tp_size, input_shape[-1] // tp_size)
in_axes = [{
batch_dim:
dp_axis_name,
# "len(a_new_shape)-2" is the index to tp_size in a_new_shape
len(input_new_shape) - 2:
tp_axis_name
}]
input_new_shapes = [input_new_shape]
other_new_shape = other_shape
if other_shape is not None:
assert other_shape[0] % tp_size == 0, \
f"The first dimension in other_shape should be a multiple of tensor parallelism size," \
f" but got {other_shape[0]=} and {tp_size=}."
other_new_shape = (tp_size, -1)
in_axes.append({0: tp_axis_name})
input_new_shapes.append(other_new_shape)
return ShardingMeta(tuple(in_axes), ({
batch_dim: dp_axis_name,
len(input_new_shape) - 2: tp_axis_name
}), {
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
}, input_new_shapes, [input_shape])
@staticmethod
def _is_supported(input_shape: Tuple, other_shape: Tuple, batch_dim: int):
if other_shape is not None:
assert len(other_shape) == 1, "Only support 1 dimension of other_shapes currently."
assert input_shape[-1] == other_shape[0], \
f"input_shape[-1] should equal to oshape[0], " \
f"but got {input_shape[-1]} and {other_shape[0]}."
assert batch_dim < len(input_shape)-1, \
"batch_dim cannot be the latest dim"
class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
"""
SoftmaxShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_dp_sharding_meta"""
SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism " \
f"size, but got {input_shape[dp_dim]=} and {dp_size=}."
input_new_shape = (*input_shape[:dp_dim], dp_size, -1, *input_shape[dp_dim + 1:])
in_axes = [{dp_dim: dp_axis_name}]
input_new_shapes = [input_new_shape]
return ShardingMeta(tuple(in_axes), ({
dp_dim: dp_axis_name
}), {dp_axis_name: dp_mesh_axis}, input_new_shapes, [input_shape])
def get_tp_col_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_col_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_tp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_tp_row_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_row_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_tp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_dp_tp_col_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_dptp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def get_dp_tp_row_sharding_meta(self,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
return SoftmaxShardingMetaGenerator._get_dptp_sharding_meta(input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
@staticmethod
def _is_supported(input_shape: Tuple, dp_dim: int, tp_dim: int):
assert len(input_shape) == 4
assert dp_dim == 0
assert tp_dim == 1
@staticmethod
def _get_tp_sharding_meta(
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_tp_sharding_meta"""
SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of data " \
f"parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}."
input_new_shape = (*input_shape[:tp_dim], tp_size, -1, *input_shape[tp_dim + 1:])
in_axes = [{tp_dim: tp_axis_name}]
input_new_shapes = [input_new_shape]
return ShardingMeta(tuple(in_axes), ({
tp_dim: tp_axis_name
}), {tp_axis_name: tp_mesh_axis}, input_new_shapes, [input_shape])
@staticmethod
def _get_dptp_sharding_meta(input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_sharding_meta"""
SoftmaxShardingMetaGenerator._is_supported(input_shape, dp_dim, tp_dim)
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of data parallelism " \
f"size, but got {input_shape[dp_dim]=} and {dp_size=}."
assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of data " \
f"parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}."
input_new_shape = (*input_shape[:dp_dim], dp_size, input_shape[dp_dim] // dp_size,
*input_shape[dp_dim + 1:tp_dim], tp_size, input_shape[tp_dim] // tp_size,
*input_shape[tp_dim + 1:])
in_axes = [{dp_dim: dp_axis_name, tp_dim + 1: tp_axis_name}]
input_new_shapes = [input_new_shape]
out_axes = in_axes
return ShardingMeta(tuple(in_axes), out_axes, {
dp_axis_name: dp_mesh_axis,
tp_axis_name: tp_mesh_axis
}, input_new_shapes, [input_shape])
def get_fp8_meta_sharding_meta(stype: ShardingType,
num_of_meta: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_fp8_meta_sharding_meta
"""
return FP8MetaShardingMetaGenerator().get_sharding_meta(stype, num_of_meta, dp_axis_name,
tp_axis_name)
def get_dot_sharding_meta(stype: ShardingType,
a_shape: Tuple,
b_shape: Tuple,
batch_dim_of_a: int,
model_dim_of_a: int,
model_dim_of_b: int,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_dot_sharding_meta
"""
if stype in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
assert model_dim_of_b <= max(contracting_dims[1]), \
f"The dimension of model parallelism in b_shape should be smaller than the max of" \
f" contracting_dims[1], but got {model_dim_of_b=} and {contracting_dims[1]=}."
if stype in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
assert model_dim_of_b > max(contracting_dims[1]), \
f"The dimension of model parallelism in b_shape should be larger than the max of" \
f" contracting_dims[1], but got {model_dim_of_b=} and {contracting_dims[1]=}."
return DotShardingMetaGenerator().get_sharding_meta(stype, a_shape, b_shape, batch_dim_of_a,
model_dim_of_a, model_dim_of_b,
contracting_dims, dp_axis_name,
tp_axis_name)
def get_elementwise_sharding_meta(stype: ShardingType,
input_shape: Tuple,
other_shape: Tuple,
batch_dim: int,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_elementwise_sharding_meta
"""
return ElementwiseShardingMetaGenerator().get_sharding_meta(stype, input_shape, other_shape,
batch_dim, dp_axis_name,
tp_axis_name)
def get_softmax_sharding_meta(stype: ShardingType,
input_shape: Tuple,
dp_dim: int = 0,
tp_dim: int = 1,
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_softmax_sharding_meta
"""
return SoftmaxShardingMetaGenerator().get_sharding_meta(stype, input_shape, dp_dim, tp_dim,
dp_axis_name, tp_axis_name)
def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...],
out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]],
axis_resources: Dict, inputs: Tuple):
"""
xmap_runner
"""
assert isinstance(inputs, tuple)
assert isinstance(in_axes, tuple)
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
fake_in_axes = {}
fake_axis_resource = {}
# Fake related setup is a workaround to "NotImplementedError:
# Collectives in manually partitioned computations are only supported
# when all mesh axes are partitioned manually (no partial automatic
# sharding). Make sure that you mention all mesh axes in axis_resources!"
for i, mesh_axis_names in enumerate(mesh.axis_names):
if mesh_axis_names not in axis_resources.values():
fake_axis_name = f"{mesh_axis_names}_fake_{i}"
fake_in_axes[i] = fake_axis_name
fake_axis_resource[fake_axis_name] = mesh_axis_names
fake_input = jnp.zeros(tuple(64 for _ in range(len(fake_in_axes) + 1)))
xmapped = xmap(lambda func_input, _: func(*func_input),
in_axes=(in_axes, fake_in_axes),
out_axes=out_axes,
axis_resources={
**axis_resources,
**fake_axis_resource
})
output = xmapped(inputs, fake_input)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment