Unverified Commit 086a12fe authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Made order of gated act consistent in all branches (#902)



- Made order of gated act consistent in all branches
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent f68df153
......@@ -953,15 +953,15 @@ class LayerNormMLP(TransformerEngineBase):
('relu',),
('quick_gelu',),
('squared_relu',)]
normalize_acts = []
normalized_acts = []
for act in self.activations:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
normalize_acts = tuple(reversed(normalize_acts)
if normalize_acts[0] == 'linear' else normalize_acts)
normalized_acts.append(act.lower())
normalized_acts = tuple(reversed(normalized_acts)
if normalized_acts[0] == 'linear' else normalized_acts)
is_act_implemented = normalize_acts in (gated_act_pool + act_pool)
is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\
self.intermediate_dropout_rate < 1e-3
......@@ -1007,7 +1007,7 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_package = \
TransformerEngineBase.get_fp8_meta_package(num_of_gemm)
num_activations = len(self.activations)
num_activations = len(normalized_acts)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim)
......@@ -1072,7 +1072,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name,
activation_type = normalize_acts,
activation_type = normalized_acts,
use_bias = self.use_bias)
else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1
......@@ -1139,12 +1139,12 @@ class LayerNormMLP(TransformerEngineBase):
x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name)
activations = []
if is_act_implemented:
z = activation_lu(x, normalize_acts)
z = activation_lu(x, normalized_acts)
else:
activations = []
x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations):
for idx, act_fn in enumerate(normalized_acts):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
z = functools.reduce(operator.mul, activations)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment