Unverified Commit 6464ced7 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] FSDP General Support and FP8 Support to Praxis. (#347)



* Initially commit for FSDP
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding support to fsdp xmap sharding
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Specify WeightHParamsCollection of fp8 meta.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Support partial FP8 custom calls with FSDP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding amax reduction on the fsdp mesh dim.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* clean code
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong batch axis in logic_axis_rules and add sharding_constraint to BMM1
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Support FSDP in fMHA.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix missing all-reduce of wgrads along FSDP axis.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Change default value of fsdp_axis_name to  for aligning with others
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix RuntimeError: with_sharding_constraint requires a non-empty
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly changes (review feedback)
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Removed unnecessary comments
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Mergeing input_dp_dim into weight_fsdp_dim_map
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Update transformer_engine/jax/sharding.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 7804d116
...@@ -14,7 +14,7 @@ from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype ...@@ -14,7 +14,7 @@ from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .fp8 import FP8Helper, FP8GemmPackage from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True) jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True) jax.config.update('experimental_xmap_spmd_lowering_manual', True)
...@@ -49,7 +49,8 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -49,7 +49,8 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name="", dp_axis_name="",
tp_axis_name="") tp_axis_name="",
fsdp_axis_name="")
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -64,6 +65,7 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -64,6 +65,7 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape, sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index, dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name) contracting_dims, dp_axis_name, tp_axis_name)
sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1]) # 1 for kernel kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1]) # 1 for kernel
...@@ -80,7 +82,8 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -80,7 +82,8 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes), res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes),
sharding_meta.out_axes, axis_resources, sharding_meta.out_axes, axis_resources,
(inputs_, kernel_, fp8_max, amax, scale, scale_inv)) (inputs_, kernel_, fp8_max, amax, scale, scale_inv))
...@@ -90,11 +93,11 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -90,11 +93,11 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return res return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11)) @partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType, scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType, contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str): dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs, res, _ = _fp8_dot_fwd(inputs,
kernel, kernel,
fp8_maxs, fp8_maxs,
...@@ -106,7 +109,8 @@ def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, am ...@@ -106,7 +109,8 @@ def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, am
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
return res return res
...@@ -122,7 +126,8 @@ def _fp8_dot_fwd( ...@@ -122,7 +126,8 @@ def _fp8_dot_fwd(
contracting_dims, contracting_dims,
sharding_type, sharding_type,
dp_axis_name, # pylint: disable=unused-argument dp_axis_name, # pylint: disable=unused-argument
tp_axis_name): tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)] input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):] input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
...@@ -173,6 +178,7 @@ def _fp8_dot_bwd( ...@@ -173,6 +178,7 @@ def _fp8_dot_bwd(
sharding_type, sharding_type,
dp_axis_name, dp_axis_name,
tp_axis_name, tp_axis_name,
fsdp_axis_name,
ctx, ctx,
g): g):
input_cast_trans, kernel_cast, \ input_cast_trans, kernel_cast, \
...@@ -206,6 +212,10 @@ def _fp8_dot_bwd( ...@@ -206,6 +212,10 @@ def _fp8_dot_bwd(
wgrad = jax.lax.psum(wgrad, dp_axis_name) wgrad = jax.lax.psum(wgrad, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name) amax = jax.lax.pmax(amax, dp_axis_name)
if len(fsdp_axis_name) > 0:
wgrad = jax.lax.psum(wgrad, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
if is_tp_enabled(sharding_type.value[0]): if is_tp_enabled(sharding_type.value[0]):
amax = jax.lax.pmax(amax, tp_axis_name) amax = jax.lax.pmax(amax, tp_axis_name)
......
...@@ -28,7 +28,8 @@ from ..fused_attn import is_fused_attn_kernel_available ...@@ -28,7 +28,8 @@ from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import infer_major_sharding_type, infer_sharding_type from ..sharding import infer_major_sharding_type, infer_sharding_type
from ..sharding import global_shard_resource, ShardingType from ..sharding import global_shard_resource, with_sharding_constraint
from ..sharding import ShardingType
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -39,6 +40,17 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci ...@@ -39,6 +40,17 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci
Initializer = Callable[[PRNGKey, Shape, DType], Array] Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]] LogicalRules = Sequence[Tuple[str, Union[str, None]]]
BATCH_AXES = 'nvte_batch'
SEQLEN_AXES = 'nvte_seqlen'
HEAD_AXES = 'nvte_head'
HIDDEN_AXES = 'nvte_hidden'
HIDDEN_TP_AXES = 'nvte_hidden_tp'
JOINED_AXES = 'nvte_joined'
W_NO_SHARD_AXES = 'nvte_w_no_shard'
W_FSDP_AXES = 'nvte_w_fsdp'
W_TP_AXES = 'nvte_w_tp'
W_JOINED_AXES = 'nvte_w_joined'
def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
# Generate broadcast dims for drop_path. # Generate broadcast dims for drop_path.
...@@ -91,10 +103,32 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -91,10 +103,32 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
rules_map[key] = [val] rules_map[key] = [val]
gsr = global_shard_resource() gsr = global_shard_resource()
te_logical_axis_rules = (('batch', gsr.dp_resource), ('embed', None), ('mlp', gsr.tp_resource),
('heads', gsr.tp_resource), ('kv', None), ('qkv_dim', None), batch_dim_rule = []
('kv_dim', None), ('joined_kv', gsr.tp_resource), ('act', None), if gsr.dp_resource is not None:
('relpos_buckets', None), ('length', None)) batch_dim_rule.append(gsr.dp_resource)
if gsr.fsdp_resource is not None and gsr.dp_resource != gsr.fsdp_resource:
batch_dim_rule.append(gsr.fsdp_resource)
if len(batch_dim_rule) <= 0:
batch_dim_rule = None
elif len(batch_dim_rule) == 1:
batch_dim_rule = batch_dim_rule[0]
else:
batch_dim_rule = tuple(batch_dim_rule)
te_logical_axis_rules = (
(BATCH_AXES, batch_dim_rule),
(SEQLEN_AXES, None),
(HEAD_AXES, gsr.tp_resource),
(HIDDEN_AXES, None),
(HIDDEN_TP_AXES, gsr.tp_resource),
(JOINED_AXES, None),
(W_NO_SHARD_AXES, None),
(W_FSDP_AXES, gsr.fsdp_resource),
(W_TP_AXES, gsr.tp_resource),
(W_JOINED_AXES, None),
)
extended_rules = [*rules] extended_rules = [*rules]
for item in te_logical_axis_rules: for item in te_logical_axis_rules:
...@@ -110,6 +144,18 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -110,6 +144,18 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
return tuple(extended_rules) return tuple(extended_rules)
def _with_sharding_constraint(x: Array, logical_axis_names: Shape):
assert len(x.shape) == len(logical_axis_names)
rules = extend_logical_axis_rules(tuple())
rules_dict = {}
for key, value in rules:
rules_dict[key] = value
mesh_axis_names = [rules_dict[name] for name in logical_axis_names]
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return with_sharding_constraint(x, pspec)
def _merge_mask(func, *masks: Optional[Array]): def _merge_mask(func, *masks: Optional[Array]):
masks = [m for m in masks if m is not None] masks = [m for m in masks if m is not None]
if not masks: if not masks:
...@@ -167,6 +213,9 @@ def core_attention(query: Array, ...@@ -167,6 +213,9 @@ def core_attention(query: Array,
else: else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
# When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias). # When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias).
# In this case, the scale can not fused into the Softmax module. # In this case, the scale can not fused into the Softmax module.
if bias is not None: if bias is not None:
...@@ -425,15 +474,13 @@ class MultiHeadAttention(nn.Module): ...@@ -425,15 +474,13 @@ class MultiHeadAttention(nn.Module):
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=('embed',), scale_axes=(W_NO_SHARD_AXES,),
kernel_axes=('embed', 'qkv_dim', 'joined_kv'), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=qkv_init, kernel_init=qkv_init,
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=( bias_axes=(W_JOINED_AXES, W_TP_AXES),
'qkv_dim',
'joined_kv',
),
name='qkv', name='qkv',
dtype=self.dtype)(inputs_q) dtype=self.dtype)(inputs_q)
if not use_fused_attn: if not use_fused_attn:
...@@ -449,11 +496,12 @@ class MultiHeadAttention(nn.Module): ...@@ -449,11 +496,12 @@ class MultiHeadAttention(nn.Module):
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=('embed',), scale_axes=(W_NO_SHARD_AXES,),
kernel_axes=('embed', 'joined_kv'), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=('joined_kv',), bias_axes=(W_TP_AXES,),
dtype=self.dtype, dtype=self.dtype,
kernel_init=query_init, kernel_init=query_init,
name='query')(inputs_q) name='query')(inputs_q)
...@@ -461,14 +509,11 @@ class MultiHeadAttention(nn.Module): ...@@ -461,14 +509,11 @@ class MultiHeadAttention(nn.Module):
features=(2, self.num_heads * self.head_dim), features=(2, self.num_heads * self.head_dim),
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=('embed', 'kv_dim', 'joined_kv'), kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init, kernel_init=kv_init,
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=( bias_axes=(W_JOINED_AXES, W_TP_AXES),
'kv_dim',
'joined_kv',
),
name='kv', name='kv',
dtype=self.dtype)(inputs_kv) dtype=self.dtype)(inputs_kv)
if not use_fused_attn: if not use_fused_attn:
...@@ -480,10 +525,10 @@ class MultiHeadAttention(nn.Module): ...@@ -480,10 +525,10 @@ class MultiHeadAttention(nn.Module):
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=('embed', 'joined_kv'), kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=('joined_kv',), bias_axes=(W_TP_AXES,),
dtype=self.dtype) dtype=self.dtype)
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=not self.output_layernorm,
...@@ -495,11 +540,12 @@ class MultiHeadAttention(nn.Module): ...@@ -495,11 +540,12 @@ class MultiHeadAttention(nn.Module):
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True, return_layernorm_output=True,
scale_axes=('embed',), scale_axes=(W_NO_SHARD_AXES,),
kernel_axes=('embed', 'joined_kv'), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=('joined_kv',), bias_axes=(W_TP_AXES,),
dtype=self.dtype, dtype=self.dtype,
kernel_init=query_init, kernel_init=query_init,
name='query')(inputs_q) name='query')(inputs_q)
...@@ -520,12 +566,12 @@ class MultiHeadAttention(nn.Module): ...@@ -520,12 +566,12 @@ class MultiHeadAttention(nn.Module):
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim)) key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim)) value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
qkv_sharding_constraint = \ qkv_sharding_constraint = \
('length', 'batch', 'heads','kv') \ (SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
if self.transpose_batch_sequence \ if self.transpose_batch_sequence \
else ('batch', 'length', 'heads', 'kv') else (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
query = nn_partitioning.with_sharding_constraint(query, qkv_sharding_constraint) query = _with_sharding_constraint(query, qkv_sharding_constraint)
key = nn_partitioning.with_sharding_constraint(key, qkv_sharding_constraint) key = _with_sharding_constraint(key, qkv_sharding_constraint)
value = nn_partitioning.with_sharding_constraint(value, qkv_sharding_constraint) value = _with_sharding_constraint(value, qkv_sharding_constraint)
if decode: if decode:
is_initialized = self.has_variable('cache', 'cached_key') is_initialized = self.has_variable('cache', 'cached_key')
...@@ -601,9 +647,9 @@ class MultiHeadAttention(nn.Module): ...@@ -601,9 +647,9 @@ class MultiHeadAttention(nn.Module):
if inputs_q is inputs_kv: if inputs_q is inputs_kv:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim)) qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = ('batch', 'length', 'qkv_dim', 'heads', 'kv') qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
qkv_proj = nn_partitioning.with_sharding_constraint(qkv_proj, HIDDEN_AXES)
qkv_sharding_constraint) qkv_proj = _with_sharding_constraint(qkv_proj, qkv_sharding_constraint)
x = self_fused_attn(qkv_proj, x = self_fused_attn(qkv_proj,
bias, bias,
mask, mask,
...@@ -618,10 +664,11 @@ class MultiHeadAttention(nn.Module): ...@@ -618,10 +664,11 @@ class MultiHeadAttention(nn.Module):
assert bias is None assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_heads, self.head_dim)) kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_heads, self.head_dim))
q_sharding_constraint = ('batch', 'length', 'heads', 'kv') q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = ('batch', 'length', 'kv_dim', 'heads', 'kv') kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
query = nn_partitioning.with_sharding_constraint(query, q_sharding_constraint) HIDDEN_AXES)
kv_proj = nn_partitioning.with_sharding_constraint(kv_proj, kv_sharding_constraint) query = _with_sharding_constraint(query, q_sharding_constraint)
kv_proj = _with_sharding_constraint(kv_proj, kv_sharding_constraint)
x = cross_fused_attn(query, x = cross_fused_attn(query,
kv_proj, kv_proj,
...@@ -668,20 +715,20 @@ class MultiHeadAttention(nn.Module): ...@@ -668,20 +715,20 @@ class MultiHeadAttention(nn.Module):
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
attn_context_sharding_constraint = \ attn_context_sharding_constraint = \
('length', 'batch', 'joined_kv') \ (SEQLEN_AXES, BATCH_AXES, HIDDEN_TP_AXES) \
if self.transpose_batch_sequence \ if self.transpose_batch_sequence \
else ('batch', 'length', 'joined_kv') else (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
x = nn_partitioning.with_sharding_constraint(x, attn_context_sharding_constraint) x = _with_sharding_constraint(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1], out = DenseGeneral(features=inputs_q.shape[-1],
sharding_type=second_sharding_type, sharding_type=second_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1, axis=-1,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
kernel_axes=('joined_kv', 'embed'), kernel_axes=(W_TP_AXES, W_FSDP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=('embed',), bias_axes=(W_NO_SHARD_AXES,),
dtype=self.dtype, dtype=self.dtype,
name='out')(x) name='out')(x)
return out, residual return out, residual
...@@ -1118,17 +1165,15 @@ class TransformerLayer(nn.Module): ...@@ -1118,17 +1165,15 @@ class TransformerLayer(nn.Module):
intermediate_dropout_rate=self.hidden_dropout, intermediate_dropout_rate=self.hidden_dropout,
intermediate_hidden_dropout_dims=self.hidden_dropout_dims, intermediate_hidden_dropout_dims=self.hidden_dropout_dims,
dtype=self.dtype, dtype=self.dtype,
scale_axes=('embed',), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_init=self.mlp_kernel_init, kernel_init=self.mlp_kernel_init,
kernel_axes_1=('embed', 'act', 'mlp'), kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=('mlp', 'embed'), kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes_1=( bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
'act', bias_axes_2=(W_NO_SHARD_AXES,),
'mlp',
),
bias_axes_2=('embed',),
name='mlp', name='mlp',
)(mlp_input, deterministic=deterministic) )(mlp_input, deterministic=deterministic)
...@@ -1148,8 +1193,8 @@ class TransformerLayer(nn.Module): ...@@ -1148,8 +1193,8 @@ class TransformerLayer(nn.Module):
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
scale_axes=('embed',), scale_axes=(W_NO_SHARD_AXES,),
bias_axes=('embed',), bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype, dtype=self.dtype,
sharding_type=ln_sharding_type, sharding_type=ln_sharding_type,
......
...@@ -16,7 +16,7 @@ from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd ...@@ -16,7 +16,7 @@ from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType from .sharding import ShardingType
from .sharding import xmap_runner from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True) jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True) jax.config.update('experimental_xmap_spmd_lowering_manual', True)
...@@ -82,6 +82,7 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -82,6 +82,7 @@ def self_fused_attn(qkv: jnp.ndarray,
tp_dims=([3, 1, None, 0], [2]), tp_dims=([3, 1, None, 0], [2]),
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple( inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None jnp.reshape(x, new_shape) if x is not None else None
...@@ -95,9 +96,9 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -95,9 +96,9 @@ def self_fused_attn(qkv: jnp.ndarray,
is_training=is_training) is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes, output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_) sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0]) output = jnp.reshape(output_, sharding_meta.output_shapes)
return output return output
...@@ -202,6 +203,7 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -202,6 +203,7 @@ def cross_fused_attn(q: jnp.ndarray,
tp_dims=([2, 3, None, None], [2]), tp_dims=([2, 3, None, None], [2]),
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name)
sharding_meta = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple( inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None jnp.reshape(x, new_shape) if x is not None else None
...@@ -215,9 +217,9 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -215,9 +217,9 @@ def cross_fused_attn(q: jnp.ndarray,
is_training=is_training) is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes, output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_) sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0]) output = jnp.reshape(output_, sharding_meta.output_shapes)
return output return output
......
...@@ -18,7 +18,7 @@ from .fp8 import FP8Helper, FP8GemmPackage ...@@ -18,7 +18,7 @@ from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_elementwise_sharding_meta from .sharding import ShardingType, get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True) jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True) jax.config.update('experimental_xmap_spmd_lowering_manual', True)
...@@ -61,12 +61,15 @@ def layernorm(inputs: jnp.ndarray, ...@@ -61,12 +61,15 @@ def layernorm(inputs: jnp.ndarray,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name="") dp_axis_name="",
fsdp_axis_name="")
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape, sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
dp_dim_index, dp_axis_name, tp_axis_name) dp_dim_index, dp_axis_name, tp_axis_name)
sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
gamma_ = jnp.reshape(gamma, sharding_meta.input_shapes[1]) # 1 for gamma gamma_ = jnp.reshape(gamma, sharding_meta.input_shapes[1]) # 1 for gamma
beta_ = beta beta_ = beta
...@@ -82,7 +85,8 @@ def layernorm(inputs: jnp.ndarray, ...@@ -82,7 +85,8 @@ def layernorm(inputs: jnp.ndarray,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name=dp_axis_name) dp_axis_name=dp_axis_name,
fsdp_axis_name=fsdp_axis_name)
output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes, output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, gamma_, beta_)) sharding_meta.axis_resources, (inputs_, gamma_, beta_))
...@@ -92,11 +96,11 @@ def layernorm(inputs: jnp.ndarray, ...@@ -92,11 +96,11 @@ def layernorm(inputs: jnp.ndarray,
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8))
def _layernorm(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, sharding_type, def _layernorm(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, sharding_type,
dp_axis_name): dp_axis_name, fsdp_axis_name):
output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon,
sharding_type, dp_axis_name) sharding_type, dp_axis_name, fsdp_axis_name)
return output return output
...@@ -108,7 +112,8 @@ def _layernorm_fwd( ...@@ -108,7 +112,8 @@ def _layernorm_fwd(
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sharding_type, # pylint: disable=unused-argument sharding_type, # pylint: disable=unused-argument
dp_axis_name # pylint: disable=unused-argument dp_axis_name, # pylint: disable=unused-argument
fsdp_axis_name # pylint: disable=unused-argument
): ):
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon) output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
...@@ -120,8 +125,8 @@ def _layernorm_fwd( ...@@ -120,8 +125,8 @@ def _layernorm_fwd(
return output, (mu, rsigma, x, gamma) return output, (mu, rsigma, x, gamma)
def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name, ctx, def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name,
g): fsdp_axis_name, ctx, g):
mu, rsigma, x, gamma = ctx mu, rsigma, x, gamma = ctx
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
...@@ -142,6 +147,11 @@ def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, ...@@ -142,6 +147,11 @@ def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type,
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name) grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
if grad_beta is not None: if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name) grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
if len(fsdp_axis_name) > 0:
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
return grad_input, grad_gamma, grad_beta return grad_input, grad_gamma, grad_beta
...@@ -196,13 +206,15 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -196,13 +206,15 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
epsilon=epsilon, epsilon=epsilon,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name="", dp_axis_name="",
tp_axis_name="") tp_axis_name="",
fsdp_axis_name="")
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
ln_sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape, ln_sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
dp_dim_index, dp_axis_name, tp_axis_name) dp_dim_index, dp_axis_name, tp_axis_name)
ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input
gamma_ = jnp.reshape(gamma, ln_sharding_meta.input_shapes[1]) # 1 for gamma gamma_ = jnp.reshape(gamma, ln_sharding_meta.input_shapes[1]) # 1 for gamma
beta_ = beta beta_ = beta
...@@ -222,6 +234,8 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -222,6 +234,8 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
dot_sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape, dot_sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index, dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name) contracting_dims, dp_axis_name, tp_axis_name)
dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(dot_sharding_meta,
{0: dp_dim_index})
kernel_ = jnp.reshape(kernel, dot_sharding_meta.input_shapes[1]) # 1 for kernel kernel_ = jnp.reshape(kernel, dot_sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
...@@ -242,7 +256,8 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -242,7 +256,8 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
epsilon=epsilon, epsilon=epsilon,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
# input, kernel, gamma, beta, fp8_metas # input, kernel, gamma, beta, fp8_metas
in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1], in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1],
...@@ -255,18 +270,18 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -255,18 +270,18 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16)) @partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, def _layernorm_fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, beta: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str, scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
fwd_dtype: TEDType, bwd_dtype: TEDType, fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], contracting_dims: Tuple[Sequence[int], Sequence[int]],
zero_centered_gamma: bool, epsilon: float, sharding_type: ShardingType, zero_centered_gamma: bool, epsilon: float, sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str) -> jnp.ndarray: dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str) -> jnp.ndarray:
output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale, output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
scale_inv, layernorm_type, fwd_dtype, bwd_dtype, scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
contracting_dims, zero_centered_gamma, epsilon, contracting_dims, zero_centered_gamma, epsilon,
sharding_type, dp_axis_name, tp_axis_name) sharding_type, dp_axis_name, tp_axis_name, fsdp_axis_name)
return output return output
...@@ -287,7 +302,8 @@ def _layernorm_fp8_dot_fwd( ...@@ -287,7 +302,8 @@ def _layernorm_fp8_dot_fwd(
epsilon, epsilon,
sharding_type, sharding_type,
dp_axis_name, # pylint: disable=unused-argument dp_axis_name, # pylint: disable=unused-argument
tp_axis_name): tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)] input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
...@@ -362,6 +378,7 @@ def _layernorm_fp8_dot_bwd( ...@@ -362,6 +378,7 @@ def _layernorm_fp8_dot_bwd(
sharding_type, sharding_type,
dp_axis_name, dp_axis_name,
tp_axis_name, tp_axis_name,
fsdp_axis_name,
ctx, ctx,
g): g):
ln_out_, kernel_cast, \ ln_out_, kernel_cast, \
...@@ -422,6 +439,13 @@ def _layernorm_fp8_dot_bwd( ...@@ -422,6 +439,13 @@ def _layernorm_fp8_dot_bwd(
grad_beta = jax.lax.psum(grad_beta, dp_axis_name) grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name) amax = jax.lax.pmax(amax, dp_axis_name)
if len(fsdp_axis_name) > 0:
wgrad = jax.lax.psum(wgrad, fsdp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
if is_tp_enabled(sharding_type.value[0]): if is_tp_enabled(sharding_type.value[0]):
amax = jax.lax.pmax(amax, tp_axis_name) amax = jax.lax.pmax(amax, tp_axis_name)
......
...@@ -23,7 +23,7 @@ from .sharding import MajorShardingType, ShardingType ...@@ -23,7 +23,7 @@ from .sharding import MajorShardingType, ShardingType
from .sharding import get_elementwise_sharding_meta from .sharding import get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import merge_axis_resources, infer_sharding_type from .sharding import merge_axis_resources, infer_sharding_type
from .sharding import xmap_runner from .sharding import xmap_runner, extend_fsdp_sharding_meta
from .layernorm import canonicalize_layernorm_type from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8GemmPackage from .fp8 import FP8Helper, FP8GemmPackage
...@@ -54,6 +54,7 @@ def geglu( ...@@ -54,6 +54,7 @@ def geglu(
sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, None, sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, None,
dp_dim_index, dp_axis_name, tp_axis_name) dp_dim_index, dp_axis_name, tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
...@@ -133,7 +134,7 @@ def fp8_ln_mlp( ...@@ -133,7 +134,7 @@ def fp8_ln_mlp(
if major_sharding_type is MajorShardingType.SINGLE: if major_sharding_type is MajorShardingType.SINGLE:
res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale, res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, layernorm_type, activations, zero_centered_gamma, epsilon, scale_inv, layernorm_type, activations, zero_centered_gamma, epsilon,
fwd_dtype, bwd_dtype, contracting_dims, major_sharding_type, "", "") fwd_dtype, bwd_dtype, contracting_dims, major_sharding_type, "", "", "")
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -143,12 +144,15 @@ def fp8_ln_mlp( ...@@ -143,12 +144,15 @@ def fp8_ln_mlp(
ln_sharding_meta = get_elementwise_sharding_meta(first_part_st, inputs.shape, ln_sharding_meta = get_elementwise_sharding_meta(first_part_st, inputs.shape,
ln_scale.shape, dp_dim_index, dp_axis_name, ln_scale.shape, dp_dim_index, dp_axis_name,
tp_axis_name) tp_axis_name)
ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index})
input_tp_index = len(inputs.shape) - 1 input_tp_index = len(inputs.shape) - 1
first_dot_sharding_meta = get_dot_sharding_meta(first_part_st, inputs.shape, kernel_1.shape, first_dot_sharding_meta = get_dot_sharding_meta(first_part_st, inputs.shape, kernel_1.shape,
dp_dim_index, input_tp_index, 2, dp_dim_index, input_tp_index, 2,
contracting_dims, dp_axis_name, contracting_dims, dp_axis_name,
tp_axis_name) tp_axis_name)
first_dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(
first_dot_sharding_meta, {0: dp_dim_index})
second_input_shape = (*first_dot_sharding_meta.output_shapes[0][:-2], second_input_shape = (*first_dot_sharding_meta.output_shapes[0][:-2],
first_dot_sharding_meta.output_shapes[0][-1]) first_dot_sharding_meta.output_shapes[0][-1])
second_dot_sharding_meta = get_dot_sharding_meta(second_part_st, second_input_shape, second_dot_sharding_meta = get_dot_sharding_meta(second_part_st, second_input_shape,
...@@ -156,6 +160,8 @@ def fp8_ln_mlp( ...@@ -156,6 +160,8 @@ def fp8_ln_mlp(
len(second_input_shape) - 1, 0, len(second_input_shape) - 1, 0,
contracting_dims, dp_axis_name, contracting_dims, dp_axis_name,
tp_axis_name) tp_axis_name)
second_dot_sharding_meta, _ = extend_fsdp_sharding_meta(second_dot_sharding_meta,
{0: dp_dim_index})
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(first_part_st, num_of_fp8_meta_kind, fp8_sharding_meta = get_fp8_meta_sharding_meta(first_part_st, num_of_fp8_meta_kind,
...@@ -187,7 +193,8 @@ def fp8_ln_mlp( ...@@ -187,7 +193,8 @@ def fp8_ln_mlp(
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
major_sharding_type=major_sharding_type, major_sharding_type=major_sharding_type,
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
in_axes = (ln_sharding_meta.in_axes[0], ln_sharding_meta.in_axes[1], ln_bias_in_axis, in_axes = (ln_sharding_meta.in_axes[0], ln_sharding_meta.in_axes[1], ln_bias_in_axis,
first_dot_sharding_meta.in_axes[1], second_dot_sharding_meta.in_axes[1], first_dot_sharding_meta.in_axes[1], second_dot_sharding_meta.in_axes[1],
*fp8_sharding_meta.in_axes) *fp8_sharding_meta.in_axes)
...@@ -200,14 +207,15 @@ def fp8_ln_mlp( ...@@ -200,14 +207,15 @@ def fp8_ln_mlp(
return res return res
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) @partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray, def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str, scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
activations: Sequence[Union[str, Callable]], zero_centered_gamma: bool, epsilon: float, activations: Sequence[Union[str, Callable]], zero_centered_gamma: bool, epsilon: float,
fwd_dtype: TEDType, bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], fwd_dtype: TEDType, bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int],
Sequence[int]], Sequence[int]],
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str): major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str,
fsdp_axis_name: str):
res, _ = _fp8_mlp_fwd(inputs, res, _ = _fp8_mlp_fwd(inputs,
ln_scale, ln_scale,
ln_bias, ln_bias,
...@@ -226,7 +234,8 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray, ...@@ -226,7 +234,8 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
major_sharding_type=major_sharding_type, major_sharding_type=major_sharding_type,
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
return res return res
...@@ -249,7 +258,8 @@ def _fp8_mlp_fwd( ...@@ -249,7 +258,8 @@ def _fp8_mlp_fwd(
contracting_dims, contracting_dims,
major_sharding_type, major_sharding_type,
dp_axis_name, # pylint: disable=unused-argument dp_axis_name, # pylint: disable=unused-argument
tp_axis_name): tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
if activations != ('gelu', 'linear'): if activations != ('gelu', 'linear'):
raise NotImplementedError("activations only support ('gelu', 'linear') for now.") raise NotImplementedError("activations only support ('gelu', 'linear') for now.")
lhs_contracting_dims, rhs_contracting_dims = contracting_dims lhs_contracting_dims, rhs_contracting_dims = contracting_dims
...@@ -352,6 +362,7 @@ def _fp8_mlp_bwd( ...@@ -352,6 +362,7 @@ def _fp8_mlp_bwd(
major_sharding_type, major_sharding_type,
dp_axis_name, dp_axis_name,
tp_axis_name, tp_axis_name,
fsdp_axis_name,
ctx, ctx,
g): g):
inputs_, ln_out, mu, rsigma, gamma, \ inputs_, ln_out, mu, rsigma, gamma, \
...@@ -431,6 +442,14 @@ def _fp8_mlp_bwd( ...@@ -431,6 +442,14 @@ def _fp8_mlp_bwd(
grad_beta = jax.lax.psum(grad_beta, dp_axis_name) grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name) amax = jax.lax.pmax(amax, dp_axis_name)
if len(fsdp_axis_name) > 0:
wgrad_1 = jax.lax.psum(wgrad_1, fsdp_axis_name)
wgrad_2 = jax.lax.psum(wgrad_2, fsdp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP): if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP):
amax = jax.lax.pmax(amax, tp_axis_name) amax = jax.lax.pmax(amax, tp_axis_name)
......
...@@ -9,10 +9,11 @@ from typing import Callable, Iterable, Sequence, Tuple, Union ...@@ -9,10 +9,11 @@ from typing import Callable, Iterable, Sequence, Tuple, Union
from praxis import pax_fiddle from praxis import pax_fiddle
from praxis.base_layer import init_var from praxis.base_layer import init_var
from praxis.base_layer import BaseLayer, WeightInit, WeightHParams from praxis.base_layer import BaseLayer, WeightInit, WeightHParams, WeightHParamsCollection
from praxis.layers import flax_adapter from praxis.layers import flax_adapter
from praxis.pytypes import JTensor from praxis.pytypes import JTensor
from ..fp8 import FP8Helper
from ..flax.module import DenseGeneral, LayerNormDenseGeneral from ..flax.module import DenseGeneral, LayerNormDenseGeneral
from ..flax.module import LayerNorm as flax_LayerNorm from ..flax.module import LayerNorm as flax_LayerNorm
from ..flax.module import LayerNormMLP as flax_LayerNormMLP from ..flax.module import LayerNormMLP as flax_LayerNormMLP
...@@ -45,9 +46,18 @@ class TransformerEngineBaseLayer(BaseLayer): ...@@ -45,9 +46,18 @@ class TransformerEngineBaseLayer(BaseLayer):
def create_layer(self, name, flax_module_cls): def create_layer(self, name, flax_module_cls):
"""create_layer""" """create_layer"""
fp8_collection_map = {
FP8Helper.FP8_COLLECTION_NAME: [
WeightHParamsCollection.SKIP_LP_REGULARIZATION,
WeightHParamsCollection.NON_TRAINABLE,
WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
]
}
flax_module_p = pax_fiddle.Config(flax_adapter.FlaxModuleAdapter, flax_module_p = pax_fiddle.Config(flax_adapter.FlaxModuleAdapter,
module_factory_method=flax_module_cls, module_factory_method=flax_module_cls,
logical_axes_rules=self.logical_axes_rules, logical_axes_rules=self.logical_axes_rules,
var_collection_map=fp8_collection_map,
ici_mesh_shape=self.ici_mesh_shape, ici_mesh_shape=self.ici_mesh_shape,
dcn_mesh_shape=self.dcn_mesh_shape, dcn_mesh_shape=self.dcn_mesh_shape,
mesh_axis_names=self.mesh_axis_names) mesh_axis_names=self.mesh_axis_names)
......
...@@ -14,6 +14,7 @@ from jax.interpreters import pxla ...@@ -14,6 +14,7 @@ from jax.interpreters import pxla
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.experimental.maps import xmap from jax.experimental.maps import xmap
from jax.sharding import PartitionSpec
jax.config.update('experimental_xmap_spmd_lowering', True) jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True) jax.config.update('experimental_xmap_spmd_lowering_manual', True)
...@@ -28,6 +29,17 @@ def _get_mesh_info(resource: str): ...@@ -28,6 +29,17 @@ def _get_mesh_info(resource: str):
return mesh.shape[resource], resource return mesh.shape[resource], resource
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
"""
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return x
return jax.lax.with_sharding_constraint(x, pspec)
@dataclass @dataclass
class ShardingResource: class ShardingResource:
""" """
...@@ -45,6 +57,7 @@ class ShardingResource: ...@@ -45,6 +57,7 @@ class ShardingResource:
""" """
dp_resource: str = None dp_resource: str = None
tp_resource: str = None tp_resource: str = None
fsdp_resource: str = None
_GLOBAL_SHARD_RESOURCE = ShardingResource() _GLOBAL_SHARD_RESOURCE = ShardingResource()
...@@ -421,6 +434,11 @@ class FusedAttnShardingMetaGenerator(ShardingMetaGenerator): ...@@ -421,6 +434,11 @@ class FusedAttnShardingMetaGenerator(ShardingMetaGenerator):
out_axis[tp_dim] = tp_axis_name out_axis[tp_dim] = tp_axis_name
out_axes.append(out_axis) out_axes.append(out_axis)
assert len(out_axes) == 1, "Only allow single output at this moment."
assert len(output_new_shapes) == 1, "Only allow single output at this moment."
out_axes = out_axes[0]
output_new_shapes = output_new_shapes[0]
axis_resources = {} axis_resources = {}
if dp_axis_name is not None: if dp_axis_name is not None:
axis_resources[dp_axis_name] = dp_mesh_axis axis_resources[dp_axis_name] = dp_mesh_axis
...@@ -1015,6 +1033,119 @@ def get_fused_attn_sharding_meta(stype: ShardingType, ...@@ -1015,6 +1033,119 @@ def get_fused_attn_sharding_meta(stype: ShardingType,
tp_axis_name) tp_axis_name)
def extend_fsdp_sharding_meta(sharding_meta: ShardingMeta,
weight_fsdp_dim_map: Dict[int, int]) -> Tuple[ShardingMeta, str]:
"""
Extending the given ShardingMeta to be compatible with FSDP (ZeRO3) sharding pattern.
.. note::
The extending helper assumes the first shape in sharding_meta.input_shapes
corresponding to the input tensor. Please be sure that 0-idx is in
`weight_fsdp_dim_map`.
Parameters
----------
sharding_meta : ShardingMeta
the sharding meta object to extend with FSDP.
weight_fsdp_dim_map: Dict[int, int]
The dict, which key is idx of sharding_meta.input_shapes and value is the dimension
to extend FSDP. default is None, means no other sharding_meta.input_shapes to extend.
Returns
-------
updated_sharding_meta : ShardingMeta
a sharding_meta with the FSDP extenstion.
fsdp_axis_name: str
The name of FSDP named axis for further xmap projection.
"""
assert 0 in weight_fsdp_dim_map, \
"0-idx is required to be in 'weight_fsdp_dim_map' for the input."
mst = infer_major_sharding_type()
if mst is MajorShardingType.SINGLE:
return sharding_meta, ""
gsr = global_shard_resource()
dp_mesh_axis = gsr.dp_resource
fsdp_mesh_axis = gsr.fsdp_resource
if fsdp_mesh_axis == dp_mesh_axis:
return sharding_meta, ""
if fsdp_mesh_axis is None:
return sharding_meta, ""
fsdp_dim_size, _ = _get_mesh_info(fsdp_mesh_axis)
fsdp_axis_name = "fsdp"
def get_idx_to_extend(sharded_indices, target_idx):
idx_to_extend = target_idx
for i in sharded_indices:
if i <= target_idx:
idx_to_extend += 1
return idx_to_extend
def extend_exist_sharding(idx, shape):
remain_size = shape[idx]
assert remain_size == -1 or remain_size % fsdp_dim_size == 0
remain_size = remain_size // fsdp_dim_size
new_shape = tuple([*shape[:idx], fsdp_dim_size, remain_size, *shape[idx + 1:]])
return new_shape
new_input_shapes = []
new_in_axes = []
for i, shape in enumerate(sharding_meta.input_shapes):
idx_to_extend = -1
if i == 0: # Assume first shape corresponds to input
input_dp_dim = weight_fsdp_dim_map[i]
# idx_to_extend = input_dp_dim + 1 if is_dp_enabled(mst) else input_dp_dim
idx_to_extend = get_idx_to_extend(list(sharding_meta.in_axes[i].keys()), input_dp_dim)
new_shape = extend_exist_sharding(idx_to_extend, shape)
# assume one output only and have the same batch sharding like input
assert isinstance(sharding_meta.out_axes, dict)
new_out_axes = {}
for key in sharding_meta.out_axes:
if key < idx_to_extend:
new_out_axes[key] = sharding_meta.out_axes[key]
else:
new_out_axes[key + 1] = sharding_meta.out_axes[key]
new_out_axes[idx_to_extend] = fsdp_axis_name
sharding_meta.out_axes = new_out_axes
else:
new_shape = shape
if i in weight_fsdp_dim_map:
idx_to_extend = get_idx_to_extend(list(sharding_meta.in_axes[i].keys()),
weight_fsdp_dim_map[i])
if weight_fsdp_dim_map[i] in sharding_meta.in_axes[i]:
new_shape = extend_exist_sharding(idx_to_extend, shape)
else:
assert shape[idx_to_extend] % fsdp_dim_size == 0
remain_dim_size = shape[idx_to_extend] // fsdp_dim_size
new_shape = tuple([
*shape[:idx_to_extend], fsdp_dim_size, remain_dim_size,
*shape[idx_to_extend + 1:]
])
if idx_to_extend >= 0:
new_ia = {}
for key in sharding_meta.in_axes[i]:
if key < idx_to_extend:
new_ia[key] = sharding_meta.in_axes[i][key]
else:
new_ia[key + 1] = sharding_meta.in_axes[i][key]
new_ia[idx_to_extend] = fsdp_axis_name
else:
new_ia = sharding_meta.in_axes[i]
new_input_shapes.append(new_shape)
new_in_axes.append(new_ia)
sharding_meta.input_shapes = tuple(new_input_shapes)
sharding_meta.in_axes = tuple(new_in_axes)
sharding_meta.axis_resources[fsdp_axis_name] = fsdp_mesh_axis
return sharding_meta, fsdp_axis_name
def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...], def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...],
out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]], out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]],
axis_resources: Dict, inputs: Tuple): axis_resources: Dict, inputs: Tuple):
...@@ -1032,10 +1163,12 @@ def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...], ...@@ -1032,10 +1163,12 @@ def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...],
# Collectives in manually partitioned computations are only supported # Collectives in manually partitioned computations are only supported
# when all mesh axes are partitioned manually (no partial automatic # when all mesh axes are partitioned manually (no partial automatic
# sharding). Make sure that you mention all mesh axes in axis_resources!" # sharding). Make sure that you mention all mesh axes in axis_resources!"
for i, mesh_axis_names in enumerate(mesh.axis_names): fake_idx_counter = 0
for mesh_axis_names in mesh.axis_names:
if mesh_axis_names not in axis_resources.values(): if mesh_axis_names not in axis_resources.values():
fake_axis_name = f"{mesh_axis_names}_fake_{i}" fake_idx_counter += 1
fake_in_axes[i] = fake_axis_name fake_axis_name = f"{mesh_axis_names}_fake_{fake_idx_counter}"
fake_in_axes[fake_idx_counter] = fake_axis_name
fake_axis_resource[fake_axis_name] = mesh_axis_names fake_axis_resource[fake_axis_name] = mesh_axis_names
fake_input = jnp.zeros(tuple(64 for _ in range(len(fake_in_axes) + 1))) fake_input = jnp.zeros(tuple(64 for _ in range(len(fake_in_axes) + 1)))
......
...@@ -18,8 +18,8 @@ from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd ...@@ -18,8 +18,8 @@ from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd
from .cpp_extensions import ScaledSoftmaxFwdPrimitive from .cpp_extensions import ScaledSoftmaxFwdPrimitive
from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive
from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive
from .sharding import get_softmax_sharding_meta, ShardingType from .sharding import get_softmax_sharding_meta, ShardingType, ShardingMeta
from .sharding import xmap_runner from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True) jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True) jax.config.update('experimental_xmap_spmd_lowering_manual', True)
...@@ -78,6 +78,8 @@ def softmax(inputs: jnp.ndarray, ...@@ -78,6 +78,8 @@ def softmax(inputs: jnp.ndarray,
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
mask_ = mask mask_ = mask
mask_in_axis = {} mask_in_axis = {}
...@@ -92,6 +94,10 @@ def softmax(inputs: jnp.ndarray, ...@@ -92,6 +94,10 @@ def softmax(inputs: jnp.ndarray,
tp_dim=tp_dim_index, tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name)
else:
mask_sharding_meta = ShardingMeta([{}], {}, {}, [mask_.shape], mask_.shape)
mask_sharding_meta, _ = extend_fsdp_sharding_meta(mask_sharding_meta, {0: dp_dim_index})
mask_ = jnp.reshape(mask_, mask_sharding_meta.input_shapes[0]) mask_ = jnp.reshape(mask_, mask_sharding_meta.input_shapes[0])
mask_in_axis = mask_sharding_meta.in_axes[0] mask_in_axis = mask_sharding_meta.in_axes[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment