Unverified Commit 086a12fe authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Made order of gated act consistent in all branches (#902)



- Made order of gated act consistent in all branches
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent f68df153
...@@ -953,15 +953,15 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -953,15 +953,15 @@ class LayerNormMLP(TransformerEngineBase):
('relu',), ('relu',),
('quick_gelu',), ('quick_gelu',),
('squared_relu',)] ('squared_relu',)]
normalize_acts = [] normalized_acts = []
for act in self.activations: for act in self.activations:
if not isinstance(act, str): if not isinstance(act, str):
return False return False
normalize_acts.append(act.lower()) normalized_acts.append(act.lower())
normalize_acts = tuple(reversed(normalize_acts) normalized_acts = tuple(reversed(normalized_acts)
if normalize_acts[0] == 'linear' else normalize_acts) if normalized_acts[0] == 'linear' else normalized_acts)
is_act_implemented = normalize_acts in (gated_act_pool + act_pool) is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\ use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\
self.intermediate_dropout_rate < 1e-3 self.intermediate_dropout_rate < 1e-3
...@@ -1007,7 +1007,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1007,7 +1007,7 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_package = \ fp8_meta_package = \
TransformerEngineBase.get_fp8_meta_package(num_of_gemm) TransformerEngineBase.get_fp8_meta_package(num_of_gemm)
num_activations = len(self.activations) num_activations = len(normalized_acts)
axis = _canonicalize_tuple(self.axis) axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
...@@ -1072,7 +1072,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1072,7 +1072,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes=self.dot_2_input_axes, dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name, ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name, ffn2_ckpt_name=ffn2_ckpt_name,
activation_type = normalize_acts, activation_type = normalized_acts,
use_bias = self.use_bias) use_bias = self.use_bias)
else: # not use_fused_ln_geglu_mlp else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1 # DenseGeneral 1
...@@ -1139,12 +1139,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1139,12 +1139,12 @@ class LayerNormMLP(TransformerEngineBase):
x += jnp.reshape(bias_1, bias_1_shape) x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name) x = checkpoint_name(x, ffn1_ckpt_name)
activations = []
if is_act_implemented: if is_act_implemented:
z = activation_lu(x, normalize_acts) z = activation_lu(x, normalized_acts)
else: else:
activations = []
x = jnp.split(x, num_activations, axis=-2) x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations): for idx, act_fn in enumerate(normalized_acts):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
z = functools.reduce(operator.mul, activations) z = functools.reduce(operator.mul, activations)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment