Unverified Commit 8780caa3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[WIP][Flax] Add general conversion script (#10809)



* save intermediate

* finish first version

* delete some more

* improve import

* fix roberta

* Update src/transformers/modeling_flax_pytorch_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_flax_pytorch_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* small corrections

* apply all comments

* fix deterministic

* make fix-copies
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 604c0850
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - TF 2.0 general utilities."""
import os
from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from .utils import logging
logger = logging.get_logger(__name__)
#####################
# PyTorch => Flax #
#####################
def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False):
"""Load pytorch checkpoints in a flax model"""
try:
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see "
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
)
raise
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path))
pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info("PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values())} parameters.")
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
return flax_state_dict
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
random_flax_state_dict = flatten_dict(unfreeze(flax_model.params))
flax_state_dict = {}
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split("."))
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
if remove_base_model_prefix and has_base_model_prefix:
pt_tuple_key = pt_tuple_key[1:]
elif add_base_model_prefix and require_base_model_prefix:
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
if pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
pt_tensor = pt_tensor.T
elif pt_tuple_key[-1] == "gamma":
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
elif pt_tuple_key[-1] == "beta":
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key in random_flax_state_dict:
if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape:
raise ValueError(
"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
)
# add unexpected weight so that warning is thrown
flax_state_dict[pt_tuple_key] = pt_tensor
return unflatten_dict(flax_state_dict)
......@@ -14,7 +14,7 @@
# limitations under the License.
import os
from abc import ABC, abstractmethod
from abc import ABC
from functools import partial
from pickle import UnpicklingError
from typing import Dict, Set, Tuple, Union
......@@ -29,6 +29,7 @@ from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import logging
......@@ -121,11 +122,6 @@ class FlaxPreTrainedModel(ABC):
)
self._params = freeze(params)
@staticmethod
@abstractmethod
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
raise NotImplementedError()
@classmethod
def from_pretrained(
cls,
......@@ -307,25 +303,18 @@ class FlaxPreTrainedModel(ABC):
else:
resolved_archive_file = None
# Instantiate model.
with open(resolved_archive_file, "rb") as state_f:
try:
if from_pt:
import torch
state = torch.load(state_f)
state = convert_state_dict_from_pt(cls, state, config)
else:
state = from_bytes(cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(
f"Unable to convert pytorch model {archive_file} to Flax deserializable object. "
)
# init random models
model = cls(config, *model_args, **model_kwargs)
if from_pt:
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
else:
with open(resolved_archive_file, "rb") as state_f:
try:
state = from_bytes(cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
# if model is base model only use model_prefix key
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
state = state[cls.base_model_prefix]
......@@ -341,6 +330,10 @@ class FlaxPreTrainedModel(ABC):
for missing_key in missing_keys:
state[missing_key] = random_state[missing_key]
# remove unexpected keys to not be saved again
for unexpected_key in unexpected_keys:
del state[unexpected_key]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
......@@ -393,13 +386,3 @@ class FlaxPreTrainedModel(ABC):
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
model_bytes = to_bytes(self.params)
f.write(model_bytes)
def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig):
"""
Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict
"""
state = {k: v.numpy() for k, v in state.items()}
state = model_class.convert_from_pytorch(state, config)
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
return state
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Tuple
from typing import Callable, Tuple
import numpy as np
......@@ -21,6 +21,8 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
......@@ -99,17 +101,15 @@ class FlaxBertLayerNorm(nn.Module):
hidden_size: int
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
dtype: jnp.dtype = jnp.float32
use_bias: bool = True
scale: bool = True
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
def __call__(self, x):
"""
......@@ -129,11 +129,11 @@ class FlaxBertLayerNorm(nn.Module):
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.gamma)
mul = mul * jnp.asarray(self.weight)
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.beta)
if self.use_bias:
y = y + jnp.asarray(self.bias)
return y
......@@ -167,24 +167,21 @@ class FlaxBertEmbeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="word_embeddings",
dtype=self.dtype,
)
self.position_embeddings = FlaxBertEmbedding(
self.config.max_position_embeddings,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="position_embeddings",
dtype=self.dtype,
)
self.token_type_embeddings = FlaxBertEmbedding(
self.config.type_vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="token_type_embeddings",
dtype=self.dtype,
)
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
......@@ -197,35 +194,116 @@ class FlaxBertEmbeddings(nn.Module):
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
# Layer Norm
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxBertAttention(nn.Module):
class FlaxBertSelfAttention(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.self_attention = nn.attention.SelfAttention(
num_heads=self.config.num_attention_heads,
qkv_features=self.config.hidden_size,
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
)
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
def __call__(self, hidden_states, attention_mask, deterministic=True):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
value_states = self.value(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
key_states = self.key(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout_rate > 0.0:
dropout_rng = self.make_rng("dropout")
attn_output = dot_product_attention(
query_states,
key_states,
value_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
return attn_output.reshape(attn_output.shape[:2] + (-1,))
class FlaxBertSelfOutput(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class FlaxBertAttention(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.layer_norm(self_attn_output + hidden_states)
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
return hidden_states
......@@ -237,7 +315,6 @@ class FlaxBertIntermediate(nn.Module):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
......@@ -256,16 +333,15 @@ class FlaxBertOutput(nn.Module):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.layer_norm(hidden_states + attention_output)
hidden_states = self.LayerNorm(hidden_states + attention_output)
return hidden_states
......@@ -274,9 +350,9 @@ class FlaxBertLayer(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.attention = FlaxBertAttention(self.config, name="attention", dtype=self.dtype)
self.intermediate = FlaxBertIntermediate(self.config, name="intermediate", dtype=self.dtype)
self.output = FlaxBertOutput(self.config, name="output", dtype=self.dtype)
self.attention = FlaxBertAttention(self.config, dtype=self.dtype)
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
......@@ -305,10 +381,10 @@ class FlaxBertEncoder(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = FlaxBertLayerCollection(self.config, name="layer", dtype=self.dtype)
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
class FlaxBertPooler(nn.Module):
......@@ -319,7 +395,6 @@ class FlaxBertPooler(nn.Module):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
......@@ -334,14 +409,14 @@ class FlaxBertPredictionHeadTransform(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(self.config.hidden_size, name="dense", dtype=self.dtype)
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.activation = ACT2FN[self.config.hidden_act]
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return self.layer_norm(hidden_states)
return self.LayerNorm(hidden_states)
class FlaxBertLMPredictionHead(nn.Module):
......@@ -349,14 +424,10 @@ class FlaxBertLMPredictionHead(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.transform = FlaxBertPredictionHeadTransform(self.config, name="transform", dtype=self.dtype)
self.decoder = nn.Dense(self.config.vocab_size, name="decoder", dtype=self.dtype)
self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype)
def __call__(self, hidden_states):
# TODO: The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
# Need a link between the two variables so that the bias is correctly
# resized with `resize_token_embeddings`
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
......@@ -367,10 +438,10 @@ class FlaxBertOnlyMLMHead(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.mlm_head = FlaxBertLMPredictionHead(self.config, name="predictions", dtype=self.dtype)
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.mlm_head(hidden_states)
hidden_states = self.predictions(hidden_states)
return hidden_states
......@@ -405,85 +476,6 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
@staticmethod
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
jax_state = dict(pt_state)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for key, tensor in pt_state.items():
# Key parts
key_parts = set(key.split("."))
# Every dense layer has "kernel" parameters instead of "weight"
if "dense.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor
if "decoder.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor.T
# SelfAttention needs also to replace "weight" by "kernel"
if {"query", "key", "value"} & key_parts:
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
if "bias" in key:
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
elif "weight":
del jax_state[key]
key = key.replace("weight", "kernel")
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove one nesting
if "attention.output.dense" in key:
del jax_state[key]
key = key.replace("attention.output.dense", "attention.self.out")
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove nesting on layer norm
if "attention.output.LayerNorm" in key:
del jax_state[key]
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
jax_state[key] = tensor
# There are some transposed parameters w.r.t their PyTorch counterpart
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key or "transform.dense.kernel" in key:
jax_state[key] = tensor.T
# Self Attention output projection needs to be transposed
if "out.kernel" in key:
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
1, 2, 0
)
# Pooler needs to transpose its kernel
if "pooler.dense.kernel" in key:
jax_state[key] = tensor.T
# Hack to correctly load some pytorch models
if "predictions.bias" in key:
del jax_state[key]
jax_state[".".join(key.split(".")[:2]) + ".decoder.bias"] = tensor
# Handle LayerNorm conversion
if "LayerNorm" in key:
del jax_state[key]
# Replace LayerNorm by layer_norm
new_key = key.replace("LayerNorm", "layer_norm")
if "weight" in key:
new_key = new_key.replace("weight", "gamma")
elif "bias" in key:
new_key = new_key.replace("bias", "beta")
jax_state[new_key] = tensor
return jax_state
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
......@@ -541,9 +533,9 @@ class FlaxBertModule(nn.Module):
add_pooling_layer: bool = True
def setup(self):
self.embeddings = FlaxBertEmbeddings(self.config, name="embeddings", dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, name="encoder", dtype=self.dtype)
self.pooler = FlaxBertPooler(self.config, name="pooler", dtype=self.dtype)
self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
......@@ -602,15 +594,13 @@ class FlaxBertForMaskedLMModule(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.encoder = FlaxBertModule(
self.bert = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
name="bert",
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.mlm_head = FlaxBertOnlyMLMHead(
self.cls = FlaxBertOnlyMLMHead(
config=self.config,
name="cls",
dtype=self.dtype,
)
......@@ -618,12 +608,10 @@ class FlaxBertForMaskedLMModule(nn.Module):
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
hidden_states = self.encoder(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
)
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
# Compute the prediction scores
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.mlm_head(hidden_states)
logits = self.cls(hidden_states)
return (logits,)
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Tuple
from typing import Callable, Tuple
import numpy as np
......@@ -20,6 +20,8 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
......@@ -116,17 +118,15 @@ class FlaxRobertaLayerNorm(nn.Module):
hidden_size: int
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
dtype: jnp.dtype = jnp.float32
use_bias: bool = True
scale: bool = True
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
def __call__(self, x):
"""
......@@ -146,11 +146,11 @@ class FlaxRobertaLayerNorm(nn.Module):
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.gamma)
mul = mul * jnp.asarray(self.weight)
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.beta)
if self.use_bias:
y = y + jnp.asarray(self.bias)
return y
......@@ -186,26 +186,21 @@ class FlaxRobertaEmbeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="word_embeddings",
dtype=self.dtype,
)
self.position_embeddings = FlaxRobertaEmbedding(
self.config.max_position_embeddings,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="position_embeddings",
dtype=self.dtype,
)
self.token_type_embeddings = FlaxRobertaEmbedding(
self.config.type_vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="token_type_embeddings",
dtype=self.dtype,
)
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
......@@ -218,38 +213,119 @@ class FlaxRobertaEmbeddings(nn.Module):
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
# Layer Norm
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
class FlaxRobertaSelfAttention(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.self_attention = nn.attention.SelfAttention(
num_heads=self.config.num_attention_heads,
qkv_features=self.config.hidden_size,
dropout_rate=self.config.attention_probs_dropout_prob,
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
)
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
def __call__(self, hidden_states, attention_mask, deterministic=True):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
value_states = self.value(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
key_states = self.key(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout_rate > 0.0:
dropout_rng = self.make_rng("dropout")
attn_output = dot_product_attention(
query_states,
key_states,
value_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
return attn_output.reshape(attn_output.shape[:2] + (-1,))
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
class FlaxRobertaSelfOutput(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.layer_norm(self_attn_output + hidden_states)
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
return hidden_states
......@@ -262,7 +338,6 @@ class FlaxRobertaIntermediate(nn.Module):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
......@@ -282,18 +357,15 @@ class FlaxRobertaOutput(nn.Module):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.layer_norm(hidden_states + attention_output)
hidden_states = self.LayerNorm(hidden_states + attention_output)
return hidden_states
......@@ -303,9 +375,9 @@ class FlaxRobertaLayer(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.attention = FlaxRobertaAttention(self.config, name="attention", dtype=self.dtype)
self.intermediate = FlaxRobertaIntermediate(self.config, name="intermediate", dtype=self.dtype)
self.output = FlaxRobertaOutput(self.config, name="output", dtype=self.dtype)
self.attention = FlaxRobertaAttention(self.config, dtype=self.dtype)
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
......@@ -336,10 +408,10 @@ class FlaxRobertaEncoder(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = FlaxRobertaLayerCollection(self.config, name="layer", dtype=self.dtype)
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
......@@ -351,7 +423,6 @@ class FlaxRobertaPooler(nn.Module):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
......@@ -370,75 +441,6 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
@staticmethod
def convert_from_pytorch(pt_state: Dict, config: RobertaConfig) -> Dict:
jax_state = dict(pt_state)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for key, tensor in pt_state.items():
# Key parts
key_parts = set(key.split("."))
# Every dense layer has "kernel" parameters instead of "weight"
if "dense.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor
# SelfAttention needs also to replace "weight" by "kernel"
if {"query", "key", "value"} & key_parts:
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
if "bias" in key:
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
elif "weight":
del jax_state[key]
key = key.replace("weight", "kernel")
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove one nesting
if "attention.output.dense" in key:
del jax_state[key]
key = key.replace("attention.output.dense", "attention.self.out")
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove nesting on layer norm
if "attention.output.LayerNorm" in key:
del jax_state[key]
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
jax_state[key] = tensor
# There are some transposed parameters w.r.t their PyTorch counterpart
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
jax_state[key] = tensor.T
# Self Attention output projection needs to be transposed
if "out.kernel" in key:
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
1, 2, 0
)
# Pooler needs to transpose its kernel
if "pooler.dense.kernel" in key:
jax_state[key] = tensor.T
# Handle LayerNorm conversion
if "LayerNorm" in key:
del jax_state[key]
# Replace LayerNorm by layer_norm
new_key = key.replace("LayerNorm", "layer_norm")
if "weight" in key:
new_key = new_key.replace("weight", "gamma")
elif "bias" in key:
new_key = new_key.replace("bias", "beta")
jax_state[new_key] = tensor
return jax_state
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
......@@ -523,9 +525,9 @@ class FlaxRobertaModule(nn.Module):
add_pooling_layer: bool = True
def setup(self):
self.embeddings = FlaxRobertaEmbeddings(self.config, name="embeddings", dtype=self.dtype)
self.encoder = FlaxRobertaEncoder(self.config, name="encoder", dtype=self.dtype)
self.pooler = FlaxRobertaPooler(self.config, name="pooler", dtype=self.dtype)
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
......
......@@ -115,6 +115,6 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("bert-base-cased")
model = model_class_name.from_pretrained("bert-base-cased", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
......@@ -27,7 +27,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.modeling_flax_utils import convert_state_dict_from_pt
from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
......@@ -79,8 +79,8 @@ class FlaxModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
......
......@@ -115,6 +115,6 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("roberta-base")
model = model_class_name.from_pretrained("roberta-base", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment