Unverified Commit 6464ced7 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] FSDP General Support and FP8 Support to Praxis. (#347)



* Initially commit for FSDP
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding support to fsdp xmap sharding
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Specify WeightHParamsCollection of fp8 meta.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Support partial FP8 custom calls with FSDP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding amax reduction on the fsdp mesh dim.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* clean code
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong batch axis in logic_axis_rules and add sharding_constraint to BMM1
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Support FSDP in fMHA.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix missing all-reduce of wgrads along FSDP axis.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Change default value of fsdp_axis_name to  for aligning with others
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix RuntimeError: with_sharding_constraint requires a non-empty
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly changes (review feedback)
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Removed unnecessary comments
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Mergeing input_dp_dim into weight_fsdp_dim_map
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Update transformer_engine/jax/sharding.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 7804d116
......@@ -14,7 +14,7 @@ from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
......@@ -49,7 +49,8 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name="",
tp_axis_name="")
tp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
......@@ -64,6 +65,7 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1]) # 1 for kernel
......@@ -80,7 +82,8 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes),
sharding_meta.out_axes, axis_resources,
(inputs_, kernel_, fp8_max, amax, scale, scale_inv))
......@@ -90,11 +93,11 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11))
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str):
dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs,
kernel,
fp8_maxs,
......@@ -106,7 +109,8 @@ def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, am
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
return res
......@@ -122,7 +126,8 @@ def _fp8_dot_fwd(
contracting_dims,
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name):
tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
......@@ -173,6 +178,7 @@ def _fp8_dot_bwd(
sharding_type,
dp_axis_name,
tp_axis_name,
fsdp_axis_name,
ctx,
g):
input_cast_trans, kernel_cast, \
......@@ -206,6 +212,10 @@ def _fp8_dot_bwd(
wgrad = jax.lax.psum(wgrad, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if len(fsdp_axis_name) > 0:
wgrad = jax.lax.psum(wgrad, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
if is_tp_enabled(sharding_type.value[0]):
amax = jax.lax.pmax(amax, tp_axis_name)
......
......@@ -28,7 +28,8 @@ from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType
from ..sharding import infer_major_sharding_type, infer_sharding_type
from ..sharding import global_shard_resource, ShardingType
from ..sharding import global_shard_resource, with_sharding_constraint
from ..sharding import ShardingType
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -39,6 +40,17 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]
BATCH_AXES = 'nvte_batch'
SEQLEN_AXES = 'nvte_seqlen'
HEAD_AXES = 'nvte_head'
HIDDEN_AXES = 'nvte_hidden'
HIDDEN_TP_AXES = 'nvte_hidden_tp'
JOINED_AXES = 'nvte_joined'
W_NO_SHARD_AXES = 'nvte_w_no_shard'
W_FSDP_AXES = 'nvte_w_fsdp'
W_TP_AXES = 'nvte_w_tp'
W_JOINED_AXES = 'nvte_w_joined'
def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
# Generate broadcast dims for drop_path.
......@@ -91,10 +103,32 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
rules_map[key] = [val]
gsr = global_shard_resource()
te_logical_axis_rules = (('batch', gsr.dp_resource), ('embed', None), ('mlp', gsr.tp_resource),
('heads', gsr.tp_resource), ('kv', None), ('qkv_dim', None),
('kv_dim', None), ('joined_kv', gsr.tp_resource), ('act', None),
('relpos_buckets', None), ('length', None))
batch_dim_rule = []
if gsr.dp_resource is not None:
batch_dim_rule.append(gsr.dp_resource)
if gsr.fsdp_resource is not None and gsr.dp_resource != gsr.fsdp_resource:
batch_dim_rule.append(gsr.fsdp_resource)
if len(batch_dim_rule) <= 0:
batch_dim_rule = None
elif len(batch_dim_rule) == 1:
batch_dim_rule = batch_dim_rule[0]
else:
batch_dim_rule = tuple(batch_dim_rule)
te_logical_axis_rules = (
(BATCH_AXES, batch_dim_rule),
(SEQLEN_AXES, None),
(HEAD_AXES, gsr.tp_resource),
(HIDDEN_AXES, None),
(HIDDEN_TP_AXES, gsr.tp_resource),
(JOINED_AXES, None),
(W_NO_SHARD_AXES, None),
(W_FSDP_AXES, gsr.fsdp_resource),
(W_TP_AXES, gsr.tp_resource),
(W_JOINED_AXES, None),
)
extended_rules = [*rules]
for item in te_logical_axis_rules:
......@@ -110,6 +144,18 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
return tuple(extended_rules)
def _with_sharding_constraint(x: Array, logical_axis_names: Shape):
assert len(x.shape) == len(logical_axis_names)
rules = extend_logical_axis_rules(tuple())
rules_dict = {}
for key, value in rules:
rules_dict[key] = value
mesh_axis_names = [rules_dict[name] for name in logical_axis_names]
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return with_sharding_constraint(x, pspec)
def _merge_mask(func, *masks: Optional[Array]):
masks = [m for m in masks if m is not None]
if not masks:
......@@ -167,6 +213,9 @@ def core_attention(query: Array,
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
# When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias).
# In this case, the scale can not fused into the Softmax module.
if bias is not None:
......@@ -425,15 +474,13 @@ class MultiHeadAttention(nn.Module):
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=('embed',),
kernel_axes=('embed', 'qkv_dim', 'joined_kv'),
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=qkv_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(
'qkv_dim',
'joined_kv',
),
bias_axes=(W_JOINED_AXES, W_TP_AXES),
name='qkv',
dtype=self.dtype)(inputs_q)
if not use_fused_attn:
......@@ -449,11 +496,12 @@ class MultiHeadAttention(nn.Module):
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=('embed',),
kernel_axes=('embed', 'joined_kv'),
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=('joined_kv',),
bias_axes=(W_TP_AXES,),
dtype=self.dtype,
kernel_init=query_init,
name='query')(inputs_q)
......@@ -461,14 +509,11 @@ class MultiHeadAttention(nn.Module):
features=(2, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=('embed', 'kv_dim', 'joined_kv'),
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(
'kv_dim',
'joined_kv',
),
bias_axes=(W_JOINED_AXES, W_TP_AXES),
name='kv',
dtype=self.dtype)(inputs_kv)
if not use_fused_attn:
......@@ -480,10 +525,10 @@ class MultiHeadAttention(nn.Module):
features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=('embed', 'joined_kv'),
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=('joined_kv',),
bias_axes=(W_TP_AXES,),
dtype=self.dtype)
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
......@@ -495,11 +540,12 @@ class MultiHeadAttention(nn.Module):
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True,
scale_axes=('embed',),
kernel_axes=('embed', 'joined_kv'),
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=('joined_kv',),
bias_axes=(W_TP_AXES,),
dtype=self.dtype,
kernel_init=query_init,
name='query')(inputs_q)
......@@ -520,12 +566,12 @@ class MultiHeadAttention(nn.Module):
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
qkv_sharding_constraint = \
('length', 'batch', 'heads','kv') \
(SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
if self.transpose_batch_sequence \
else ('batch', 'length', 'heads', 'kv')
query = nn_partitioning.with_sharding_constraint(query, qkv_sharding_constraint)
key = nn_partitioning.with_sharding_constraint(key, qkv_sharding_constraint)
value = nn_partitioning.with_sharding_constraint(value, qkv_sharding_constraint)
else (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
query = _with_sharding_constraint(query, qkv_sharding_constraint)
key = _with_sharding_constraint(key, qkv_sharding_constraint)
value = _with_sharding_constraint(value, qkv_sharding_constraint)
if decode:
is_initialized = self.has_variable('cache', 'cached_key')
......@@ -601,9 +647,9 @@ class MultiHeadAttention(nn.Module):
if inputs_q is inputs_kv:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = ('batch', 'length', 'qkv_dim', 'heads', 'kv')
qkv_proj = nn_partitioning.with_sharding_constraint(qkv_proj,
qkv_sharding_constraint)
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
qkv_proj = _with_sharding_constraint(qkv_proj, qkv_sharding_constraint)
x = self_fused_attn(qkv_proj,
bias,
mask,
......@@ -618,10 +664,11 @@ class MultiHeadAttention(nn.Module):
assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_heads, self.head_dim))
q_sharding_constraint = ('batch', 'length', 'heads', 'kv')
kv_sharding_constraint = ('batch', 'length', 'kv_dim', 'heads', 'kv')
query = nn_partitioning.with_sharding_constraint(query, q_sharding_constraint)
kv_proj = nn_partitioning.with_sharding_constraint(kv_proj, kv_sharding_constraint)
q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
query = _with_sharding_constraint(query, q_sharding_constraint)
kv_proj = _with_sharding_constraint(kv_proj, kv_sharding_constraint)
x = cross_fused_attn(query,
kv_proj,
......@@ -668,20 +715,20 @@ class MultiHeadAttention(nn.Module):
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
attn_context_sharding_constraint = \
('length', 'batch', 'joined_kv') \
(SEQLEN_AXES, BATCH_AXES, HIDDEN_TP_AXES) \
if self.transpose_batch_sequence \
else ('batch', 'length', 'joined_kv')
x = nn_partitioning.with_sharding_constraint(x, attn_context_sharding_constraint)
else (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
x = _with_sharding_constraint(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1],
sharding_type=second_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=('joined_kv', 'embed'),
kernel_axes=(W_TP_AXES, W_FSDP_AXES),
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=('embed',),
bias_axes=(W_NO_SHARD_AXES,),
dtype=self.dtype,
name='out')(x)
return out, residual
......@@ -1118,17 +1165,15 @@ class TransformerLayer(nn.Module):
intermediate_dropout_rate=self.hidden_dropout,
intermediate_hidden_dropout_dims=self.hidden_dropout_dims,
dtype=self.dtype,
scale_axes=('embed',),
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_init=self.mlp_kernel_init,
kernel_axes_1=('embed', 'act', 'mlp'),
kernel_axes_2=('mlp', 'embed'),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes_1=(
'act',
'mlp',
),
bias_axes_2=('embed',),
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
name='mlp',
)(mlp_input, deterministic=deterministic)
......@@ -1148,8 +1193,8 @@ class TransformerLayer(nn.Module):
z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
scale_axes=('embed',),
bias_axes=('embed',),
scale_axes=(W_NO_SHARD_AXES,),
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
sharding_type=ln_sharding_type,
......
......@@ -16,7 +16,7 @@ from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType
from .sharding import xmap_runner
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
......@@ -82,6 +82,7 @@ def self_fused_attn(qkv: jnp.ndarray,
tp_dims=([3, 1, None, 0], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
......@@ -95,9 +96,9 @@ def self_fused_attn(qkv: jnp.ndarray,
is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0])
output = jnp.reshape(output_, sharding_meta.output_shapes)
return output
......@@ -202,6 +203,7 @@ def cross_fused_attn(q: jnp.ndarray,
tp_dims=([2, 3, None, None], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
......@@ -215,9 +217,9 @@ def cross_fused_attn(q: jnp.ndarray,
is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0])
output = jnp.reshape(output_, sharding_meta.output_shapes)
return output
......
......@@ -18,7 +18,7 @@ from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
......@@ -61,12 +61,15 @@ def layernorm(inputs: jnp.ndarray,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name="")
dp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
dp_dim_index, dp_axis_name, tp_axis_name)
sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
gamma_ = jnp.reshape(gamma, sharding_meta.input_shapes[1]) # 1 for gamma
beta_ = beta
......@@ -82,7 +85,8 @@ def layernorm(inputs: jnp.ndarray,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name)
dp_axis_name=dp_axis_name,
fsdp_axis_name=fsdp_axis_name)
output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, gamma_, beta_))
......@@ -92,11 +96,11 @@ def layernorm(inputs: jnp.ndarray,
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7))
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8))
def _layernorm(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, sharding_type,
dp_axis_name):
dp_axis_name, fsdp_axis_name):
output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon,
sharding_type, dp_axis_name)
sharding_type, dp_axis_name, fsdp_axis_name)
return output
......@@ -108,7 +112,8 @@ def _layernorm_fwd(
zero_centered_gamma,
epsilon,
sharding_type, # pylint: disable=unused-argument
dp_axis_name # pylint: disable=unused-argument
dp_axis_name, # pylint: disable=unused-argument
fsdp_axis_name # pylint: disable=unused-argument
):
if layernorm_type == 'layernorm':
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
......@@ -120,8 +125,8 @@ def _layernorm_fwd(
return output, (mu, rsigma, x, gamma)
def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name, ctx,
g):
def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name,
fsdp_axis_name, ctx, g):
mu, rsigma, x, gamma = ctx
if layernorm_type == 'layernorm':
......@@ -142,6 +147,11 @@ def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type,
grad_gamma = jax.lax.psum(grad_gamma, dp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
if len(fsdp_axis_name) > 0:
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
return grad_input, grad_gamma, grad_beta
......@@ -196,13 +206,15 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name="",
tp_axis_name="")
tp_axis_name="",
fsdp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
ln_sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, gamma.shape,
dp_dim_index, dp_axis_name, tp_axis_name)
ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, ln_sharding_meta.input_shapes[0]) # 0 for input
gamma_ = jnp.reshape(gamma, ln_sharding_meta.input_shapes[1]) # 1 for gamma
beta_ = beta
......@@ -222,6 +234,8 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
dot_sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(dot_sharding_meta,
{0: dp_dim_index})
kernel_ = jnp.reshape(kernel, dot_sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
......@@ -242,7 +256,8 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
# input, kernel, gamma, beta, fp8_metas
in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1],
......@@ -255,18 +270,18 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16))
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
zero_centered_gamma: bool, epsilon: float, sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str) -> jnp.ndarray:
dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str) -> jnp.ndarray:
output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
contracting_dims, zero_centered_gamma, epsilon,
sharding_type, dp_axis_name, tp_axis_name)
sharding_type, dp_axis_name, tp_axis_name, fsdp_axis_name)
return output
......@@ -287,7 +302,8 @@ def _layernorm_fp8_dot_fwd(
epsilon,
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name):
tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
......@@ -362,6 +378,7 @@ def _layernorm_fp8_dot_bwd(
sharding_type,
dp_axis_name,
tp_axis_name,
fsdp_axis_name,
ctx,
g):
ln_out_, kernel_cast, \
......@@ -422,6 +439,13 @@ def _layernorm_fp8_dot_bwd(
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if len(fsdp_axis_name) > 0:
wgrad = jax.lax.psum(wgrad, fsdp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
if is_tp_enabled(sharding_type.value[0]):
amax = jax.lax.pmax(amax, tp_axis_name)
......
......@@ -23,7 +23,7 @@ from .sharding import MajorShardingType, ShardingType
from .sharding import get_elementwise_sharding_meta
from .sharding import get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import merge_axis_resources, infer_sharding_type
from .sharding import xmap_runner
from .sharding import xmap_runner, extend_fsdp_sharding_meta
from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8GemmPackage
......@@ -54,6 +54,7 @@ def geglu(
sharding_meta = get_elementwise_sharding_meta(sharding_type, inputs.shape, None,
dp_dim_index, dp_axis_name, tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
......@@ -133,7 +134,7 @@ def fp8_ln_mlp(
if major_sharding_type is MajorShardingType.SINGLE:
res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, layernorm_type, activations, zero_centered_gamma, epsilon,
fwd_dtype, bwd_dtype, contracting_dims, major_sharding_type, "", "")
fwd_dtype, bwd_dtype, contracting_dims, major_sharding_type, "", "", "")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
......@@ -143,12 +144,15 @@ def fp8_ln_mlp(
ln_sharding_meta = get_elementwise_sharding_meta(first_part_st, inputs.shape,
ln_scale.shape, dp_dim_index, dp_axis_name,
tp_axis_name)
ln_sharding_meta, _ = extend_fsdp_sharding_meta(ln_sharding_meta, {0: dp_dim_index})
input_tp_index = len(inputs.shape) - 1
first_dot_sharding_meta = get_dot_sharding_meta(first_part_st, inputs.shape, kernel_1.shape,
dp_dim_index, input_tp_index, 2,
contracting_dims, dp_axis_name,
tp_axis_name)
first_dot_sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(
first_dot_sharding_meta, {0: dp_dim_index})
second_input_shape = (*first_dot_sharding_meta.output_shapes[0][:-2],
first_dot_sharding_meta.output_shapes[0][-1])
second_dot_sharding_meta = get_dot_sharding_meta(second_part_st, second_input_shape,
......@@ -156,6 +160,8 @@ def fp8_ln_mlp(
len(second_input_shape) - 1, 0,
contracting_dims, dp_axis_name,
tp_axis_name)
second_dot_sharding_meta, _ = extend_fsdp_sharding_meta(second_dot_sharding_meta,
{0: dp_dim_index})
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(first_part_st, num_of_fp8_meta_kind,
......@@ -187,7 +193,8 @@ def fp8_ln_mlp(
contracting_dims=contracting_dims,
major_sharding_type=major_sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
in_axes = (ln_sharding_meta.in_axes[0], ln_sharding_meta.in_axes[1], ln_bias_in_axis,
first_dot_sharding_meta.in_axes[1], second_dot_sharding_meta.in_axes[1],
*fp8_sharding_meta.in_axes)
......@@ -200,14 +207,15 @@ def fp8_ln_mlp(
return res
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
activations: Sequence[Union[str, Callable]], zero_centered_gamma: bool, epsilon: float,
fwd_dtype: TEDType, bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int],
Sequence[int]],
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str):
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str,
fsdp_axis_name: str):
res, _ = _fp8_mlp_fwd(inputs,
ln_scale,
ln_bias,
......@@ -226,7 +234,8 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
contracting_dims=contracting_dims,
major_sharding_type=major_sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
tp_axis_name=tp_axis_name,
fsdp_axis_name=fsdp_axis_name)
return res
......@@ -249,7 +258,8 @@ def _fp8_mlp_fwd(
contracting_dims,
major_sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name):
tp_axis_name,
fsdp_axis_name): # pylint: disable=unused-argument
if activations != ('gelu', 'linear'):
raise NotImplementedError("activations only support ('gelu', 'linear') for now.")
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
......@@ -352,6 +362,7 @@ def _fp8_mlp_bwd(
major_sharding_type,
dp_axis_name,
tp_axis_name,
fsdp_axis_name,
ctx,
g):
inputs_, ln_out, mu, rsigma, gamma, \
......@@ -431,6 +442,14 @@ def _fp8_mlp_bwd(
grad_beta = jax.lax.psum(grad_beta, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if len(fsdp_axis_name) > 0:
wgrad_1 = jax.lax.psum(wgrad_1, fsdp_axis_name)
wgrad_2 = jax.lax.psum(wgrad_2, fsdp_axis_name)
grad_gamma = jax.lax.psum(grad_gamma, fsdp_axis_name)
if grad_beta is not None:
grad_beta = jax.lax.psum(grad_beta, fsdp_axis_name)
amax = jax.lax.pmax(amax, fsdp_axis_name)
if major_sharding_type in (MajorShardingType.TP, MajorShardingType.DPTP):
amax = jax.lax.pmax(amax, tp_axis_name)
......
......@@ -9,10 +9,11 @@ from typing import Callable, Iterable, Sequence, Tuple, Union
from praxis import pax_fiddle
from praxis.base_layer import init_var
from praxis.base_layer import BaseLayer, WeightInit, WeightHParams
from praxis.base_layer import BaseLayer, WeightInit, WeightHParams, WeightHParamsCollection
from praxis.layers import flax_adapter
from praxis.pytypes import JTensor
from ..fp8 import FP8Helper
from ..flax.module import DenseGeneral, LayerNormDenseGeneral
from ..flax.module import LayerNorm as flax_LayerNorm
from ..flax.module import LayerNormMLP as flax_LayerNormMLP
......@@ -45,9 +46,18 @@ class TransformerEngineBaseLayer(BaseLayer):
def create_layer(self, name, flax_module_cls):
"""create_layer"""
fp8_collection_map = {
FP8Helper.FP8_COLLECTION_NAME: [
WeightHParamsCollection.SKIP_LP_REGULARIZATION,
WeightHParamsCollection.NON_TRAINABLE,
WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
]
}
flax_module_p = pax_fiddle.Config(flax_adapter.FlaxModuleAdapter,
module_factory_method=flax_module_cls,
logical_axes_rules=self.logical_axes_rules,
var_collection_map=fp8_collection_map,
ici_mesh_shape=self.ici_mesh_shape,
dcn_mesh_shape=self.dcn_mesh_shape,
mesh_axis_names=self.mesh_axis_names)
......
......@@ -14,6 +14,7 @@ from jax.interpreters import pxla
import jax
import jax.numpy as jnp
from jax.experimental.maps import xmap
from jax.sharding import PartitionSpec
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
......@@ -28,6 +29,17 @@ def _get_mesh_info(resource: str):
return mesh.shape[resource], resource
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
"""
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return x
return jax.lax.with_sharding_constraint(x, pspec)
@dataclass
class ShardingResource:
"""
......@@ -45,6 +57,7 @@ class ShardingResource:
"""
dp_resource: str = None
tp_resource: str = None
fsdp_resource: str = None
_GLOBAL_SHARD_RESOURCE = ShardingResource()
......@@ -421,6 +434,11 @@ class FusedAttnShardingMetaGenerator(ShardingMetaGenerator):
out_axis[tp_dim] = tp_axis_name
out_axes.append(out_axis)
assert len(out_axes) == 1, "Only allow single output at this moment."
assert len(output_new_shapes) == 1, "Only allow single output at this moment."
out_axes = out_axes[0]
output_new_shapes = output_new_shapes[0]
axis_resources = {}
if dp_axis_name is not None:
axis_resources[dp_axis_name] = dp_mesh_axis
......@@ -1015,6 +1033,119 @@ def get_fused_attn_sharding_meta(stype: ShardingType,
tp_axis_name)
def extend_fsdp_sharding_meta(sharding_meta: ShardingMeta,
weight_fsdp_dim_map: Dict[int, int]) -> Tuple[ShardingMeta, str]:
"""
Extending the given ShardingMeta to be compatible with FSDP (ZeRO3) sharding pattern.
.. note::
The extending helper assumes the first shape in sharding_meta.input_shapes
corresponding to the input tensor. Please be sure that 0-idx is in
`weight_fsdp_dim_map`.
Parameters
----------
sharding_meta : ShardingMeta
the sharding meta object to extend with FSDP.
weight_fsdp_dim_map: Dict[int, int]
The dict, which key is idx of sharding_meta.input_shapes and value is the dimension
to extend FSDP. default is None, means no other sharding_meta.input_shapes to extend.
Returns
-------
updated_sharding_meta : ShardingMeta
a sharding_meta with the FSDP extenstion.
fsdp_axis_name: str
The name of FSDP named axis for further xmap projection.
"""
assert 0 in weight_fsdp_dim_map, \
"0-idx is required to be in 'weight_fsdp_dim_map' for the input."
mst = infer_major_sharding_type()
if mst is MajorShardingType.SINGLE:
return sharding_meta, ""
gsr = global_shard_resource()
dp_mesh_axis = gsr.dp_resource
fsdp_mesh_axis = gsr.fsdp_resource
if fsdp_mesh_axis == dp_mesh_axis:
return sharding_meta, ""
if fsdp_mesh_axis is None:
return sharding_meta, ""
fsdp_dim_size, _ = _get_mesh_info(fsdp_mesh_axis)
fsdp_axis_name = "fsdp"
def get_idx_to_extend(sharded_indices, target_idx):
idx_to_extend = target_idx
for i in sharded_indices:
if i <= target_idx:
idx_to_extend += 1
return idx_to_extend
def extend_exist_sharding(idx, shape):
remain_size = shape[idx]
assert remain_size == -1 or remain_size % fsdp_dim_size == 0
remain_size = remain_size // fsdp_dim_size
new_shape = tuple([*shape[:idx], fsdp_dim_size, remain_size, *shape[idx + 1:]])
return new_shape
new_input_shapes = []
new_in_axes = []
for i, shape in enumerate(sharding_meta.input_shapes):
idx_to_extend = -1
if i == 0: # Assume first shape corresponds to input
input_dp_dim = weight_fsdp_dim_map[i]
# idx_to_extend = input_dp_dim + 1 if is_dp_enabled(mst) else input_dp_dim
idx_to_extend = get_idx_to_extend(list(sharding_meta.in_axes[i].keys()), input_dp_dim)
new_shape = extend_exist_sharding(idx_to_extend, shape)
# assume one output only and have the same batch sharding like input
assert isinstance(sharding_meta.out_axes, dict)
new_out_axes = {}
for key in sharding_meta.out_axes:
if key < idx_to_extend:
new_out_axes[key] = sharding_meta.out_axes[key]
else:
new_out_axes[key + 1] = sharding_meta.out_axes[key]
new_out_axes[idx_to_extend] = fsdp_axis_name
sharding_meta.out_axes = new_out_axes
else:
new_shape = shape
if i in weight_fsdp_dim_map:
idx_to_extend = get_idx_to_extend(list(sharding_meta.in_axes[i].keys()),
weight_fsdp_dim_map[i])
if weight_fsdp_dim_map[i] in sharding_meta.in_axes[i]:
new_shape = extend_exist_sharding(idx_to_extend, shape)
else:
assert shape[idx_to_extend] % fsdp_dim_size == 0
remain_dim_size = shape[idx_to_extend] // fsdp_dim_size
new_shape = tuple([
*shape[:idx_to_extend], fsdp_dim_size, remain_dim_size,
*shape[idx_to_extend + 1:]
])
if idx_to_extend >= 0:
new_ia = {}
for key in sharding_meta.in_axes[i]:
if key < idx_to_extend:
new_ia[key] = sharding_meta.in_axes[i][key]
else:
new_ia[key + 1] = sharding_meta.in_axes[i][key]
new_ia[idx_to_extend] = fsdp_axis_name
else:
new_ia = sharding_meta.in_axes[i]
new_input_shapes.append(new_shape)
new_in_axes.append(new_ia)
sharding_meta.input_shapes = tuple(new_input_shapes)
sharding_meta.in_axes = tuple(new_in_axes)
sharding_meta.axis_resources[fsdp_axis_name] = fsdp_mesh_axis
return sharding_meta, fsdp_axis_name
def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...],
out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]],
axis_resources: Dict, inputs: Tuple):
......@@ -1032,10 +1163,12 @@ def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...],
# Collectives in manually partitioned computations are only supported
# when all mesh axes are partitioned manually (no partial automatic
# sharding). Make sure that you mention all mesh axes in axis_resources!"
for i, mesh_axis_names in enumerate(mesh.axis_names):
fake_idx_counter = 0
for mesh_axis_names in mesh.axis_names:
if mesh_axis_names not in axis_resources.values():
fake_axis_name = f"{mesh_axis_names}_fake_{i}"
fake_in_axes[i] = fake_axis_name
fake_idx_counter += 1
fake_axis_name = f"{mesh_axis_names}_fake_{fake_idx_counter}"
fake_in_axes[fake_idx_counter] = fake_axis_name
fake_axis_resource[fake_axis_name] = mesh_axis_names
fake_input = jnp.zeros(tuple(64 for _ in range(len(fake_in_axes) + 1)))
......
......@@ -18,8 +18,8 @@ from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd
from .cpp_extensions import ScaledSoftmaxFwdPrimitive
from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive
from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive
from .sharding import get_softmax_sharding_meta, ShardingType
from .sharding import xmap_runner
from .sharding import get_softmax_sharding_meta, ShardingType, ShardingMeta
from .sharding import xmap_runner, extend_fsdp_sharding_meta
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
......@@ -78,6 +78,8 @@ def softmax(inputs: jnp.ndarray,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
mask_ = mask
mask_in_axis = {}
......@@ -92,8 +94,12 @@ def softmax(inputs: jnp.ndarray,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
mask_ = jnp.reshape(mask_, mask_sharding_meta.input_shapes[0])
mask_in_axis = mask_sharding_meta.in_axes[0]
else:
mask_sharding_meta = ShardingMeta([{}], {}, {}, [mask_.shape], mask_.shape)
mask_sharding_meta, _ = extend_fsdp_sharding_meta(mask_sharding_meta, {0: dp_dim_index})
mask_ = jnp.reshape(mask_, mask_sharding_meta.input_shapes[0])
mask_in_axis = mask_sharding_meta.in_axes[0]
partial_softmax = partial(_softmax, scale_factor=scale_factor, softmax_type=softmax_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment