Unverified Commit f56e4fd0 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Fix Bugs of TE/JAX (#119)



* Support transpose_bs when decoded=True
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix Bugs,

1. Fix missing dropout_dims in LayerNormMLP.
2. Fix broadcast issues in decoded.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix wrong masks in decoded.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed wrong assert condition in TransformerLayer
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix amax is not set as 0 in each step.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Enhance rules conflict checking and docs.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* fix code formatting.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 2d72c11f
......@@ -38,9 +38,13 @@ MESH_CONFIG = [((4,), ("dp",), ShardingType.DP), ((4,), ("tp",), ShardingType.TP
((4,), ("tp",), ShardingType.TP_ROW), ((2, 2), ("dp", "tp"), ShardingType.DP_TP_COL),
((2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW)]
LOGICAL_RULES = [[(('a1', None), ('a2', 'ma2')), False],
[(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True],
[(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True]]
LOGICAL_RULES = [
[(('a1', None), ('a2', 'ma2')), False],
[(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True],
[(('a1', None), ('a2', 'ma2'), ('a3', 'ma31'), ('a3', 'ma32')), False],
[(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True],
[(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True],
]
SRS = [
ShardingResource(),
ShardingResource('data', None),
......
......@@ -321,8 +321,9 @@ class MlpBlock(nn.Module):
# Take elementwise product of above intermediate activations.
x = functools.reduce(operator.mul, activations)
dropout_broadcast_dims = (0,) if self.transpose_batch_sequence else (1,)
# Apply dropout and final dense output projection.
x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=dropout_broadcast_dims)(
x, deterministic=deterministic) # Broadcast along length.
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp'))
......
......@@ -190,7 +190,7 @@ class FP8Helper:
Update the amax history
"""
updated_amax_buffers = jnp.roll(amax_buffers, -1, 1)
updated_amax_buffers.at[:, 0].set(0)
updated_amax_buffers = updated_amax_buffers.at[:, 0].set(0)
return updated_amax_buffers
@staticmethod
......
......@@ -683,6 +683,8 @@ class LayerNormMLP(TransformerEngineBase):
Each activation has its own transformation layer.
intermediate_dropout_rate: float, default = 0.1
Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
......@@ -716,6 +718,7 @@ class LayerNormMLP(TransformerEngineBase):
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',)
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
......@@ -912,8 +915,9 @@ class LayerNormMLP(TransformerEngineBase):
z = functools.reduce(operator.mul, activations)
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
z, deterministic=deterministic) # Broadcast along length.
z = nn.Dropout(rate=self.intermediate_dropout_rate,
broadcast_dims=self.intermediate_hidden_dropout_dims)(
z, deterministic=deterministic)
# DenseGeneral 2
hidden_size = inputs.shape[-1]
......
......@@ -53,6 +53,10 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
.. warning::
Please make sure ShardingResource is set via fp8_autocast before calling this function.
.. note::
This function is only needed when using TransformerLayer. For other modules, such as
DenseGeneral, please properly set axes of kernels and bias.
Parameters
----------
rules : Sequence[Tuple[str, Union[str, None]]]
......@@ -73,10 +77,12 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
f"Thie axis_name should be str, but got {type(key)}."
assert isinstance(val, str) or (val is None), \
f"Thie mesh_axis_name should be str or None, but got {type(val)}."
rules_map[key] = val
if key in rules_map:
rules_map[key].append(val)
else:
rules_map[key] = [val]
gsr = global_shard_resource()
te_logical_axis_rules = (('batch', gsr.dp_resource), ('embed', None), ('mlp', gsr.tp_resource),
('heads', gsr.tp_resource), ('kv', None), ('qkv_dim', None),
('kv_dim', None), ('joined_kv', gsr.tp_resource), ('act', None),
......@@ -87,7 +93,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
key = item[0]
val = item[1]
if key in rules_map:
assert rules_map[key] == val, \
assert len(rules_map[key]) == 1 and rules_map[key][0] == val, \
f"The rule diverged between TE and given rule." \
f"Axis:{key} map to {rules_map[key]} in the given" \
f" rules, but {val} in TE's rules."
......@@ -447,21 +453,22 @@ class MultiHeadAttention(nn.Module):
if decode:
is_initialized = self.has_variable('cache', 'cached_key')
# TODO (Ming Huang): Check performance on GPU withou swap dimensions # pylint: disable=fixme
def swap_dims(x):
return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape,
value.dtype)
cache_index = self.variable('cache', 'cache_index',
lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
batch, num_heads, head_dim, length = cached_key.value.shape
if self.transpose_batch_sequence:
length, batch, num_heads, head_dim = cached_key.value.shape
expected_shape = (1, batch, num_heads, head_dim)
one_hot_indices_shape = (length, 1, 1, 1)
else:
batch, length, num_heads, head_dim = cached_key.value.shape
expected_shape = (batch, 1, num_heads, head_dim)
one_hot_indices_shape = (1, length, 1, 1)
# Sanity shape check of cached key against input query.
expected_shape = (batch, 1, num_heads, head_dim)
if expected_shape != query.shape:
raise ValueError(
'Autoregressive cache shape error, '
......@@ -469,19 +476,15 @@ class MultiHeadAttention(nn.Module):
cur_index = cache_index.value
one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
one_token_key = jnp.moveaxis(key, -3, -1)
one_token_value = jnp.moveaxis(value, -3, -1)
key = cached_key.value + one_token_key * one_hot_indices
value = cached_value.value + one_token_value * one_hot_indices
one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
key = cached_key.value + key * one_hot_indices
value = cached_value.value + value * one_hot_indices
cached_key.value = key
cached_value.value = value
cache_index.value = cache_index.value + 1
key = jnp.moveaxis(key, -1, -3)
value = jnp.moveaxis(value, -1, -3)
mask = combine_masks(
mask, jnp.broadcast_to(jnp.arange(length) <= cur_index, (batch, 1, 1, length)))
mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))
if bias is not None:
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
......@@ -889,10 +892,11 @@ class TransformerLayer(nn.Module):
assert isinstance(self.hidden_dropout_dims, Sequence)
x_shape_len = len(x.shape)
for dims in self.hidden_dropout_dims:
assert -x_shape_len < dims < x_shape_len
assert -x_shape_len <= dims < x_shape_len
return nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, deterministic)
broadcast_dims=self.hidden_dropout_dims)(x,
deterministic=deterministic)
x = hidden_dropout(x, deterministic)
if self.drop_path > 0.0:
......@@ -944,6 +948,7 @@ class TransformerLayer(nn.Module):
intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations,
intermediate_dropout_rate=self.hidden_dropout,
intermediate_hidden_dropout_dims=self.hidden_dropout_dims,
dtype=self.dtype,
scale_axes=('embed',),
kernel_init=self.mlp_kernel_init,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment