Unverified Commit 75336c17 authored by Alex McKinney's avatar Alex McKinney Committed by GitHub
Browse files

Add Llama Flax Implementation (#24587)

* Copies `modeling_flax_gpt_neo.py` to start

* MLP Block. WIP Attention and Block

* Adds Flax implementation of `LlamaMLP`
Validated with in-file test.
Some slight numeric differences, but assuming it isn't an issue

* Adds `FlaxLlamaRMSNorm` layer
`flax.linen` includes `RMSNorm` layer but not necessarily in all
versions. Hence, we add in-file.

* Adds FlaxLlamaAttention
Copied from GPT-J as it has efficient caching implementation as well as
rotary embeddings.
Notice numerically different, but not by a huge amount. Needs
investigating

* Adds `FlaxLlamaDecoderLayer`
numerically inaccurate, debugging..

* debugging rotary mismatch
gptj uses interleaved whilst llama uses contiguous
i think they match now but still final result is wrong.
maybe drop back to just debugging attention layer?

* fixes bug with decoder layer
still somewhat numerically inaccurate, but close enough for now

* adds markers for what to implement next
the structure here diverges a lot from the PT version.
not a big fan of it, but just get something working for now

* implements `FlaxLlamaBlockCollection`]
tolerance must be higher than expected, kinda disconcerting

* Adds `FlaxLlamaModule`
equivalent PyTorch model is `LlamaModel`
yay! a language model🤗

* adds `FlaxLlamaForCausalLMModule`
equivalent to `LlamaForCausalLM`
still missing returning dict or tuple, will add later

* start porting pretrained wrappers
realised it probably needs return dict as a prereq

* cleanup, quality, style

* readds `return_dict` and model output named tuples

* (tentatively) pretrained wrappers work 🔥

* fixes numerical mismatch in `FlaxLlamaRMSNorm`
seems `jax.lax.rsqrt` does not match `torch.sqrt`.
manually computing `1 / jax.numpy.sqrt` results in matching values.

* [WIP] debugging numerics

* numerical match
I think issue was accidental change of backend. forcing CPU fixes test.
We expect some mismatch on GPU.

* adds in model and integration tests for Flax Llama
summary of failing:
- mul invalid combination of dimensions
- one numerical mismatch
- bf16 conversion (maybe my local backend issue)
- params are not FrozenDict

* adds missing TYPE_CHECKING import and `make fixup`

* adds back missing docstrings
needs review on quality of docstrings, not sure what is required.
Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO

* commenting out equivalence test as can just use common

* debugging

* Fixes bug where mask and pos_ids were swapped in pretrained models
This results in all tests passing now 🔥



* cleanup of modeling file

* cleanup of test file

* Resolving simpler review comments

* addresses more minor review comments

* fixing introduced pytest errors from review

* wip additional slow tests

* wip tests
need to grab a GPU machine to get real logits for comparison
otherwise, slow tests should be okay

* `make quality`, `make style`

* adds slow integration tests
- checking logits
- checking hidden states
- checking generation outputs

* `make fix-copies`

* fix mangled function following `make fix-copies`

* adds missing type checking imports

* fixes missing parameter checkpoint warning

* more finegrained 'Copied from' tags
avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING`

* swaps import guards
??? how did these get swapped initially?

* removing `inv_freq` again as pytorch version has now removed

* attempting to get CI to pass

* adds doc entries for llama flax models

* fixes typo in __init__.py imports

* adds back special equivalence tests
these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version

* overrides tests with dummy to see if CI passes
need to fill in these tests later

* adds my contribution to docs

* `make style; make quality`

* replaces random masking with fixed to work with flax version

* `make quality; make style`

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* updates `x`->`tensor` in `rotate_half`

* addresses smaller review comments

* Update docs/source/en/model_doc/llama.md
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* adds integration test class

* adds `dtype` to rotary embedding to cast outputs

* adds type to flax llama rotary layer

* `make style`

* `make fix-copies`

* Apply suggestions from code review
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* applies suggestions from review

* Update modeling_flax_llama.py

* `make fix-copies`

* Update tests/models/llama/test_modeling_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* fixes shape mismatch in FlaxLlamaMLP

* applies some suggestions from reviews

* casts attn output logits to f32 regardless of dtype

* adds attn bias using `LlamaConfig.attention_bias`

* adds Copied From comments to Flax Llama test

* mistral and persimmon test change -copy from llama

* updates docs index

* removes Copied from in tests

it was preventing `make fix-copies` from succeeding

* quality and style

* ignores FlaxLlama input docstring

* adds revision to `_CHECKPOINT_FOR_DOC`

* repo consistency and quality

* removes unused import

* removes copied from from Phi test

now diverges from llama tests following FlaxLlama changes

* adds `_REAL_CHECKPOINT_FOR_DOC`

* removes refs from pr tests

* reformat to make ruff happy

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 7fc80724
<!--Copyright 2020 The HuggingFace Team. All rights reserved. <!--Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at the License. You may obtain a copy of the License at
...@@ -94,7 +94,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -94,7 +94,7 @@ Flax), PyTorch, and/or TensorFlow.
| [CLIPSeg](model_doc/clipseg) | ✅ | ❌ | ❌ | | [CLIPSeg](model_doc/clipseg) | ✅ | ❌ | ❌ |
| [CLVP](model_doc/clvp) | ✅ | ❌ | ❌ | | [CLVP](model_doc/clvp) | ✅ | ❌ | ❌ |
| [CodeGen](model_doc/codegen) | ✅ | ❌ | ❌ | | [CodeGen](model_doc/codegen) | ✅ | ❌ | ❌ |
| [CodeLlama](model_doc/code_llama) | ✅ | ❌ | | | [CodeLlama](model_doc/code_llama) | ✅ | ❌ | |
| [Conditional DETR](model_doc/conditional_detr) | ✅ | ❌ | ❌ | | [Conditional DETR](model_doc/conditional_detr) | ✅ | ❌ | ❌ |
| [ConvBERT](model_doc/convbert) | ✅ | ✅ | ❌ | | [ConvBERT](model_doc/convbert) | ✅ | ✅ | ❌ |
| [ConvNeXT](model_doc/convnext) | ✅ | ✅ | ❌ | | [ConvNeXT](model_doc/convnext) | ✅ | ✅ | ❌ |
...@@ -167,8 +167,8 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -167,8 +167,8 @@ Flax), PyTorch, and/or TensorFlow.
| [LED](model_doc/led) | ✅ | ✅ | ❌ | | [LED](model_doc/led) | ✅ | ✅ | ❌ |
| [LeViT](model_doc/levit) | ✅ | ❌ | ❌ | | [LeViT](model_doc/levit) | ✅ | ❌ | ❌ |
| [LiLT](model_doc/lilt) | ✅ | ❌ | ❌ | | [LiLT](model_doc/lilt) | ✅ | ❌ | ❌ |
| [LLaMA](model_doc/llama) | ✅ | ❌ | | | [LLaMA](model_doc/llama) | ✅ | ❌ | |
| [Llama2](model_doc/llama2) | ✅ | ❌ | | | [Llama2](model_doc/llama2) | ✅ | ❌ | |
| [Longformer](model_doc/longformer) | ✅ | ✅ | ❌ | | [Longformer](model_doc/longformer) | ✅ | ✅ | ❌ |
| [LongT5](model_doc/longt5) | ✅ | ❌ | ✅ | | [LongT5](model_doc/longt5) | ✅ | ❌ | ✅ |
| [LUKE](model_doc/luke) | ✅ | ❌ | ❌ | | [LUKE](model_doc/luke) | ✅ | ❌ | ❌ |
......
...@@ -50,6 +50,9 @@ come in several checkpoints they each contain a part of each weight of the model ...@@ -50,6 +50,9 @@ come in several checkpoints they each contain a part of each weight of the model
- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string. - The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.
This model was contributed by [zphang](https://huggingface.co/zphang) with contributions from [BlackSamorez](https://huggingface.co/BlackSamorez). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox). The original code of the authors can be found [here](https://github.com/facebookresearch/llama). The Flax version of the implementation was contributed by [afmck](https://huggingface.co/afmck) with the code in the implementation based on Hugging Face's Flax GPT-Neo.
Based on the original LLaMA model, Meta AI has released some follow-up works: Based on the original LLaMA model, Meta AI has released some follow-up works:
- **Llama2**: Llama2 is an improved version of Llama with some architectural tweaks (Grouped Query Attention), and is pre-trained on 2Trillion tokens. Refer to the documentation of Llama2 which can be found [here](llama2). - **Llama2**: Llama2 is an improved version of Llama with some architectural tweaks (Grouped Query Attention), and is pre-trained on 2Trillion tokens. Refer to the documentation of Llama2 which can be found [here](llama2).
...@@ -112,3 +115,13 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h ...@@ -112,3 +115,13 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] LlamaForSequenceClassification [[autodoc]] LlamaForSequenceClassification
- forward - forward
## FlaxLlamaModel
[[autodoc]] FlaxLlamaModel
- __call__
## FlaxLlamaForCausalLM
[[autodoc]] FlaxLlamaForCausalLM
- __call__
...@@ -4554,6 +4554,7 @@ else: ...@@ -4554,6 +4554,7 @@ else:
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"] ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
) )
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"]) _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
_import_structure["models.llama"].extend(["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"])
_import_structure["models.longt5"].extend( _import_structure["models.longt5"].extend(
[ [
"FlaxLongT5ForConditionalGeneration", "FlaxLongT5ForConditionalGeneration",
...@@ -8631,6 +8632,11 @@ if TYPE_CHECKING: ...@@ -8631,6 +8632,11 @@ if TYPE_CHECKING:
FlaxGPTJModel, FlaxGPTJModel,
FlaxGPTJPreTrainedModel, FlaxGPTJPreTrainedModel,
) )
from .models.llama import (
FlaxLlamaForCausalLM,
FlaxLlamaModel,
FlaxLlamaPreTrainedModel,
)
from .models.longt5 import ( from .models.longt5 import (
FlaxLongT5ForConditionalGeneration, FlaxLongT5ForConditionalGeneration,
FlaxLongT5Model, FlaxLongT5Model,
......
...@@ -1267,7 +1267,9 @@ def overwrite_call_docstring(model_class, docstring): ...@@ -1267,7 +1267,9 @@ def overwrite_call_docstring(model_class, docstring):
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__) model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
def append_call_sample_docstring(model_class, checkpoint, output_type, config_class, mask=None, revision=None): def append_call_sample_docstring(
model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
):
model_class.__call__ = copy_func(model_class.__call__) model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = add_code_sample_docstrings( model_class.__call__ = add_code_sample_docstrings(
checkpoint=checkpoint, checkpoint=checkpoint,
...@@ -1275,6 +1277,7 @@ def append_call_sample_docstring(model_class, checkpoint, output_type, config_cl ...@@ -1275,6 +1277,7 @@ def append_call_sample_docstring(model_class, checkpoint, output_type, config_cl
config_class=config_class, config_class=config_class,
model_cls=model_class.__name__, model_cls=model_class.__name__,
revision=revision, revision=revision,
real_checkpoint=real_checkpoint,
)(model_class.__call__) )(model_class.__call__)
......
...@@ -43,6 +43,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -43,6 +43,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("gpt2", "FlaxGPT2Model"), ("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"), ("gpt_neo", "FlaxGPTNeoModel"),
("gptj", "FlaxGPTJModel"), ("gptj", "FlaxGPTJModel"),
("llama", "FlaxLlamaModel"),
("longt5", "FlaxLongT5Model"), ("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"), ("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"), ("mbart", "FlaxMBartModel"),
...@@ -146,6 +147,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -146,6 +147,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gpt2", "FlaxGPT2LMHeadModel"), ("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"), ("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"), ("gptj", "FlaxGPTJForCausalLM"),
("llama", "FlaxLlamaForCausalLM"),
("opt", "FlaxOPTForCausalLM"), ("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"), ("roberta", "FlaxRobertaForCausalLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
......
...@@ -489,7 +489,7 @@ class BloomPreTrainedModel(PreTrainedModel): ...@@ -489,7 +489,7 @@ class BloomPreTrainedModel(PreTrainedModel):
@staticmethod @staticmethod
def _convert_to_bloom_cache( def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
""" """
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
......
...@@ -54,7 +54,7 @@ logger = logging.get_logger(__name__) ...@@ -54,7 +54,7 @@ logger = logging.get_logger(__name__)
def make_list_of_list_of_images( def make_list_of_list_of_images(
images: Union[List[List[ImageInput]], List[ImageInput], ImageInput] images: Union[List[List[ImageInput]], List[ImageInput], ImageInput],
) -> List[List[ImageInput]]: ) -> List[List[ImageInput]]:
if is_valid_image(images): if is_valid_image(images):
return [[images]] return [[images]]
......
...@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING ...@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_flax_available,
is_sentencepiece_available, is_sentencepiece_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
...@@ -55,6 +56,14 @@ else: ...@@ -55,6 +56,14 @@ else:
"LlamaForSequenceClassification", "LlamaForSequenceClassification",
] ]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_llama"] = ["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
...@@ -83,6 +92,14 @@ if TYPE_CHECKING: ...@@ -83,6 +92,14 @@ if TYPE_CHECKING:
else: else:
from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_llama import FlaxLlamaForCausalLM, FlaxLlamaModel, FlaxLlamaPreTrainedModel
else: else:
import sys import sys
......
This diff is collapsed.
...@@ -265,7 +265,7 @@ class MptPreTrainedModel(PreTrainedModel): ...@@ -265,7 +265,7 @@ class MptPreTrainedModel(PreTrainedModel):
@staticmethod @staticmethod
def _convert_to_mpt_cache( def _convert_to_mpt_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
""" """
Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...])) Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...]))
......
...@@ -800,6 +800,27 @@ class FlaxGPTJPreTrainedModel(metaclass=DummyObject): ...@@ -800,6 +800,27 @@ class FlaxGPTJPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxLlamaForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLlamaModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLlamaPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLongT5ForConditionalGeneration(metaclass=DummyObject): class FlaxLongT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import LlamaConfig, is_flax_available, is_tokenizers_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import jax.numpy as jnp
from transformers.models.llama.modeling_flax_llama import FlaxLlamaForCausalLM, FlaxLlamaModel
if is_tokenizers_available():
from transformers import LlamaTokenizerFast
class FlaxLlamaModelTester:
def __init__(
self,
parent,
batch_size=2,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=64,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
window_size=7,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.window_size = window_size
self.initializer_range = initializer_range
self.scope = None
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = np.tril(np.ones((self.batch_size, self.seq_length)))
config = LlamaConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
use_cache=True,
is_decoder=False,
initializer_range=self.initializer_range,
)
return (config, input_ids, input_mask)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
attention_mask=attention_mask,
past_key_values=outputs_cache.past_key_values,
position_ids=position_ids,
)
outputs = model(input_ids)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
attention_mask_cache = jnp.concatenate(
[attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
axis=-1,
)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask_cache,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
past_key_values=outputs_cache.past_key_values,
attention_mask=attention_mask_cache,
position_ids=position_ids,
)
outputs = model(input_ids, attention_mask=attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
@require_flax
class FlaxLlamaModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
all_model_classes = (FlaxLlamaModel, FlaxLlamaForCausalLM) if is_flax_available() else ()
all_generative_model_classes = (FlaxLlamaForCausalLM,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxLlamaModelTester(self)
def test_use_cache_forward(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)
def test_use_cache_forward_with_attn_mask(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward_with_attn_mask(
model_class_name, config, input_ids, attention_mask
)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("openlm-research/open_llama_3b_v2", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
@slow
@require_flax
class FlaxLlamaIntegrationTest(unittest.TestCase):
def setUp(self):
self.model_id = "openlm-research/open_llama_3b_v2"
self.model = FlaxLlamaForCausalLM.from_pretrained(self.model_id, from_pt=True)
self.test_batch = jnp.arange(32).reshape(4, 8) + 1911
def test_model_logits(self):
flax_logits = self.model(self.test_batch).logits
# fmt: off
EXPECTED_LOGITS = [-74.4243, -74.0680, -65.2507, -79.1658, -77.7460, -69.2379, -86.4588, -84.8933, -77.8456]
EXPECTED_MIN, EXPECTED_MAX, EXPECTED_MEAN = -96.9952
EXPECTED_MAX = -18.4571
EXPECTED_MEAN = -65.0608
# fmt: on
self.assertTrue(np.allclose(flax_logits[0, :3, :3].flatten(), EXPECTED_LOGITS, atol=1e-4))
self.assertAlmostEqual(flax_logits.min(), EXPECTED_MIN, places=3)
self.assertAlmostEqual(flax_logits.max(), EXPECTED_MAX, places=3)
self.assertAlmostEqual(flax_logits.mean(), EXPECTED_MEAN, places=3)
def test_model_hidden_states(self):
flax_hidden_states = self.model(self.test_batch, output_hidden_states=True).hidden_states
flax_hidden_means = [h.mean() for h in flax_hidden_states]
# fmt: off
EXPECTED_HIDDEN_MEANS = [
-0.00007,-0.00049,-0.00169,-0.00253,-0.00271,
-0.00290,-0.00252,0.00230,0.00230,0.00198,
0.00196,0.00174,0.00246,0.00205,0.00242,
0.00171,0.00092,0.00054,0.00102,0.00024,
0.00029,0.00037,-0.00101,-0.00062,-0.00341,-0.00636,-0.00357
]
# fmt: on
self.assertTrue(np.allclose(flax_hidden_means, EXPECTED_HIDDEN_MEANS, atol=1e-4))
def test_generated_text(self):
tokenizer = LlamaTokenizerFast.from_pretrained(self.model_id)
tokenizer.pad_token_id = 2
test_batch = ["Aloha, World! ", "2 + 2 = ", "Paris is the capital of ", "我很高興認識"]
inputs = tokenizer(test_batch, return_tensors="np", truncation=True, padding=True)
generated_ids = self.model.generate(**inputs, max_length=15).sequences
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# fmt: off
EXPECTED_GENERATION = [
"Aloha, World! 201",
"2 + 2 = 4\n2",
"Paris is the capital of Île-",
"我很高興認識你,我"
]
# fmt: on
self.assertListEqual(generated_text, EXPECTED_GENERATION)
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch LLaMA model. """ """ Testing suite for the PyTorch LLaMA model. """
import unittest import unittest
import pytest import pytest
...@@ -33,7 +32,7 @@ from transformers.testing_utils import ( ...@@ -33,7 +32,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
...@@ -105,7 +104,7 @@ class LlamaModelTester: ...@@ -105,7 +104,7 @@ class LlamaModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length]) input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
......
...@@ -34,7 +34,7 @@ from transformers.testing_utils import ( ...@@ -34,7 +34,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
...@@ -107,7 +107,7 @@ class MistralModelTester: ...@@ -107,7 +107,7 @@ class MistralModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length]) input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
......
...@@ -32,7 +32,7 @@ from transformers.testing_utils import ( ...@@ -32,7 +32,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
...@@ -104,7 +104,7 @@ class PersimmonModelTester: ...@@ -104,7 +104,7 @@ class PersimmonModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length]) input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
......
...@@ -38,7 +38,6 @@ if is_torch_available(): ...@@ -38,7 +38,6 @@ if is_torch_available():
) )
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Phi
class PhiModelTester: class PhiModelTester:
def __init__( def __init__(
self, self,
......
...@@ -233,6 +233,8 @@ OBJECTS_TO_IGNORE = [ ...@@ -233,6 +233,8 @@ OBJECTS_TO_IGNORE = [
"FlaxGPTJModel", "FlaxGPTJModel",
"FlaxGPTNeoForCausalLM", "FlaxGPTNeoForCausalLM",
"FlaxGPTNeoModel", "FlaxGPTNeoModel",
"FlaxLlamaForCausalLM",
"FlaxLlamaModel",
"FlaxMBartForConditionalGeneration", "FlaxMBartForConditionalGeneration",
"FlaxMBartForQuestionAnswering", "FlaxMBartForQuestionAnswering",
"FlaxMBartForSequenceClassification", "FlaxMBartForSequenceClassification",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment