Unverified Commit 689ff74f authored by Avital Oliver's avatar Avital Oliver Committed by GitHub
Browse files

Minor style improvements for the Flax BERT and RoBERTa examples (#8178)

* Minor style improvements:

1. Use `@nn.compact` rather than `@compact` (as to not make it seem
   like compact is a standard Python decorator.
2. Move attribute docstrings from two `__call__` methods to comments
   on the attributes themselves. (This was probably a remnant from
   the pre-Linen version where the attributes were arguments to
   `call`.)

* Use black on the Flax modeling code
parent 9eb3a410
......@@ -20,7 +20,6 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen import compact
from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings
......@@ -108,13 +107,15 @@ class FlaxBertLayerNorm(nn.Module):
"""
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
bias: bool = True
scale: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones
@compact
@nn.compact
def __call__(self, x):
"""
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
......@@ -123,13 +124,6 @@ class FlaxBertLayerNorm(nn.Module):
Args:
x: the inputs
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: If True, bias (beta) is added.
scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one
Returns:
Normalized inputs (the same shape as inputs).
......@@ -157,7 +151,7 @@ class FlaxBertEmbedding(nn.Module):
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
@compact
@nn.compact
def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
......@@ -171,7 +165,7 @@ class FlaxBertEmbeddings(nn.Module):
type_vocab_size: int
max_length: int
@compact
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embed
......@@ -198,7 +192,7 @@ class FlaxBertAttention(nn.Module):
num_heads: int
head_size: int
@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
......@@ -211,7 +205,7 @@ class FlaxBertAttention(nn.Module):
class FlaxBertIntermediate(nn.Module):
output_size: int
@compact
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
......@@ -219,7 +213,7 @@ class FlaxBertIntermediate(nn.Module):
class FlaxBertOutput(nn.Module):
@compact
@nn.compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output)
......@@ -231,7 +225,7 @@ class FlaxBertLayer(nn.Module):
head_size: int
intermediate_size: int
@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask)
intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention)
......@@ -250,7 +244,7 @@ class FlaxBertLayerCollection(nn.Module):
head_size: int
intermediate_size: int
@compact
@nn.compact
def __call__(self, inputs, attention_mask):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
......@@ -270,7 +264,7 @@ class FlaxBertEncoder(nn.Module):
head_size: int
intermediate_size: int
@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
layer = FlaxBertLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
......@@ -279,7 +273,7 @@ class FlaxBertEncoder(nn.Module):
class FlaxBertPooler(nn.Module):
@compact
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
......@@ -296,7 +290,7 @@ class FlaxBertModule(nn.Module):
head_size: int
intermediate_size: int
@compact
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embedding
......
......@@ -19,7 +19,6 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen import compact
from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings
......@@ -108,13 +107,15 @@ class FlaxRobertaLayerNorm(nn.Module):
"""
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
bias: bool = True
scale: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones
@compact
@nn.compact
def __call__(self, x):
"""
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
......@@ -123,13 +124,6 @@ class FlaxRobertaLayerNorm(nn.Module):
Args:
x: the inputs
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: If True, bias (beta) is added.
scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one
Returns:
Normalized inputs (the same shape as inputs).
......@@ -158,7 +152,7 @@ class FlaxRobertaEmbedding(nn.Module):
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
@compact
@nn.compact
def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
......@@ -173,7 +167,7 @@ class FlaxRobertaEmbeddings(nn.Module):
type_vocab_size: int
max_length: int
@compact
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embed
......@@ -201,7 +195,7 @@ class FlaxRobertaAttention(nn.Module):
num_heads: int
head_size: int
@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
......@@ -215,7 +209,7 @@ class FlaxRobertaAttention(nn.Module):
class FlaxRobertaIntermediate(nn.Module):
output_size: int
@compact
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
......@@ -224,7 +218,7 @@ class FlaxRobertaIntermediate(nn.Module):
# Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module):
@compact
@nn.compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output)
......@@ -236,7 +230,7 @@ class FlaxRobertaLayer(nn.Module):
head_size: int
intermediate_size: int
@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")(
hidden_state, attention_mask
......@@ -258,7 +252,7 @@ class FlaxRobertaLayerCollection(nn.Module):
head_size: int
intermediate_size: int
@compact
@nn.compact
def __call__(self, inputs, attention_mask):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
......@@ -279,7 +273,7 @@ class FlaxRobertaEncoder(nn.Module):
head_size: int
intermediate_size: int
@compact
@nn.compact
def __call__(self, hidden_state, attention_mask):
layer = FlaxRobertaLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
......@@ -289,7 +283,7 @@ class FlaxRobertaEncoder(nn.Module):
# Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module):
@compact
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
......@@ -307,7 +301,7 @@ class FlaxRobertaModule(nn.Module):
head_size: int
intermediate_size: int
@compact
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embedding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment