Unverified Commit f7076cd3 authored by Kian Sierra McGettigan's avatar Kian Sierra McGettigan Committed by GitHub
Browse files

Flax mistral (#26943)

* direct copy from llama work

* mistral modules forward pass working

* flax mistral forward pass with sliding window

* added tests

* added layer collection approach

* Revert "added layer collection approach"

This reverts commit 0e2905bf2236ec323163fc1a9f0c016b21aa8b8f.

* Revert "Revert "added layer collection approach""

This reverts commit fb17b6187ac5d16da7c461e1130514dc3d137a43.

* fixed attention outputs

* added mistral to init and auto

* fixed import name

* fixed layernorm weight dtype

* freeze initialized weights

* make sure conversion consideres bfloat16

* added backend

* added docstrings

* added cache

* fixed sliding window causal mask

* passes cache tests

* passed all tests

* applied make style

* removed commented out code

* applied fix-copies ignored other model changes

* applied make fix-copies

* removed unused functions

* passed generation integration test

* slow tests pass

* fixed slow tests

* changed default dtype from jax.numpy.float32 to float32 for docstring check

* skip cache test  for FlaxMistralForSequenceClassification since if pad_token_id in input_ids it doesn't score previous input_ids

* updated checkpoint since from_pt not included

* applied black style

* removed unused args

* Applied styling and fixup

* changed checkpoint for doc back

* fixed rf after adding it to hf hub

* Add dummy ckpt

* applied styling

* added tokenizer to new ckpt

* fixed slice format

* fix init and slice

* changed ref for placeholder TODO

* added copies from Llama

* applied styling

* applied fix-copies

* fixed docs

* update weight dtype reconversion for sharded weights

* removed Nullable input ids

* Removed unnecessary output attentions in Module

* added embedding weight initialziation

* removed unused past_key_values

* fixed deterministic

* Fixed RMS Norm and added copied from

* removed input_embeds

* applied make style

* removed nullable input ids from sequence classification model

* added copied from GPTJ

* added copied from Llama on FlaxMistralDecoderLayer

* added copied from to FlaxMistralPreTrainedModel methods

* fix test deprecation warning

* freeze gpt neox random_params and fix copies

* applied make style

* fixed doc issue

* skipped docstring test to allign # copied from

* applied make style

* removed FlaxMistralForSequenceClassification

* removed unused padding_idx

* removed more sequence classification

* removed sequence classification

* applied styling and consistency

* added copied from in tests

* removed sequence classification test logic

* applied styling

* applied make style

* removed freeze and fixed copies

* undo test change

* changed repeat_kv to tile

* fixed to key value groups

* updated copyright year

* split casual_mask

* empty to rerun failed pt_flax_equivalence test FlaxWav2Vec2ModelTest

* went back to 2023 for tests_pr_documentation_tests

* went back to 2024

* changed tile to repeat

* applied make style

* empty for retry on Wav2Vec2
parent 7a496100
......@@ -190,7 +190,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Megatron-BERT](model_doc/megatron-bert) | ✅ | ❌ | ❌ |
| [Megatron-GPT2](model_doc/megatron_gpt2) | ✅ | ✅ | ✅ |
| [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ |
| [Mistral](model_doc/mistral) | ✅ | ❌ | |
| [Mistral](model_doc/mistral) | ✅ | ❌ | |
| [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ |
| [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ |
| [MMS](model_doc/mms) | ✅ | ✅ | ✅ |
......
......@@ -149,3 +149,13 @@ Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Sin
[[autodoc]] MistralForSequenceClassification
- forward
## FlaxMistralModel
[[autodoc]] FlaxMistralModel
- __call__
## FlaxMistralForCausalLM
[[autodoc]] FlaxMistralForCausalLM
- __call__
......@@ -4678,6 +4678,13 @@ else:
"FlaxMBartPreTrainedModel",
]
)
_import_structure["models.mistral"].extend(
[
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxMistralPreTrainedModel",
]
)
_import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
_import_structure["models.opt"].extend(
[
......@@ -8830,6 +8837,11 @@ if TYPE_CHECKING:
FlaxMBartModel,
FlaxMBartPreTrainedModel,
)
from .models.mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
FlaxMistralPreTrainedModel,
)
from .models.mt5 import (
FlaxMT5EncoderModel,
FlaxMT5ForConditionalGeneration,
......
......@@ -255,7 +255,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
# load using msgpack utils
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
}
model_prefix = flax_model.base_model_prefix
......@@ -278,6 +281,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16
# remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix
......@@ -314,11 +318,15 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
continue
# also add unexpected weight so that warning is thrown
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)
flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
else:
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
flax_state_dict[flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
return unflatten_dict(flax_state_dict)
......
......@@ -47,6 +47,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
("mistral", "FlaxMistralModel"),
("mt5", "FlaxMT5Model"),
("opt", "FlaxOPTModel"),
("pegasus", "FlaxPegasusModel"),
......@@ -148,6 +149,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("llama", "FlaxLlamaForCausalLM"),
("mistral", "FlaxMistralForCausalLM"),
("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
......
......@@ -13,11 +13,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
_import_structure = {
......@@ -38,6 +34,18 @@ else:
"MistralForSequenceClassification",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_mistral"] = [
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxMistralPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig
......@@ -55,6 +63,18 @@ if TYPE_CHECKING:
MistralPreTrainedModel,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
FlaxMistralPreTrainedModel,
)
else:
import sys
......
# coding=utf-8
# Copyright 2024 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax Mistral model."""
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPast,
FlaxCausalLMOutput,
FlaxCausalLMOutputWithCrossAttentions,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, logging
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .configuration_mistral import MistralConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MistralConfig"
_REAL_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
_CHECKPOINT_FOR_DOC = "ksmcg/Mistral-tiny"
MISTRAL_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`MistralConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or
`jax.numpy.bfloat16`.
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
MISTRAL_INPUTS_DOCSTRING = r"""
Args:
input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRMSNorm with Llama->Mistral
class FlaxMistralRMSNorm(nn.Module):
config: MistralConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.epsilon = self.config.rms_norm_eps
self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)
def __call__(self, hidden_states):
variance = jnp.asarray(hidden_states, dtype=jnp.float32)
variance = jnp.power(variance, 2)
variance = variance.mean(-1, keepdims=True)
# use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt`
hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)
return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Mistral
class FlaxMistralRotaryEmbedding(nn.Module):
config: MistralConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
head_dim = self.config.hidden_size // self.config.num_attention_heads
self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim)
def __call__(self, key, query, position_ids):
sincos = self.sincos[position_ids]
sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1)
key = apply_rotary_pos_emb(key, sin_pos, cos_pos)
query = apply_rotary_pos_emb(query, sin_pos, cos_pos)
key = jnp.asarray(key, dtype=self.dtype)
query = jnp.asarray(query, dtype=self.dtype)
return key, query
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP with Llama->Mistral
class FlaxMistralMLP(nn.Module):
config: MistralConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
embed_dim = self.config.hidden_size
inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
self.act = ACT2FN[self.config.hidden_act]
self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
def __call__(self, hidden_states):
up_proj_states = self.up_proj(hidden_states)
gate_states = self.act(self.gate_proj(hidden_states))
hidden_states = self.down_proj(up_proj_states * gate_states)
return hidden_states
# Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
# Copied from transformers.models.llama.modeling_flax_llama.create_sinusoidal_positions
def create_sinusoidal_positions(num_pos, dim):
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
emb = np.concatenate((freqs, freqs), axis=-1)
out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
return jnp.array(out[:, :, :num_pos])
# Copied from transformers.models.llama.modeling_flax_llama.rotate_half
def rotate_half(tensor):
"""Rotates half the hidden dims of the input."""
rotate_half_tensor = jnp.concatenate(
(-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
)
return rotate_half_tensor
class FlaxMistralAttention(nn.Module):
config: MistralConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
config = self.config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
self.rope_theta = config.rope_theta
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Dense(self.num_heads * self.head_dim, use_bias=False, dtype=self.dtype)
self.k_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
self.v_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
self.o_proj = nn.Dense(self.hidden_size, use_bias=False, dtype=self.dtype)
casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
self.causal_mask = jnp.triu(casual_mask, k=-config.sliding_window)
self.rotary_emb = FlaxMistralRotaryEmbedding(config, dtype=self.dtype)
def _split_heads(self, hidden_states, num_heads):
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
@nn.compact
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
deterministic: bool = True,
output_attentions: bool = False,
init_cache: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._split_heads(query_states, self.num_heads)
key_states = self._split_heads(key_states, self.num_key_value_heads)
value_states = self._split_heads(value_states, self.num_key_value_heads)
key_states, query_states = self.rotary_emb(key_states, query_states, position_ids)
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
)
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
batch_size = hidden_states.shape[0]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
if self.has_variable("cache", "cached_key") or init_cache:
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
key_states = jnp.repeat(key_states, self.num_key_value_groups, axis=2)
value_states = jnp.repeat(value_states, self.num_key_value_groups, axis=2)
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
# usual dot product attention
attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
deterministic=deterministic,
dropout_rate=self.config.attention_dropout,
dtype=attention_dtype,
)
if self.attention_softmax_in_fp32:
attn_weights = attn_weights.astype(self.dtype)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = self._merge_heads(attn_output)
attn_output = self.o_proj(attn_output)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Mistral
class FlaxMistralDecoderLayer(nn.Module):
config: MistralConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.input_layernorm = FlaxMistralRMSNorm(self.config, dtype=self.dtype)
self.self_attn = FlaxMistralAttention(self.config, dtype=self.dtype)
self.post_attention_layernorm = FlaxMistralRMSNorm(self.config, dtype=self.dtype)
self.mlp = FlaxMistralMLP(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
outputs = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
# residual connection
attn_output = outputs[0]
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + hidden_states
return (hidden_states,) + outputs[1:]
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Mistral, GPT_NEO->MISTRAL, transformer->model
class FlaxMistralPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MistralConfig
base_model_prefix = "model"
module_class: nn.Module = None
def __init__(
self,
config: MistralConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length))
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
params: dict = None,
past_key_values: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
batch_size, sequence_length = input_ids.shape
if position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxMistralAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
False,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
return outputs
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Mistral
class FlaxMistralLayerCollection(nn.Module):
config: MistralConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.blocks = [
FlaxMistralDecoderLayer(self.config, dtype=self.dtype, name=str(i))
for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = False,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for block in self.blocks:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
# this contains possible `None` values - `FlaxMistralModule` will filter them out
outputs = (hidden_states, all_hidden_states, all_attentions)
return outputs
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Mistral
class FlaxMistralModule(nn.Module):
config: MistralConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.hidden_size = self.config.hidden_size
embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
self.embed_tokens = nn.Embed(
self.config.vocab_size,
self.hidden_size,
embedding_init=embedding_init,
dtype=self.dtype,
)
self.layers = FlaxMistralLayerCollection(self.config, dtype=self.dtype)
self.norm = FlaxMistralRMSNorm(self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
deterministic=True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
input_embeds = self.embed_tokens(input_ids.astype("i4"))
outputs = self.layers(
input_embeds,
position_ids=position_ids,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=outputs[1],
attentions=outputs[-1],
)
@add_start_docstrings(
"The bare Mistral Model transformer outputting raw hidden-states without any specific head on top.",
MISTRAL_START_DOCSTRING,
)
class FlaxMistralModel(FlaxMistralPreTrainedModel):
module_class = FlaxMistralModule
append_call_sample_docstring(
FlaxMistralModel,
_CHECKPOINT_FOR_DOC,
FlaxBaseModelOutputWithPast,
_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Mistral
class FlaxMistralForCausalLMModule(nn.Module):
config: MistralConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.model = FlaxMistralModule(self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
outputs = self.model(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + outputs[1:]
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
@add_start_docstrings(
"""
The Mistral Model transformer with a language modeling head (linear layer) on top.
""",
MISTRAL_START_DOCSTRING,
)
# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Mistral
class FlaxMistralForCausalLM(FlaxMistralPreTrainedModel):
module_class = FlaxMistralForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since Mistral uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring(
FlaxMistralForCausalLM,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)
......@@ -898,6 +898,27 @@ class FlaxMBartPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxMistralForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxMistralModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxMistralPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxMT5EncoderModel(metaclass=DummyObject):
_backends = ["flax"]
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import MistralConfig, is_flax_available, is_tokenizers_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import jax.numpy as jnp
from transformers.models.mistral.modeling_flax_mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
)
if is_tokenizers_available():
from transformers import LlamaTokenizerFast
class FlaxMistralModelTester:
def __init__(
self,
parent,
batch_size=2,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
window_size=7,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.window_size = window_size
self.initializer_range = initializer_range
self.scope = None
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = np.tril(np.ones((self.batch_size, self.seq_length)))
config = MistralConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
use_cache=True,
is_decoder=False,
initializer_range=self.initializer_range,
sliding_window=self.window_size,
)
config.pad_token_id = config.eos_token_id
return (config, input_ids, input_mask)
# Copied from tests.models.gpt_neo.test_modeling_flax_gpt_neo.FlaxGPTNeoModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
# Copied from tests.models.gpt_neo.test_modeling_flax_gpt_neo.FlaxGPTNeoModelTester.check_use_cache_forward
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
attention_mask=attention_mask,
past_key_values=outputs_cache.past_key_values,
position_ids=position_ids,
)
outputs = model(input_ids)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
# Copied from tests.models.gpt_neo.test_modeling_flax_gpt_neo.FlaxGPTNeoModelTester.check_use_cache_forward_with_attn_mask
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
attention_mask_cache = jnp.concatenate(
[attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
axis=-1,
)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask_cache,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
past_key_values=outputs_cache.past_key_values,
attention_mask=attention_mask_cache,
position_ids=position_ids,
)
outputs = model(input_ids, attention_mask=attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
@require_flax
class FlaxMistralModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
all_model_classes = (FlaxMistralModel, FlaxMistralForCausalLM) if is_flax_available() else ()
all_generative_model_classes = (FlaxMistralForCausalLM,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxMistralModelTester(self)
def test_use_cache_forward(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)
def test_use_cache_forward_with_attn_mask(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward_with_attn_mask(
model_class_name, config, input_ids, attention_mask
)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("mistralai/Mistral-7B-v0.1", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
@slow
@require_flax
class FlaxMistralIntegrationTest(unittest.TestCase):
def setUp(self):
self.model_id = "mistralai/Mistral-7B-v0.1"
self.model = FlaxMistralForCausalLM.from_pretrained(self.model_id, from_pt=True)
self.test_batch = jnp.arange(32).reshape(4, 8) + 1911
def test_model_logits(self):
input_ids = jnp.array([[1, 306, 4658, 278, 6593, 310, 2834, 338]])
EXPECTED_MEAN = np.array([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
EXPECTED_SLICE = np.array([-5.8781,-5.8616,-0.1052,-4.7200,-5.8781,-5.8774,-5.8773,-5.8777,-5.8781,-5.8780,-5.8781,-5.8779,-1.0787,1.7583,-5.8779,-5.8780,-5.8783,-5.8778,-5.8776,-5.8781,-5.8784,-5.8778,-5.8778,-5.8777,-5.8779,-5.8778,-5.8776,-5.8780,-5.8779,-5.8781]) # fmt: skip
flax_logits = self.model(input_ids).logits
diff_mean = jnp.abs(flax_logits.mean(-1) - EXPECTED_MEAN).max()
diff_slice = jnp.abs(flax_logits[0, 0, :30] - EXPECTED_SLICE).max()
self.assertAlmostEqual(diff_mean, 0, places=3)
self.assertAlmostEqual(diff_slice, 0, places=3)
def test_generated_text(self):
tokenizer = LlamaTokenizerFast.from_pretrained(self.model_id)
tokenizer.pad_token_id = 2
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"""
prompt = "My favourite condiment is "
inputs = tokenizer(prompt, return_tensors="np", truncation=True, padding=True)
generated_ids = self.model.generate(**inputs, max_new_tokens=20, temperature=0).sequences
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(generated_text, EXPECTED_TEXT_COMPLETION)
......@@ -239,6 +239,8 @@ OBJECTS_TO_IGNORE = [
"FlaxMBartModel",
"FlaxMarianMTModel",
"FlaxMarianModel",
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxOPTForCausalLM",
"FlaxPegasusForConditionalGeneration",
"FlaxPegasusModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment