Unverified Commit 80222dc0 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Enhance Dropout in TransformerLayer. (#444)



* [JAX] Enhance Dropout in TransformerLayer.

1. Fixed missing setup of dropout RNG key in TransformerLayer and
   LayerNormMLP.
2. Allowing seperated dropout rate for FC1's output and other hiddens.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix wrong fp8 scale in _update_fp8_metas_impl
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 958e1889
...@@ -72,11 +72,10 @@ class TestFP8Helper(unittest.TestCase): ...@@ -72,11 +72,10 @@ class TestFP8Helper(unittest.TestCase):
amax = np.array(amax) amax = np.array(amax)
scale = np.array(scale) scale = np.array(scale)
exp = np.floor(np.log2(fp8_max / amax)) - FP8Helper.MARGIN sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = np.round(np.power(2, np.abs(exp))) sf = jnp.where(amax > 0.0, sf, scale)
sf = np.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale)
sf = np.where(np.isfinite(amax), sf, scale) return sf
return np.where(exp < 0, 1 / sf, sf)
amax_meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN) amax_meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
scale_meta_shape = (num_of_meta, 1) scale_meta_shape = (num_of_meta, 1)
......
...@@ -167,6 +167,7 @@ class TestEncoderLayer: ...@@ -167,6 +167,7 @@ class TestEncoderLayer:
if k == 'dropout_rate': if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi': elif k == 'fuse_mlp_wi':
continue continue
else: else:
...@@ -174,6 +175,7 @@ class TestEncoderLayer: ...@@ -174,6 +175,7 @@ class TestEncoderLayer:
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER, layer_type=TransformerLayerType.ENCODER,
self_attn_mask_type='padding', self_attn_mask_type='padding',
dtype=dtype, dtype=dtype,
...@@ -212,6 +214,7 @@ class TestEncoderLayer: ...@@ -212,6 +214,7 @@ class TestEncoderLayer:
if k == 'dropout_rate': if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi': elif k == 'fuse_mlp_wi':
continue continue
else: else:
...@@ -219,6 +222,7 @@ class TestEncoderLayer: ...@@ -219,6 +222,7 @@ class TestEncoderLayer:
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER, layer_type=TransformerLayerType.ENCODER,
self_attn_mask_type='padding', self_attn_mask_type='padding',
dtype=dtype, dtype=dtype,
...@@ -381,6 +385,7 @@ class TestDecoderLayer: ...@@ -381,6 +385,7 @@ class TestDecoderLayer:
if k == 'dropout_rate': if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi': elif k == 'fuse_mlp_wi':
continue continue
else: else:
...@@ -388,6 +393,7 @@ class TestDecoderLayer: ...@@ -388,6 +393,7 @@ class TestDecoderLayer:
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER, layer_type=TransformerLayerType.DECODER,
dtype=dtype, dtype=dtype,
**te_layer_attrs) **te_layer_attrs)
...@@ -426,6 +432,7 @@ class TestDecoderLayer: ...@@ -426,6 +432,7 @@ class TestDecoderLayer:
if k == 'dropout_rate': if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi': elif k == 'fuse_mlp_wi':
continue continue
else: else:
...@@ -433,6 +440,7 @@ class TestDecoderLayer: ...@@ -433,6 +440,7 @@ class TestDecoderLayer:
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER, layer_type=TransformerLayerType.DECODER,
dtype=dtype, dtype=dtype,
**te_layer_attrs) **te_layer_attrs)
......
...@@ -957,6 +957,7 @@ class TestTransformer(TestLayer): ...@@ -957,6 +957,7 @@ class TestTransformer(TestLayer):
layernorm_type = attrs[TransformerLayerAttr.LN_TYPE] layernorm_type = attrs[TransformerLayerAttr.LN_TYPE]
hidden_dropout = 0.0 hidden_dropout = 0.0
attention_dropout = 0.0 attention_dropout = 0.0
intermediate_dropout = 0.0
mlp_activations = attrs[TransformerLayerAttr.ACTIVATION] mlp_activations = attrs[TransformerLayerAttr.ACTIVATION]
kernel_init = WeightInit.Gaussian(1.0) kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[TransformerLayerAttr.USE_BIAS] use_bias = attrs[TransformerLayerAttr.USE_BIAS]
...@@ -991,6 +992,7 @@ class TestTransformer(TestLayer): ...@@ -991,6 +992,7 @@ class TestTransformer(TestLayer):
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout, hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations, mlp_activations=mlp_activations,
use_bias=use_bias, use_bias=use_bias,
bias_init=bias_init, bias_init=bias_init,
...@@ -1007,6 +1009,7 @@ class TestTransformer(TestLayer): ...@@ -1007,6 +1009,7 @@ class TestTransformer(TestLayer):
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout, hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations, mlp_activations=mlp_activations,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", kernel_init), "mha_kernel", kernel_init),
......
...@@ -739,6 +739,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -739,6 +739,8 @@ class LayerNormMLP(TransformerEngineBase):
activations: Sequence[Union[str, Callable]], default = ('relu',) activations: Sequence[Union[str, Callable]], default = ('relu',)
The sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1 intermediate_dropout_rate: float, default = 0.1
Dropout probability for the dropout op after the :attr:`activations`. Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = () intermediate_hidden_dropout_dims: Sequence[int], default = ()
...@@ -779,6 +781,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -779,6 +781,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_axes_2: Tuple[str, ...] = ('embed',) bias_axes_2: Tuple[str, ...] = ('embed',)
return_layernorm_output: bool = True return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',) activations: Sequence[Union[str, Callable]] = ('relu',)
intermediate_dropout_rng_name: str = 'dropout'
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = () intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
...@@ -985,7 +988,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -985,7 +988,8 @@ class LayerNormMLP(TransformerEngineBase):
z = jnp.reshape(z, (*z.shape[:-2], -1)) z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate, z = nn.Dropout(rate=self.intermediate_dropout_rate,
broadcast_dims=self.intermediate_hidden_dropout_dims)( broadcast_dims=self.intermediate_hidden_dropout_dims,
rng_collection=self.intermediate_dropout_rng_name)(
z, deterministic=deterministic) z, deterministic=deterministic)
# DenseGeneral 2 # DenseGeneral 2
......
...@@ -883,6 +883,10 @@ class TransformerLayer(nn.Module): ...@@ -883,6 +883,10 @@ class TransformerLayer(nn.Module):
Dimensions that will share the same dropout mask for hidden Dimensions that will share the same dropout mask for hidden
attention_dropout: float, default = 0.1 attention_dropout: float, default = 0.1
Dropout probability for the dropout op during multi-head attention. Dropout probability for the dropout op during multi-head attention.
intermediate_dropout: float, default = 0.1
Dropout probability for the dropout op after FC1 layer.
intermediate_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden after FC1 layer.
dropout_rng_name: str, default = 'dropout' dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for The key in given RNGs via flax.linen.Module.apply that for
generating Dropout masks in the Multi-Head Attention. generating Dropout masks in the Multi-Head Attention.
...@@ -963,6 +967,8 @@ class TransformerLayer(nn.Module): ...@@ -963,6 +967,8 @@ class TransformerLayer(nn.Module):
hidden_dropout: float = 0.1 hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = () hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1 attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = 'dropout' dropout_rng_name: str = 'dropout'
mha_kernel_init: Initializer = None mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None mlp_kernel_init: Initializer = None
...@@ -1078,6 +1084,8 @@ class TransformerLayer(nn.Module): ...@@ -1078,6 +1084,8 @@ class TransformerLayer(nn.Module):
else: else:
mha_name = 'self_attention' mha_name = 'self_attention'
inputs = _with_sharding_constraint(inputs, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES))
# [batch, length, emb_dim] -> [batch, length, emb_dim] # [batch, length, emb_dim] -> [batch, length, emb_dim]
x, residual = MultiHeadAttention( x, residual = MultiHeadAttention(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
...@@ -1113,14 +1121,15 @@ class TransformerLayer(nn.Module): ...@@ -1113,14 +1121,15 @@ class TransformerLayer(nn.Module):
assert -x_shape_len <= dims < x_shape_len assert -x_shape_len <= dims < x_shape_len
return nn.Dropout(rate=self.hidden_dropout, return nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, broadcast_dims=self.hidden_dropout_dims,
deterministic=deterministic) rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
x = hidden_dropout(x, deterministic) x = hidden_dropout(x, deterministic)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path, x = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(x, deterministic=deterministic) broadcast_dims=drop_path_shape,
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
x = x + residual x = x + residual
mlp_input = x mlp_input = x
...@@ -1156,6 +1165,8 @@ class TransformerLayer(nn.Module): ...@@ -1156,6 +1165,8 @@ class TransformerLayer(nn.Module):
y = hidden_dropout(y, deterministic) y = hidden_dropout(y, deterministic)
mlp_input = y + residual mlp_input = y + residual
mlp_input = _with_sharding_constraint(mlp_input, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES))
# MlpBlock # MlpBlock
residual = mlp_input residual = mlp_input
z, ln_out = LayerNormMLP( z, ln_out = LayerNormMLP(
...@@ -1167,8 +1178,9 @@ class TransformerLayer(nn.Module): ...@@ -1167,8 +1178,9 @@ class TransformerLayer(nn.Module):
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size, intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations, activations=self.mlp_activations,
intermediate_dropout_rate=self.hidden_dropout, intermediate_dropout_rng_name=self.dropout_rng_name,
intermediate_hidden_dropout_dims=self.hidden_dropout_dims, intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
dtype=self.dtype, dtype=self.dtype,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
......
...@@ -310,11 +310,11 @@ class FP8Helper: ...@@ -310,11 +310,11 @@ class FP8Helper:
amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1] amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]
scale = fp8_meta_arrays[fp8_scale_idx] scale = fp8_meta_arrays[fp8_scale_idx]
sf = (fp8_max / amax) / (2 ** FP8Helper.MARGIN) sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale)
fp8_meta_arrays[fp8_scale_idx] = scale fp8_meta_arrays[fp8_scale_idx] = sf
fp8_meta_arrays[fp8_scale_inv_idx] = 1 / scale fp8_meta_arrays[fp8_scale_inv_idx] = 1 / sf
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays) return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
......
...@@ -137,6 +137,8 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -137,6 +137,8 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_dropout: float = 0.1 hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = () hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1 attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = 'dropout' dropout_rng_name: str = 'dropout'
mlp_activations: Sequence[str] = ('relu',) mlp_activations: Sequence[str] = ('relu',)
use_bias: bool = False use_bias: bool = False
...@@ -190,6 +192,8 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -190,6 +192,8 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_dropout=self.hidden_dropout, hidden_dropout=self.hidden_dropout,
hidden_dropout_dims=self.hidden_dropout_dims, hidden_dropout_dims=self.hidden_dropout_dims,
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
intermediate_dropout=self.intermediate_dropout,
intermediate_dropout_dims=self.intermediate_dropout_dims,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", self.params_init), "mha_kernel", self.params_init),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment