Unverified Commit 45e9d8b6 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Lint Fix (#1484)



JAX Lint Fix
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e19b8281
......@@ -1251,7 +1251,6 @@ class LayerNormMLP(TransformerEngineBase):
# Remove act axis
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = z.astype(self.dtype)
# import pdb; pdb.set_trace()
z = nn.Dropout(
rate=self.intermediate_dropout_rate,
......
......@@ -987,9 +987,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", self.weight_dtype
1.0, "fan_in", "normal", dtype=self.weight_dtype
)
self.kernel_init = _kernel_init.astype(self.dtype)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment