Unverified Commit 80222dc0 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Enhance Dropout in TransformerLayer. (#444)



* [JAX] Enhance Dropout in TransformerLayer.

1. Fixed missing setup of dropout RNG key in TransformerLayer and
   LayerNormMLP.
2. Allowing seperated dropout rate for FC1's output and other hiddens.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix wrong fp8 scale in _update_fp8_metas_impl
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 958e1889
......@@ -72,11 +72,10 @@ class TestFP8Helper(unittest.TestCase):
amax = np.array(amax)
scale = np.array(scale)
exp = np.floor(np.log2(fp8_max / amax)) - FP8Helper.MARGIN
sf = np.round(np.power(2, np.abs(exp)))
sf = np.where(amax > 0.0, sf, scale)
sf = np.where(np.isfinite(amax), sf, scale)
return np.where(exp < 0, 1 / sf, sf)
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
return sf
amax_meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
scale_meta_shape = (num_of_meta, 1)
......
......@@ -167,6 +167,7 @@ class TestEncoderLayer:
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
......@@ -174,6 +175,7 @@ class TestEncoderLayer:
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
dtype=dtype,
......@@ -212,6 +214,7 @@ class TestEncoderLayer:
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
......@@ -219,6 +222,7 @@ class TestEncoderLayer:
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
dtype=dtype,
......@@ -381,6 +385,7 @@ class TestDecoderLayer:
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
......@@ -388,6 +393,7 @@ class TestDecoderLayer:
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
dtype=dtype,
**te_layer_attrs)
......@@ -426,6 +432,7 @@ class TestDecoderLayer:
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
......@@ -433,6 +440,7 @@ class TestDecoderLayer:
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
dtype=dtype,
**te_layer_attrs)
......
......@@ -957,6 +957,7 @@ class TestTransformer(TestLayer):
layernorm_type = attrs[TransformerLayerAttr.LN_TYPE]
hidden_dropout = 0.0
attention_dropout = 0.0
intermediate_dropout = 0.0
mlp_activations = attrs[TransformerLayerAttr.ACTIVATION]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[TransformerLayerAttr.USE_BIAS]
......@@ -991,6 +992,7 @@ class TestTransformer(TestLayer):
layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations,
use_bias=use_bias,
bias_init=bias_init,
......@@ -1007,6 +1009,7 @@ class TestTransformer(TestLayer):
layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", kernel_init),
......
......@@ -739,6 +739,8 @@ class LayerNormMLP(TransformerEngineBase):
activations: Sequence[Union[str, Callable]], default = ('relu',)
The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1
Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = ()
......@@ -779,6 +781,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_axes_2: Tuple[str, ...] = ('embed',)
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',)
intermediate_dropout_rng_name: str = 'dropout'
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1
......@@ -985,7 +988,8 @@ class LayerNormMLP(TransformerEngineBase):
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate,
broadcast_dims=self.intermediate_hidden_dropout_dims)(
broadcast_dims=self.intermediate_hidden_dropout_dims,
rng_collection=self.intermediate_dropout_rng_name)(
z, deterministic=deterministic)
# DenseGeneral 2
......
......@@ -883,6 +883,10 @@ class TransformerLayer(nn.Module):
Dimensions that will share the same dropout mask for hidden
attention_dropout: float, default = 0.1
Dropout probability for the dropout op during multi-head attention.
intermediate_dropout: float, default = 0.1
Dropout probability for the dropout op after FC1 layer.
intermediate_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden after FC1 layer.
dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for
generating Dropout masks in the Multi-Head Attention.
......@@ -963,6 +967,8 @@ class TransformerLayer(nn.Module):
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = 'dropout'
mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None
......@@ -1078,6 +1084,8 @@ class TransformerLayer(nn.Module):
else:
mha_name = 'self_attention'
inputs = _with_sharding_constraint(inputs, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES))
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x, residual = MultiHeadAttention(
num_heads=self.num_attention_heads,
......@@ -1113,14 +1121,15 @@ class TransformerLayer(nn.Module):
assert -x_shape_len <= dims < x_shape_len
return nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x,
deterministic=deterministic)
broadcast_dims=self.hidden_dropout_dims,
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
x = hidden_dropout(x, deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
broadcast_dims=drop_path_shape,
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
x = x + residual
mlp_input = x
......@@ -1156,6 +1165,8 @@ class TransformerLayer(nn.Module):
y = hidden_dropout(y, deterministic)
mlp_input = y + residual
mlp_input = _with_sharding_constraint(mlp_input, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES))
# MlpBlock
residual = mlp_input
z, ln_out = LayerNormMLP(
......@@ -1167,8 +1178,9 @@ class TransformerLayer(nn.Module):
return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations,
intermediate_dropout_rate=self.hidden_dropout,
intermediate_hidden_dropout_dims=self.hidden_dropout_dims,
intermediate_dropout_rng_name=self.dropout_rng_name,
intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
dtype=self.dtype,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
......
......@@ -310,11 +310,11 @@ class FP8Helper:
amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]
scale = fp8_meta_arrays[fp8_scale_idx]
sf = (fp8_max / amax) / (2 ** FP8Helper.MARGIN)
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
fp8_meta_arrays[fp8_scale_idx] = scale
fp8_meta_arrays[fp8_scale_inv_idx] = 1 / scale
fp8_meta_arrays[fp8_scale_idx] = sf
fp8_meta_arrays[fp8_scale_inv_idx] = 1 / sf
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
......
......@@ -137,6 +137,8 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = 'dropout'
mlp_activations: Sequence[str] = ('relu',)
use_bias: bool = False
......@@ -190,6 +192,8 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_dropout=self.hidden_dropout,
hidden_dropout_dims=self.hidden_dropout_dims,
attention_dropout=self.attention_dropout,
intermediate_dropout=self.intermediate_dropout,
intermediate_dropout_dims=self.intermediate_dropout_dims,
dropout_rng_name=self.dropout_rng_name,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", self.params_init),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment