Unverified Commit f7076cd3 authored by Kian Sierra McGettigan's avatar Kian Sierra McGettigan Committed by GitHub
Browse files

Flax mistral (#26943)

* direct copy from llama work

* mistral modules forward pass working

* flax mistral forward pass with sliding window

* added tests

* added layer collection approach

* Revert "added layer collection approach"

This reverts commit 0e2905bf2236ec323163fc1a9f0c016b21aa8b8f.

* Revert "Revert "added layer collection approach""

This reverts commit fb17b6187ac5d16da7c461e1130514dc3d137a43.

* fixed attention outputs

* added mistral to init and auto

* fixed import name

* fixed layernorm weight dtype

* freeze initialized weights

* make sure conversion consideres bfloat16

* added backend

* added docstrings

* added cache

* fixed sliding window causal mask

* passes cache tests

* passed all tests

* applied make style

* removed commented out code

* applied fix-copies ignored other model changes

* applied make fix-copies

* removed unused functions

* passed generation integration test

* slow tests pass

* fixed slow tests

* changed default dtype from jax.numpy.float32 to float32 for docstring check

* skip cache test  for FlaxMistralForSequenceClassification since if pad_token_id in input_ids it doesn't score previous input_ids

* updated checkpoint since from_pt not included

* applied black style

* removed unused args

* Applied styling and fixup

* changed checkpoint for doc back

* fixed rf after adding it to hf hub

* Add dummy ckpt

* applied styling

* added tokenizer to new ckpt

* fixed slice format

* fix init and slice

* changed ref for placeholder TODO

* added copies from Llama

* applied styling

* applied fix-copies

* fixed docs

* update weight dtype reconversion for sharded weights

* removed Nullable input ids

* Removed unnecessary output attentions in Module

* added embedding weight initialziation

* removed unused past_key_values

* fixed deterministic

* Fixed RMS Norm and added copied from

* removed input_embeds

* applied make style

* removed nullable input ids from sequence classification model

* added copied from GPTJ

* added copied from Llama on FlaxMistralDecoderLayer

* added copied from to FlaxMistralPreTrainedModel methods

* fix test deprecation warning

* freeze gpt neox random_params and fix copies

* applied make style

* fixed doc issue

* skipped docstring test to allign # copied from

* applied make style

* removed FlaxMistralForSequenceClassification

* removed unused padding_idx

* removed more sequence classification

* removed sequence classification

* applied styling and consistency

* added copied from in tests

* removed sequence classification test logic

* applied styling

* applied make style

* removed freeze and fixed copies

* undo test change

* changed repeat_kv to tile

* fixed to key value groups

* updated copyright year

* split casual_mask

* empty to rerun failed pt_flax_equivalence test FlaxWav2Vec2ModelTest

* went back to 2023 for tests_pr_documentation_tests

* went back to 2024

* changed tile to repeat

* applied make style

* empty for retry on Wav2Vec2
parent 7a496100
...@@ -190,7 +190,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -190,7 +190,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Megatron-BERT](model_doc/megatron-bert) | ✅ | ❌ | ❌ | | [Megatron-BERT](model_doc/megatron-bert) | ✅ | ❌ | ❌ |
| [Megatron-GPT2](model_doc/megatron_gpt2) | ✅ | ✅ | ✅ | | [Megatron-GPT2](model_doc/megatron_gpt2) | ✅ | ✅ | ✅ |
| [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ | | [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ |
| [Mistral](model_doc/mistral) | ✅ | ❌ | | | [Mistral](model_doc/mistral) | ✅ | ❌ | |
| [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ | | [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ |
| [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ | | [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ |
| [MMS](model_doc/mms) | ✅ | ✅ | ✅ | | [MMS](model_doc/mms) | ✅ | ✅ | ✅ |
......
...@@ -149,3 +149,13 @@ Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Sin ...@@ -149,3 +149,13 @@ Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Sin
[[autodoc]] MistralForSequenceClassification [[autodoc]] MistralForSequenceClassification
- forward - forward
## FlaxMistralModel
[[autodoc]] FlaxMistralModel
- __call__
## FlaxMistralForCausalLM
[[autodoc]] FlaxMistralForCausalLM
- __call__
...@@ -4678,6 +4678,13 @@ else: ...@@ -4678,6 +4678,13 @@ else:
"FlaxMBartPreTrainedModel", "FlaxMBartPreTrainedModel",
] ]
) )
_import_structure["models.mistral"].extend(
[
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxMistralPreTrainedModel",
]
)
_import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]) _import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
_import_structure["models.opt"].extend( _import_structure["models.opt"].extend(
[ [
...@@ -8830,6 +8837,11 @@ if TYPE_CHECKING: ...@@ -8830,6 +8837,11 @@ if TYPE_CHECKING:
FlaxMBartModel, FlaxMBartModel,
FlaxMBartPreTrainedModel, FlaxMBartPreTrainedModel,
) )
from .models.mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
FlaxMistralPreTrainedModel,
)
from .models.mt5 import ( from .models.mt5 import (
FlaxMT5EncoderModel, FlaxMT5EncoderModel,
FlaxMT5ForConditionalGeneration, FlaxMT5ForConditionalGeneration,
......
...@@ -255,7 +255,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): ...@@ -255,7 +255,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
# load using msgpack utils # load using msgpack utils
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(shard_file, **weights_only_kwarg) pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
}
model_prefix = flax_model.base_model_prefix model_prefix = flax_model.base_model_prefix
...@@ -278,6 +281,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): ...@@ -278,6 +281,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
# Need to change some parameters name to match Flax names # Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items(): for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split(".")) pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16
# remove base model prefix if necessary # remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix has_base_model_prefix = pt_tuple_key[0] == model_prefix
...@@ -314,11 +318,15 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): ...@@ -314,11 +318,15 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
continue continue
# also add unexpected weight so that warning is thrown # also add unexpected weight so that warning is thrown
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
else: else:
# also add unexpected weight so that warning is thrown # also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor) flax_state_dict[flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
return unflatten_dict(flax_state_dict) return unflatten_dict(flax_state_dict)
......
...@@ -47,6 +47,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -47,6 +47,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("longt5", "FlaxLongT5Model"), ("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"), ("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"), ("mbart", "FlaxMBartModel"),
("mistral", "FlaxMistralModel"),
("mt5", "FlaxMT5Model"), ("mt5", "FlaxMT5Model"),
("opt", "FlaxOPTModel"), ("opt", "FlaxOPTModel"),
("pegasus", "FlaxPegasusModel"), ("pegasus", "FlaxPegasusModel"),
...@@ -148,6 +149,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -148,6 +149,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gpt_neo", "FlaxGPTNeoForCausalLM"), ("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"), ("gptj", "FlaxGPTJForCausalLM"),
("llama", "FlaxLlamaForCausalLM"), ("llama", "FlaxLlamaForCausalLM"),
("mistral", "FlaxMistralForCausalLM"),
("opt", "FlaxOPTForCausalLM"), ("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"), ("roberta", "FlaxRobertaForCausalLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
......
...@@ -13,11 +13,7 @@ ...@@ -13,11 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = { _import_structure = {
...@@ -38,6 +34,18 @@ else: ...@@ -38,6 +34,18 @@ else:
"MistralForSequenceClassification", "MistralForSequenceClassification",
] ]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_mistral"] = [
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxMistralPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig from .configuration_mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig
...@@ -55,6 +63,18 @@ if TYPE_CHECKING: ...@@ -55,6 +63,18 @@ if TYPE_CHECKING:
MistralPreTrainedModel, MistralPreTrainedModel,
) )
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
FlaxMistralPreTrainedModel,
)
else: else:
import sys import sys
......
This diff is collapsed.
...@@ -898,6 +898,27 @@ class FlaxMBartPreTrainedModel(metaclass=DummyObject): ...@@ -898,6 +898,27 @@ class FlaxMBartPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxMistralForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxMistralModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxMistralPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxMT5EncoderModel(metaclass=DummyObject): class FlaxMT5EncoderModel(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import MistralConfig, is_flax_available, is_tokenizers_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import jax.numpy as jnp
from transformers.models.mistral.modeling_flax_mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
)
if is_tokenizers_available():
from transformers import LlamaTokenizerFast
class FlaxMistralModelTester:
def __init__(
self,
parent,
batch_size=2,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
window_size=7,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.window_size = window_size
self.initializer_range = initializer_range
self.scope = None
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = np.tril(np.ones((self.batch_size, self.seq_length)))
config = MistralConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
use_cache=True,
is_decoder=False,
initializer_range=self.initializer_range,
sliding_window=self.window_size,
)
config.pad_token_id = config.eos_token_id
return (config, input_ids, input_mask)
# Copied from tests.models.gpt_neo.test_modeling_flax_gpt_neo.FlaxGPTNeoModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
# Copied from tests.models.gpt_neo.test_modeling_flax_gpt_neo.FlaxGPTNeoModelTester.check_use_cache_forward
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
attention_mask=attention_mask,
past_key_values=outputs_cache.past_key_values,
position_ids=position_ids,
)
outputs = model(input_ids)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
# Copied from tests.models.gpt_neo.test_modeling_flax_gpt_neo.FlaxGPTNeoModelTester.check_use_cache_forward_with_attn_mask
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
attention_mask_cache = jnp.concatenate(
[attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
axis=-1,
)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask_cache,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
past_key_values=outputs_cache.past_key_values,
attention_mask=attention_mask_cache,
position_ids=position_ids,
)
outputs = model(input_ids, attention_mask=attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
@require_flax
class FlaxMistralModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
all_model_classes = (FlaxMistralModel, FlaxMistralForCausalLM) if is_flax_available() else ()
all_generative_model_classes = (FlaxMistralForCausalLM,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxMistralModelTester(self)
def test_use_cache_forward(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)
def test_use_cache_forward_with_attn_mask(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward_with_attn_mask(
model_class_name, config, input_ids, attention_mask
)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("mistralai/Mistral-7B-v0.1", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
@slow
@require_flax
class FlaxMistralIntegrationTest(unittest.TestCase):
def setUp(self):
self.model_id = "mistralai/Mistral-7B-v0.1"
self.model = FlaxMistralForCausalLM.from_pretrained(self.model_id, from_pt=True)
self.test_batch = jnp.arange(32).reshape(4, 8) + 1911
def test_model_logits(self):
input_ids = jnp.array([[1, 306, 4658, 278, 6593, 310, 2834, 338]])
EXPECTED_MEAN = np.array([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
EXPECTED_SLICE = np.array([-5.8781,-5.8616,-0.1052,-4.7200,-5.8781,-5.8774,-5.8773,-5.8777,-5.8781,-5.8780,-5.8781,-5.8779,-1.0787,1.7583,-5.8779,-5.8780,-5.8783,-5.8778,-5.8776,-5.8781,-5.8784,-5.8778,-5.8778,-5.8777,-5.8779,-5.8778,-5.8776,-5.8780,-5.8779,-5.8781]) # fmt: skip
flax_logits = self.model(input_ids).logits
diff_mean = jnp.abs(flax_logits.mean(-1) - EXPECTED_MEAN).max()
diff_slice = jnp.abs(flax_logits[0, 0, :30] - EXPECTED_SLICE).max()
self.assertAlmostEqual(diff_mean, 0, places=3)
self.assertAlmostEqual(diff_slice, 0, places=3)
def test_generated_text(self):
tokenizer = LlamaTokenizerFast.from_pretrained(self.model_id)
tokenizer.pad_token_id = 2
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"""
prompt = "My favourite condiment is "
inputs = tokenizer(prompt, return_tensors="np", truncation=True, padding=True)
generated_ids = self.model.generate(**inputs, max_new_tokens=20, temperature=0).sequences
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(generated_text, EXPECTED_TEXT_COMPLETION)
...@@ -239,6 +239,8 @@ OBJECTS_TO_IGNORE = [ ...@@ -239,6 +239,8 @@ OBJECTS_TO_IGNORE = [
"FlaxMBartModel", "FlaxMBartModel",
"FlaxMarianMTModel", "FlaxMarianMTModel",
"FlaxMarianModel", "FlaxMarianModel",
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxOPTForCausalLM", "FlaxOPTForCausalLM",
"FlaxPegasusForConditionalGeneration", "FlaxPegasusForConditionalGeneration",
"FlaxPegasusModel", "FlaxPegasusModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment