"vscode:/vscode.git/clone" did not exist on "97d3390fc8edb210fcf0aad6a079406b018655b9"
Unverified Commit 75336c17 authored by Alex McKinney's avatar Alex McKinney Committed by GitHub
Browse files

Add Llama Flax Implementation (#24587)

* Copies `modeling_flax_gpt_neo.py` to start

* MLP Block. WIP Attention and Block

* Adds Flax implementation of `LlamaMLP`
Validated with in-file test.
Some slight numeric differences, but assuming it isn't an issue

* Adds `FlaxLlamaRMSNorm` layer
`flax.linen` includes `RMSNorm` layer but not necessarily in all
versions. Hence, we add in-file.

* Adds FlaxLlamaAttention
Copied from GPT-J as it has efficient caching implementation as well as
rotary embeddings.
Notice numerically different, but not by a huge amount. Needs
investigating

* Adds `FlaxLlamaDecoderLayer`
numerically inaccurate, debugging..

* debugging rotary mismatch
gptj uses interleaved whilst llama uses contiguous
i think they match now but still final result is wrong.
maybe drop back to just debugging attention layer?

* fixes bug with decoder layer
still somewhat numerically inaccurate, but close enough for now

* adds markers for what to implement next
the structure here diverges a lot from the PT version.
not a big fan of it, but just get something working for now

* implements `FlaxLlamaBlockCollection`]
tolerance must be higher than expected, kinda disconcerting

* Adds `FlaxLlamaModule`
equivalent PyTorch model is `LlamaModel`
yay! a language model🤗

* adds `FlaxLlamaForCausalLMModule`
equivalent to `LlamaForCausalLM`
still missing returning dict or tuple, will add later

* start porting pretrained wrappers
realised it probably needs return dict as a prereq

* cleanup, quality, style

* readds `return_dict` and model output named tuples

* (tentatively) pretrained wrappers work 🔥

* fixes numerical mismatch in `FlaxLlamaRMSNorm`
seems `jax.lax.rsqrt` does not match `torch.sqrt`.
manually computing `1 / jax.numpy.sqrt` results in matching values.

* [WIP] debugging numerics

* numerical match
I think issue was accidental change of backend. forcing CPU fixes test.
We expect some mismatch on GPU.

* adds in model and integration tests for Flax Llama
summary of failing:
- mul invalid combination of dimensions
- one numerical mismatch
- bf16 conversion (maybe my local backend issue)
- params are not FrozenDict

* adds missing TYPE_CHECKING import and `make fixup`

* adds back missing docstrings
needs review on quality of docstrings, not sure what is required.
Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO

* commenting out equivalence test as can just use common

* debugging

* Fixes bug where mask and pos_ids were swapped in pretrained models
This results in all tests passing now 🔥



* cleanup of modeling file

* cleanup of test file

* Resolving simpler review comments

* addresses more minor review comments

* fixing introduced pytest errors from review

* wip additional slow tests

* wip tests
need to grab a GPU machine to get real logits for comparison
otherwise, slow tests should be okay

* `make quality`, `make style`

* adds slow integration tests
- checking logits
- checking hidden states
- checking generation outputs

* `make fix-copies`

* fix mangled function following `make fix-copies`

* adds missing type checking imports

* fixes missing parameter checkpoint warning

* more finegrained 'Copied from' tags
avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING`

* swaps import guards
??? how did these get swapped initially?

* removing `inv_freq` again as pytorch version has now removed

* attempting to get CI to pass

* adds doc entries for llama flax models

* fixes typo in __init__.py imports

* adds back special equivalence tests
these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version

* overrides tests with dummy to see if CI passes
need to fill in these tests later

* adds my contribution to docs

* `make style; make quality`

* replaces random masking with fixed to work with flax version

* `make quality; make style`

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* updates `x`->`tensor` in `rotate_half`

* addresses smaller review comments

* Update docs/source/en/model_doc/llama.md
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* adds integration test class

* adds `dtype` to rotary embedding to cast outputs

* adds type to flax llama rotary layer

* `make style`

* `make fix-copies`

* Apply suggestions from code review
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* applies suggestions from review

* Update modeling_flax_llama.py

* `make fix-copies`

* Update tests/models/llama/test_modeling_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* fixes shape mismatch in FlaxLlamaMLP

* applies some suggestions from reviews

* casts attn output logits to f32 regardless of dtype

* adds attn bias using `LlamaConfig.attention_bias`

* adds Copied From comments to Flax Llama test

* mistral and persimmon test change -copy from llama

* updates docs index

* removes Copied from in tests

it was preventing `make fix-copies` from succeeding

* quality and style

* ignores FlaxLlama input docstring

* adds revision to `_CHECKPOINT_FOR_DOC`

* repo consistency and quality

* removes unused import

* removes copied from from Phi test

now diverges from llama tests following FlaxLlama changes

* adds `_REAL_CHECKPOINT_FOR_DOC`

* removes refs from pr tests

* reformat to make ruff happy

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 7fc80724
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
......@@ -94,7 +94,7 @@ Flax), PyTorch, and/or TensorFlow.
| [CLIPSeg](model_doc/clipseg) | ✅ | ❌ | ❌ |
| [CLVP](model_doc/clvp) | ✅ | ❌ | ❌ |
| [CodeGen](model_doc/codegen) | ✅ | ❌ | ❌ |
| [CodeLlama](model_doc/code_llama) | ✅ | ❌ | |
| [CodeLlama](model_doc/code_llama) | ✅ | ❌ | |
| [Conditional DETR](model_doc/conditional_detr) | ✅ | ❌ | ❌ |
| [ConvBERT](model_doc/convbert) | ✅ | ✅ | ❌ |
| [ConvNeXT](model_doc/convnext) | ✅ | ✅ | ❌ |
......@@ -167,8 +167,8 @@ Flax), PyTorch, and/or TensorFlow.
| [LED](model_doc/led) | ✅ | ✅ | ❌ |
| [LeViT](model_doc/levit) | ✅ | ❌ | ❌ |
| [LiLT](model_doc/lilt) | ✅ | ❌ | ❌ |
| [LLaMA](model_doc/llama) | ✅ | ❌ | |
| [Llama2](model_doc/llama2) | ✅ | ❌ | |
| [LLaMA](model_doc/llama) | ✅ | ❌ | |
| [Llama2](model_doc/llama2) | ✅ | ❌ | |
| [Longformer](model_doc/longformer) | ✅ | ✅ | ❌ |
| [LongT5](model_doc/longt5) | ✅ | ❌ | ✅ |
| [LUKE](model_doc/luke) | ✅ | ❌ | ❌ |
......
......@@ -50,6 +50,9 @@ come in several checkpoints they each contain a part of each weight of the model
- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.
This model was contributed by [zphang](https://huggingface.co/zphang) with contributions from [BlackSamorez](https://huggingface.co/BlackSamorez). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox). The original code of the authors can be found [here](https://github.com/facebookresearch/llama). The Flax version of the implementation was contributed by [afmck](https://huggingface.co/afmck) with the code in the implementation based on Hugging Face's Flax GPT-Neo.
Based on the original LLaMA model, Meta AI has released some follow-up works:
- **Llama2**: Llama2 is an improved version of Llama with some architectural tweaks (Grouped Query Attention), and is pre-trained on 2Trillion tokens. Refer to the documentation of Llama2 which can be found [here](llama2).
......@@ -112,3 +115,13 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] LlamaForSequenceClassification
- forward
## FlaxLlamaModel
[[autodoc]] FlaxLlamaModel
- __call__
## FlaxLlamaForCausalLM
[[autodoc]] FlaxLlamaForCausalLM
- __call__
......@@ -4554,6 +4554,7 @@ else:
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
)
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
_import_structure["models.llama"].extend(["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"])
_import_structure["models.longt5"].extend(
[
"FlaxLongT5ForConditionalGeneration",
......@@ -8631,6 +8632,11 @@ if TYPE_CHECKING:
FlaxGPTJModel,
FlaxGPTJPreTrainedModel,
)
from .models.llama import (
FlaxLlamaForCausalLM,
FlaxLlamaModel,
FlaxLlamaPreTrainedModel,
)
from .models.longt5 import (
FlaxLongT5ForConditionalGeneration,
FlaxLongT5Model,
......
......@@ -1267,7 +1267,9 @@ def overwrite_call_docstring(model_class, docstring):
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
def append_call_sample_docstring(model_class, checkpoint, output_type, config_class, mask=None, revision=None):
def append_call_sample_docstring(
model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = add_code_sample_docstrings(
checkpoint=checkpoint,
......@@ -1275,6 +1277,7 @@ def append_call_sample_docstring(model_class, checkpoint, output_type, config_cl
config_class=config_class,
model_cls=model_class.__name__,
revision=revision,
real_checkpoint=real_checkpoint,
)(model_class.__call__)
......
......@@ -43,6 +43,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
("gptj", "FlaxGPTJModel"),
("llama", "FlaxLlamaModel"),
("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
......@@ -146,6 +147,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("llama", "FlaxLlamaForCausalLM"),
("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
......
......@@ -489,7 +489,7 @@ class BloomPreTrainedModel(PreTrainedModel):
@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
......
......@@ -54,7 +54,7 @@ logger = logging.get_logger(__name__)
def make_list_of_list_of_images(
images: Union[List[List[ImageInput]], List[ImageInput], ImageInput]
images: Union[List[List[ImageInput]], List[ImageInput], ImageInput],
) -> List[List[ImageInput]]:
if is_valid_image(images):
return [[images]]
......
......@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
......@@ -55,6 +56,14 @@ else:
"LlamaForSequenceClassification",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_llama"] = ["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"]
if TYPE_CHECKING:
from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
......@@ -83,6 +92,14 @@ if TYPE_CHECKING:
else:
from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_llama import FlaxLlamaForCausalLM, FlaxLlamaModel, FlaxLlamaPreTrainedModel
else:
import sys
......
This diff is collapsed.
......@@ -265,7 +265,7 @@ class MptPreTrainedModel(PreTrainedModel):
@staticmethod
def _convert_to_mpt_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...]))
......
......@@ -800,6 +800,27 @@ class FlaxGPTJPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxLlamaForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLlamaModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLlamaPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLongT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import LlamaConfig, is_flax_available, is_tokenizers_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import jax.numpy as jnp
from transformers.models.llama.modeling_flax_llama import FlaxLlamaForCausalLM, FlaxLlamaModel
if is_tokenizers_available():
from transformers import LlamaTokenizerFast
class FlaxLlamaModelTester:
def __init__(
self,
parent,
batch_size=2,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=64,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
window_size=7,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.window_size = window_size
self.initializer_range = initializer_range
self.scope = None
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = np.tril(np.ones((self.batch_size, self.seq_length)))
config = LlamaConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
use_cache=True,
is_decoder=False,
initializer_range=self.initializer_range,
)
return (config, input_ids, input_mask)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
attention_mask=attention_mask,
past_key_values=outputs_cache.past_key_values,
position_ids=position_ids,
)
outputs = model(input_ids)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
attention_mask_cache = jnp.concatenate(
[attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
axis=-1,
)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask_cache,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
past_key_values=outputs_cache.past_key_values,
attention_mask=attention_mask_cache,
position_ids=position_ids,
)
outputs = model(input_ids, attention_mask=attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
@require_flax
class FlaxLlamaModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
all_model_classes = (FlaxLlamaModel, FlaxLlamaForCausalLM) if is_flax_available() else ()
all_generative_model_classes = (FlaxLlamaForCausalLM,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxLlamaModelTester(self)
def test_use_cache_forward(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)
def test_use_cache_forward_with_attn_mask(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward_with_attn_mask(
model_class_name, config, input_ids, attention_mask
)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("openlm-research/open_llama_3b_v2", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
@slow
@require_flax
class FlaxLlamaIntegrationTest(unittest.TestCase):
def setUp(self):
self.model_id = "openlm-research/open_llama_3b_v2"
self.model = FlaxLlamaForCausalLM.from_pretrained(self.model_id, from_pt=True)
self.test_batch = jnp.arange(32).reshape(4, 8) + 1911
def test_model_logits(self):
flax_logits = self.model(self.test_batch).logits
# fmt: off
EXPECTED_LOGITS = [-74.4243, -74.0680, -65.2507, -79.1658, -77.7460, -69.2379, -86.4588, -84.8933, -77.8456]
EXPECTED_MIN, EXPECTED_MAX, EXPECTED_MEAN = -96.9952
EXPECTED_MAX = -18.4571
EXPECTED_MEAN = -65.0608
# fmt: on
self.assertTrue(np.allclose(flax_logits[0, :3, :3].flatten(), EXPECTED_LOGITS, atol=1e-4))
self.assertAlmostEqual(flax_logits.min(), EXPECTED_MIN, places=3)
self.assertAlmostEqual(flax_logits.max(), EXPECTED_MAX, places=3)
self.assertAlmostEqual(flax_logits.mean(), EXPECTED_MEAN, places=3)
def test_model_hidden_states(self):
flax_hidden_states = self.model(self.test_batch, output_hidden_states=True).hidden_states
flax_hidden_means = [h.mean() for h in flax_hidden_states]
# fmt: off
EXPECTED_HIDDEN_MEANS = [
-0.00007,-0.00049,-0.00169,-0.00253,-0.00271,
-0.00290,-0.00252,0.00230,0.00230,0.00198,
0.00196,0.00174,0.00246,0.00205,0.00242,
0.00171,0.00092,0.00054,0.00102,0.00024,
0.00029,0.00037,-0.00101,-0.00062,-0.00341,-0.00636,-0.00357
]
# fmt: on
self.assertTrue(np.allclose(flax_hidden_means, EXPECTED_HIDDEN_MEANS, atol=1e-4))
def test_generated_text(self):
tokenizer = LlamaTokenizerFast.from_pretrained(self.model_id)
tokenizer.pad_token_id = 2
test_batch = ["Aloha, World! ", "2 + 2 = ", "Paris is the capital of ", "我很高興認識"]
inputs = tokenizer(test_batch, return_tensors="np", truncation=True, padding=True)
generated_ids = self.model.generate(**inputs, max_length=15).sequences
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# fmt: off
EXPECTED_GENERATION = [
"Aloha, World! 201",
"2 + 2 = 4\n2",
"Paris is the capital of Île-",
"我很高興認識你,我"
]
# fmt: on
self.assertListEqual(generated_text, EXPECTED_GENERATION)
......@@ -14,7 +14,6 @@
# limitations under the License.
""" Testing suite for the PyTorch LLaMA model. """
import unittest
import pytest
......@@ -33,7 +32,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
......@@ -105,7 +104,7 @@ class LlamaModelTester:
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -34,7 +34,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
......@@ -107,7 +107,7 @@ class MistralModelTester:
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -32,7 +32,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
......@@ -104,7 +104,7 @@ class PersimmonModelTester:
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -38,7 +38,6 @@ if is_torch_available():
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Phi
class PhiModelTester:
def __init__(
self,
......
......@@ -233,6 +233,8 @@ OBJECTS_TO_IGNORE = [
"FlaxGPTJModel",
"FlaxGPTNeoForCausalLM",
"FlaxGPTNeoModel",
"FlaxLlamaForCausalLM",
"FlaxLlamaModel",
"FlaxMBartForConditionalGeneration",
"FlaxMBartForQuestionAnswering",
"FlaxMBartForSequenceClassification",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment