Unverified Commit f56e4fd0 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Fix Bugs of TE/JAX (#119)



* Support transpose_bs when decoded=True
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix Bugs,

1. Fix missing dropout_dims in LayerNormMLP.
2. Fix broadcast issues in decoded.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix wrong masks in decoded.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed wrong assert condition in TransformerLayer
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix amax is not set as 0 in each step.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Enhance rules conflict checking and docs.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* fix code formatting.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 2d72c11f
...@@ -38,9 +38,13 @@ MESH_CONFIG = [((4,), ("dp",), ShardingType.DP), ((4,), ("tp",), ShardingType.TP ...@@ -38,9 +38,13 @@ MESH_CONFIG = [((4,), ("dp",), ShardingType.DP), ((4,), ("tp",), ShardingType.TP
((4,), ("tp",), ShardingType.TP_ROW), ((2, 2), ("dp", "tp"), ShardingType.DP_TP_COL), ((4,), ("tp",), ShardingType.TP_ROW), ((2, 2), ("dp", "tp"), ShardingType.DP_TP_COL),
((2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW)] ((2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW)]
LOGICAL_RULES = [[(('a1', None), ('a2', 'ma2')), False], LOGICAL_RULES = [
[(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True], [(('a1', None), ('a2', 'ma2')), False],
[(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True]] [(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True],
[(('a1', None), ('a2', 'ma2'), ('a3', 'ma31'), ('a3', 'ma32')), False],
[(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True],
[(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True],
]
SRS = [ SRS = [
ShardingResource(), ShardingResource(),
ShardingResource('data', None), ShardingResource('data', None),
......
...@@ -321,8 +321,9 @@ class MlpBlock(nn.Module): ...@@ -321,8 +321,9 @@ class MlpBlock(nn.Module):
# Take elementwise product of above intermediate activations. # Take elementwise product of above intermediate activations.
x = functools.reduce(operator.mul, activations) x = functools.reduce(operator.mul, activations)
dropout_broadcast_dims = (0,) if self.transpose_batch_sequence else (1,)
# Apply dropout and final dense output projection. # Apply dropout and final dense output projection.
x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=dropout_broadcast_dims)(
x, deterministic=deterministic) # Broadcast along length. x, deterministic=deterministic) # Broadcast along length.
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp')) x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp'))
......
...@@ -190,7 +190,7 @@ class FP8Helper: ...@@ -190,7 +190,7 @@ class FP8Helper:
Update the amax history Update the amax history
""" """
updated_amax_buffers = jnp.roll(amax_buffers, -1, 1) updated_amax_buffers = jnp.roll(amax_buffers, -1, 1)
updated_amax_buffers.at[:, 0].set(0) updated_amax_buffers = updated_amax_buffers.at[:, 0].set(0)
return updated_amax_buffers return updated_amax_buffers
@staticmethod @staticmethod
......
...@@ -683,6 +683,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -683,6 +683,8 @@ class LayerNormMLP(TransformerEngineBase):
Each activation has its own transformation layer. Each activation has its own transformation layer.
intermediate_dropout_rate: float, default = 0.1 intermediate_dropout_rate: float, default = 0.1
Dropout probability for the dropout op after the :attr:`activations`. Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
...@@ -716,6 +718,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -716,6 +718,7 @@ class LayerNormMLP(TransformerEngineBase):
return_layernorm_output: bool = True return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',) activations: Sequence[Union[str, Callable]] = ('relu',)
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -912,8 +915,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -912,8 +915,9 @@ class LayerNormMLP(TransformerEngineBase):
z = functools.reduce(operator.mul, activations) z = functools.reduce(operator.mul, activations)
z = jnp.reshape(z, (*z.shape[:-2], -1)) z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( z = nn.Dropout(rate=self.intermediate_dropout_rate,
z, deterministic=deterministic) # Broadcast along length. broadcast_dims=self.intermediate_hidden_dropout_dims)(
z, deterministic=deterministic)
# DenseGeneral 2 # DenseGeneral 2
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
......
...@@ -53,6 +53,10 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -53,6 +53,10 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
.. warning:: .. warning::
Please make sure ShardingResource is set via fp8_autocast before calling this function. Please make sure ShardingResource is set via fp8_autocast before calling this function.
.. note::
This function is only needed when using TransformerLayer. For other modules, such as
DenseGeneral, please properly set axes of kernels and bias.
Parameters Parameters
---------- ----------
rules : Sequence[Tuple[str, Union[str, None]]] rules : Sequence[Tuple[str, Union[str, None]]]
...@@ -73,10 +77,12 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -73,10 +77,12 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
f"Thie axis_name should be str, but got {type(key)}." f"Thie axis_name should be str, but got {type(key)}."
assert isinstance(val, str) or (val is None), \ assert isinstance(val, str) or (val is None), \
f"Thie mesh_axis_name should be str or None, but got {type(val)}." f"Thie mesh_axis_name should be str or None, but got {type(val)}."
rules_map[key] = val if key in rules_map:
rules_map[key].append(val)
else:
rules_map[key] = [val]
gsr = global_shard_resource() gsr = global_shard_resource()
te_logical_axis_rules = (('batch', gsr.dp_resource), ('embed', None), ('mlp', gsr.tp_resource), te_logical_axis_rules = (('batch', gsr.dp_resource), ('embed', None), ('mlp', gsr.tp_resource),
('heads', gsr.tp_resource), ('kv', None), ('qkv_dim', None), ('heads', gsr.tp_resource), ('kv', None), ('qkv_dim', None),
('kv_dim', None), ('joined_kv', gsr.tp_resource), ('act', None), ('kv_dim', None), ('joined_kv', gsr.tp_resource), ('act', None),
...@@ -87,7 +93,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -87,7 +93,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
key = item[0] key = item[0]
val = item[1] val = item[1]
if key in rules_map: if key in rules_map:
assert rules_map[key] == val, \ assert len(rules_map[key]) == 1 and rules_map[key][0] == val, \
f"The rule diverged between TE and given rule." \ f"The rule diverged between TE and given rule." \
f"Axis:{key} map to {rules_map[key]} in the given" \ f"Axis:{key} map to {rules_map[key]} in the given" \
f" rules, but {val} in TE's rules." f" rules, but {val} in TE's rules."
...@@ -447,21 +453,22 @@ class MultiHeadAttention(nn.Module): ...@@ -447,21 +453,22 @@ class MultiHeadAttention(nn.Module):
if decode: if decode:
is_initialized = self.has_variable('cache', 'cached_key') is_initialized = self.has_variable('cache', 'cached_key')
# TODO (Ming Huang): Check performance on GPU withou swap dimensions # pylint: disable=fixme cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype)
def swap_dims(x): cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape,
return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
value.dtype) value.dtype)
cache_index = self.variable('cache', 'cache_index', cache_index = self.variable('cache', 'cache_index',
lambda: jnp.array(0, dtype=jnp.int32)) lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized: if is_initialized:
batch, num_heads, head_dim, length = cached_key.value.shape if self.transpose_batch_sequence:
length, batch, num_heads, head_dim = cached_key.value.shape
expected_shape = (1, batch, num_heads, head_dim)
one_hot_indices_shape = (length, 1, 1, 1)
else:
batch, length, num_heads, head_dim = cached_key.value.shape
expected_shape = (batch, 1, num_heads, head_dim)
one_hot_indices_shape = (1, length, 1, 1)
# Sanity shape check of cached key against input query. # Sanity shape check of cached key against input query.
expected_shape = (batch, 1, num_heads, head_dim)
if expected_shape != query.shape: if expected_shape != query.shape:
raise ValueError( raise ValueError(
'Autoregressive cache shape error, ' 'Autoregressive cache shape error, '
...@@ -469,19 +476,15 @@ class MultiHeadAttention(nn.Module): ...@@ -469,19 +476,15 @@ class MultiHeadAttention(nn.Module):
cur_index = cache_index.value cur_index = cache_index.value
one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype) one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
one_token_key = jnp.moveaxis(key, -3, -1) one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
one_token_value = jnp.moveaxis(value, -3, -1) key = cached_key.value + key * one_hot_indices
key = cached_key.value + one_token_key * one_hot_indices value = cached_value.value + value * one_hot_indices
value = cached_value.value + one_token_value * one_hot_indices
cached_key.value = key cached_key.value = key
cached_value.value = value cached_value.value = value
cache_index.value = cache_index.value + 1 cache_index.value = cache_index.value + 1
key = jnp.moveaxis(key, -1, -3)
value = jnp.moveaxis(value, -1, -3)
mask = combine_masks( mask = combine_masks(
mask, jnp.broadcast_to(jnp.arange(length) <= cur_index, (batch, 1, 1, length))) mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))
if bias is not None: if bias is not None:
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
...@@ -889,10 +892,11 @@ class TransformerLayer(nn.Module): ...@@ -889,10 +892,11 @@ class TransformerLayer(nn.Module):
assert isinstance(self.hidden_dropout_dims, Sequence) assert isinstance(self.hidden_dropout_dims, Sequence)
x_shape_len = len(x.shape) x_shape_len = len(x.shape)
for dims in self.hidden_dropout_dims: for dims in self.hidden_dropout_dims:
assert -x_shape_len < dims < x_shape_len assert -x_shape_len <= dims < x_shape_len
return nn.Dropout(rate=self.hidden_dropout, return nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, deterministic) broadcast_dims=self.hidden_dropout_dims)(x,
deterministic=deterministic)
x = hidden_dropout(x, deterministic) x = hidden_dropout(x, deterministic)
if self.drop_path > 0.0: if self.drop_path > 0.0:
...@@ -944,6 +948,7 @@ class TransformerLayer(nn.Module): ...@@ -944,6 +948,7 @@ class TransformerLayer(nn.Module):
intermediate_dim=self.mlp_hidden_size, intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations, activations=self.mlp_activations,
intermediate_dropout_rate=self.hidden_dropout, intermediate_dropout_rate=self.hidden_dropout,
intermediate_hidden_dropout_dims=self.hidden_dropout_dims,
dtype=self.dtype, dtype=self.dtype,
scale_axes=('embed',), scale_axes=('embed',),
kernel_init=self.mlp_kernel_init, kernel_init=self.mlp_kernel_init,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment