"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b43c78e5d32b6eb8f367c52950336591ef8d82eb"
Unverified Commit 689ff74f authored by Avital Oliver's avatar Avital Oliver Committed by GitHub
Browse files

Minor style improvements for the Flax BERT and RoBERTa examples (#8178)

* Minor style improvements:

1. Use `@nn.compact` rather than `@compact` (as to not make it seem
   like compact is a standard Python decorator.
2. Move attribute docstrings from two `__call__` methods to comments
   on the attributes themselves. (This was probably a remnant from
   the pre-Linen version where the attributes were arguments to
   `call`.)

* Use black on the Flax modeling code
parent 9eb3a410
...@@ -20,7 +20,6 @@ import numpy as np ...@@ -20,7 +20,6 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.linen import compact
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
...@@ -108,13 +107,15 @@ class FlaxBertLayerNorm(nn.Module): ...@@ -108,13 +107,15 @@ class FlaxBertLayerNorm(nn.Module):
""" """
epsilon: float = 1e-6 epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True bias: bool = True # If True, bias (beta) is added.
scale: bool = True scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones scale_init: jnp.ndarray = nn.initializers.ones
@compact @nn.compact
def __call__(self, x): def __call__(self, x):
""" """
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
...@@ -123,13 +124,6 @@ class FlaxBertLayerNorm(nn.Module): ...@@ -123,13 +124,6 @@ class FlaxBertLayerNorm(nn.Module):
Args: Args:
x: the inputs x: the inputs
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: If True, bias (beta) is added.
scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one
Returns: Returns:
Normalized inputs (the same shape as inputs). Normalized inputs (the same shape as inputs).
...@@ -157,7 +151,7 @@ class FlaxBertEmbedding(nn.Module): ...@@ -157,7 +151,7 @@ class FlaxBertEmbedding(nn.Module):
hidden_size: int hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
@compact @nn.compact
def __call__(self, inputs): def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0) return jnp.take(embedding, inputs, axis=0)
...@@ -171,7 +165,7 @@ class FlaxBertEmbeddings(nn.Module): ...@@ -171,7 +165,7 @@ class FlaxBertEmbeddings(nn.Module):
type_vocab_size: int type_vocab_size: int
max_length: int max_length: int
@compact @nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embed # Embed
...@@ -198,7 +192,7 @@ class FlaxBertAttention(nn.Module): ...@@ -198,7 +192,7 @@ class FlaxBertAttention(nn.Module):
num_heads: int num_heads: int
head_size: int head_size: int
@compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask):
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask hidden_state, attention_mask
...@@ -211,7 +205,7 @@ class FlaxBertAttention(nn.Module): ...@@ -211,7 +205,7 @@ class FlaxBertAttention(nn.Module):
class FlaxBertIntermediate(nn.Module): class FlaxBertIntermediate(nn.Module):
output_size: int output_size: int
@compact @nn.compact
def __call__(self, hidden_state): def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function # TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
...@@ -219,7 +213,7 @@ class FlaxBertIntermediate(nn.Module): ...@@ -219,7 +213,7 @@ class FlaxBertIntermediate(nn.Module):
class FlaxBertOutput(nn.Module): class FlaxBertOutput(nn.Module):
@compact @nn.compact
def __call__(self, intermediate_output, attention_output): def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output) hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output)
...@@ -231,7 +225,7 @@ class FlaxBertLayer(nn.Module): ...@@ -231,7 +225,7 @@ class FlaxBertLayer(nn.Module):
head_size: int head_size: int
intermediate_size: int intermediate_size: int
@compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask):
attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask) attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask)
intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention) intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention)
...@@ -250,7 +244,7 @@ class FlaxBertLayerCollection(nn.Module): ...@@ -250,7 +244,7 @@ class FlaxBertLayerCollection(nn.Module):
head_size: int head_size: int
intermediate_size: int intermediate_size: int
@compact @nn.compact
def __call__(self, inputs, attention_mask): def __call__(self, inputs, attention_mask):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
...@@ -270,7 +264,7 @@ class FlaxBertEncoder(nn.Module): ...@@ -270,7 +264,7 @@ class FlaxBertEncoder(nn.Module):
head_size: int head_size: int
intermediate_size: int intermediate_size: int
@compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask):
layer = FlaxBertLayerCollection( layer = FlaxBertLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer" self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
...@@ -279,7 +273,7 @@ class FlaxBertEncoder(nn.Module): ...@@ -279,7 +273,7 @@ class FlaxBertEncoder(nn.Module):
class FlaxBertPooler(nn.Module): class FlaxBertPooler(nn.Module):
@compact @nn.compact
def __call__(self, hidden_state): def __call__(self, hidden_state):
cls_token = hidden_state[:, 0] cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
...@@ -296,7 +290,7 @@ class FlaxBertModule(nn.Module): ...@@ -296,7 +290,7 @@ class FlaxBertModule(nn.Module):
head_size: int head_size: int
intermediate_size: int intermediate_size: int
@compact @nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embedding # Embedding
......
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.linen import compact
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
...@@ -108,13 +107,15 @@ class FlaxRobertaLayerNorm(nn.Module): ...@@ -108,13 +107,15 @@ class FlaxRobertaLayerNorm(nn.Module):
""" """
epsilon: float = 1e-6 epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True bias: bool = True # If True, bias (beta) is added.
scale: bool = True scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones scale_init: jnp.ndarray = nn.initializers.ones
@compact @nn.compact
def __call__(self, x): def __call__(self, x):
""" """
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
...@@ -123,13 +124,6 @@ class FlaxRobertaLayerNorm(nn.Module): ...@@ -123,13 +124,6 @@ class FlaxRobertaLayerNorm(nn.Module):
Args: Args:
x: the inputs x: the inputs
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: If True, bias (beta) is added.
scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one
Returns: Returns:
Normalized inputs (the same shape as inputs). Normalized inputs (the same shape as inputs).
...@@ -158,7 +152,7 @@ class FlaxRobertaEmbedding(nn.Module): ...@@ -158,7 +152,7 @@ class FlaxRobertaEmbedding(nn.Module):
hidden_size: int hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
@compact @nn.compact
def __call__(self, inputs): def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0) return jnp.take(embedding, inputs, axis=0)
...@@ -173,7 +167,7 @@ class FlaxRobertaEmbeddings(nn.Module): ...@@ -173,7 +167,7 @@ class FlaxRobertaEmbeddings(nn.Module):
type_vocab_size: int type_vocab_size: int
max_length: int max_length: int
@compact @nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embed # Embed
...@@ -201,7 +195,7 @@ class FlaxRobertaAttention(nn.Module): ...@@ -201,7 +195,7 @@ class FlaxRobertaAttention(nn.Module):
num_heads: int num_heads: int
head_size: int head_size: int
@compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask):
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask hidden_state, attention_mask
...@@ -215,7 +209,7 @@ class FlaxRobertaAttention(nn.Module): ...@@ -215,7 +209,7 @@ class FlaxRobertaAttention(nn.Module):
class FlaxRobertaIntermediate(nn.Module): class FlaxRobertaIntermediate(nn.Module):
output_size: int output_size: int
@compact @nn.compact
def __call__(self, hidden_state): def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function # TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
...@@ -224,7 +218,7 @@ class FlaxRobertaIntermediate(nn.Module): ...@@ -224,7 +218,7 @@ class FlaxRobertaIntermediate(nn.Module):
# Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta # Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module): class FlaxRobertaOutput(nn.Module):
@compact @nn.compact
def __call__(self, intermediate_output, attention_output): def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output) hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output)
...@@ -236,7 +230,7 @@ class FlaxRobertaLayer(nn.Module): ...@@ -236,7 +230,7 @@ class FlaxRobertaLayer(nn.Module):
head_size: int head_size: int
intermediate_size: int intermediate_size: int
@compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask):
attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")( attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")(
hidden_state, attention_mask hidden_state, attention_mask
...@@ -258,7 +252,7 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -258,7 +252,7 @@ class FlaxRobertaLayerCollection(nn.Module):
head_size: int head_size: int
intermediate_size: int intermediate_size: int
@compact @nn.compact
def __call__(self, inputs, attention_mask): def __call__(self, inputs, attention_mask):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
...@@ -279,7 +273,7 @@ class FlaxRobertaEncoder(nn.Module): ...@@ -279,7 +273,7 @@ class FlaxRobertaEncoder(nn.Module):
head_size: int head_size: int
intermediate_size: int intermediate_size: int
@compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask):
layer = FlaxRobertaLayerCollection( layer = FlaxRobertaLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer" self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
...@@ -289,7 +283,7 @@ class FlaxRobertaEncoder(nn.Module): ...@@ -289,7 +283,7 @@ class FlaxRobertaEncoder(nn.Module):
# Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta # Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module): class FlaxRobertaPooler(nn.Module):
@compact @nn.compact
def __call__(self, hidden_state): def __call__(self, hidden_state):
cls_token = hidden_state[:, 0] cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
...@@ -307,7 +301,7 @@ class FlaxRobertaModule(nn.Module): ...@@ -307,7 +301,7 @@ class FlaxRobertaModule(nn.Module):
head_size: int head_size: int
intermediate_size: int intermediate_size: int
@compact @nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embedding # Embedding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment