Unverified Commit 0fe44059 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Add recurrent gemma (#30143)



* Fork.

* RecurrentGemma initial commit.

* Updating __init__.py.

* Minor modification to how we initialize the cache.
Changing how the config specifies the architecture.

* Reformat code to 4 spaces.
Fixed a few typos.

* Fixed the forward pass.
Still unclear on the cache?

* Fixed the RecurrentGemmaForCausalLM

* Minor comment that we might not need attention_mask and output_attention arguments.

* Now cache should work as well.

* Adding a temporary example to check whether the model generation works.

* Adding the tests and updating imports.

* Adding the example file missing in the previous commit.

* First working example.

* Removing .gitignore and reverting parts of __init__.

* Re-add .gitignore.

* Addressing comments for configuration.

* Move mask creation to `_prepare_inputs_for_generation`.

* First try at integration tests:
1. AttributeError: 'GriffinCausalLMOutput' object has no attribute 'attentions'.
2. `cache_position` not passed

* Transfoering between machines.

* Running normal tests.

* Minor fix.

* More fixes.

* Addressing more comments.

* Minor fixes.

* first stab at cleanup

* more refactoring

* fix copies and else

* renaming and get init to work

* fix causal mask creation

* update

* nit

* fix a hell lot of things

* updates

* update conversion script

* make all keys importable

* nits

* add auto mappings

* properly convert ffw_up and down

* add scaling

* fix generations

* for recurrent dtype

* update

* fix going beyong window

* fixup

* add missing files

* current updates to remove last einops

* finish modeling refactor

* TADA

* fix compile

* fix most failing testt ? ?

* update tests

* refactor and update

* update

* nits, fixup and update tests

* more fixup

* nits

* fix imports

* test format

* fixups

* nits

* tuple typing

* fix code quality

* add model card

* fix doc

* skip most generation tests

* nits

* style

* doc fixes

* fix pr and check_copies?

* last nit

* oupsy

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* update

* Update src/transformers/models/recurrent_gemma/convert_recurrent_gemma_to_hf.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* update based on review

* doc nit

* fix quality

* quality

* fix slow test model path

* update default dype

* ignore attributes that can be safely ignored in check config attributes

* 0lallalala come on

* save nit

* style

* remove to dict update

* make sure we can also run in float16

* style

---------
Co-authored-by: default avatarPablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: default avatarAleksandar Botev <botev@google.com>
Co-authored-by: default avatarLeonard Berrada <lberrada@users.noreply.github.com>
Co-authored-by: default avataranushanf <anushanf@google.com>
Co-authored-by: default avatarbotev <botevmg@gmail.com>
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 33bca541
......@@ -183,6 +183,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("qdqbert", "QDQBertModel"),
("qwen2", "Qwen2Model"),
("qwen2_moe", "Qwen2MoeModel"),
("recurrent_gemma", "RecurrentGemmaModel"),
("reformer", "ReformerModel"),
("regnet", "RegNetModel"),
("rembert", "RemBertModel"),
......@@ -469,6 +470,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("qdqbert", "QDQBertLMHeadModel"),
("qwen2", "Qwen2ForCausalLM"),
("qwen2_moe", "Qwen2MoeForCausalLM"),
("recurrent_gemma", "RecurrentGemmaForCausalLM"),
("reformer", "ReformerModelWithLMHead"),
("rembert", "RemBertForCausalLM"),
("roberta", "RobertaForCausalLM"),
......
......@@ -363,6 +363,13 @@ else:
),
("rag", ("RagTokenizer", None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
(
"recurrent_gemma",
(
"GemmaTokenizer" if is_sentencepiece_available() else None,
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"reformer",
(
......
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_recurrent_gemma": ["RecurrentGemmaConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_recurrent_gemma"] = [
"RecurrentGemmaForCausalLM",
"RecurrentGemmaModel",
"RecurrentGemmaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_recurrent_gemma import RecurrentGemmaConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_recurrent_gemma import (
RecurrentGemmaForCausalLM,
RecurrentGemmaModel,
RecurrentGemmaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" RecurrentGemma model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class RecurrentGemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`RecurrentGemmaModel`]. It is used to instantiate a RecurrentGemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the RecurrentGemma-7B.
e.g. [google/recurrentgemma-2b](https://huggingface.co/google/recurrentgemma-2b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_hidden_layers (`int`, *optional*, defaults to 26):
The number of hidden layers in the model.
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the RecurrentGemma model. Defines the number of
different tokens that can be represented by the
`inputs_ids` passed when calling [`RecurrentGemmaModel`]
hidden_size (`int`, *optional*, defaults to 2560):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 7680):
Dimension of the MLP representations.
num_attention_heads (`int`, *optional*, defaults to 10):
The number of heads for the attention block and the number of
heads/blocks for the block-diagonal layers used in the RG-LRU gates.
This number must divide `hidden_size` and `lru_width`.
lru_width (`int` or `None`, *optional*):
Dimension of the hidden representations of the RG-LRU. If `None`
this will be set to `hidden_size`.
Whether to scale the output of the embeddings by `sqrt(hidden_size)`.
attention_window_size (`int`, *optional*, defaults to 2048):
The size of the attention window used in the attention block.
conv1d_width (`int`, *optional*, defaults to 4):
The kernel size of conv1d layers used in the recurrent blocks.
logits_soft_cap (`float`, *optional*, defaults to 30.0):
The value at which the logits should be soft-capped to after the transformer and LM-head computation in the Causal LM architecture.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether the model should return the last key/values
attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
hidden_activation (``str` or `function``, *optional*, defaults to `"gelu_pytorch_tanh"`):
The hidden activation used in the recurrent block as well as the MLP layer of the decoder layers.
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
The partial rotary factor used in the initialization of the rotary embeddings.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
block_types (`List[str]`, *optional*, defaults to `('recurrent', 'recurrent', 'attention')`):
List of aleternating blocks that will be repeated to initialize the `temporal_block` layer.
attention_dropout (`float`, *optional*, defaults to 0.0): dropout value to use after the attention softmax.
num_key_value_heads (`16`, *optional*, defaults to 16): Number of key value heads to use GQA.
attention_bias (`bool`, *optional*, defaults to `False`): whether or not the linear q,k,v of the Attention layer should have bias
w_init_variance_scale (`float`, *optional*, defaults to 0.01): weight initialization variance.
```python
>>> from transformers import RecurrentGemmaModel, RecurrentGemmaConfig
>>> # Initializing a RecurrentGemma recurrentgemma-2b style configuration
>>> configuration = RecurrentGemmaConfig()
>>> # Initializing a model from the recurrentgemma-2b style configuration
>>> model = RecurrentGemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "recurrent_gemma"
def __init__(
self,
num_hidden_layers=26,
vocab_size=256000,
hidden_size=2560,
intermediate_size=3 * 2560,
num_attention_heads=10,
lru_width=None,
attention_window_size=2048,
conv1d_width=4,
logits_soft_cap=30.0,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
hidden_activation="gelu_pytorch_tanh",
partial_rotary_factor=0.5,
rope_theta=10000.0,
block_types=("recurrent", "recurrent", "attention"),
attention_dropout=0.0,
num_key_value_heads=None,
attention_bias=False,
w_init_variance_scale=0.01,
**kwargs,
):
self.num_hidden_layers = num_hidden_layers
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.lru_width = lru_width if lru_width is not None else hidden_size
self.attention_window_size = attention_window_size
self.conv1d_width = conv1d_width
self.logits_soft_cap = logits_soft_cap
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.partial_rotary_factor = partial_rotary_factor
self.block_types = list(block_types)
self.hidden_activation = hidden_activation
self.head_dim = self.hidden_size // self.num_attention_heads
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
if self.num_key_value_heads > self.num_attention_heads:
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
self.attention_dropout = attention_dropout
self.attention_bias = attention_bias
self.w_init_variance_scale = w_init_variance_scale
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
@property
def layers_block_type(self):
return (self.block_types * 100)[: self.num_hidden_layers]
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import warnings
import torch
from accelerate import init_empty_weights
from transformers import GemmaTokenizer, RecurrentGemmaConfig, RecurrentGemmaForCausalLM
try:
from transformers import GemmaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
GemmaTokenizerFast = None
import regex as re
"""
Sample usage:
```
python src/transformers/models/gemma/convert_gemma_weights_to_hf.py \
--input_dir /path/to/downloaded/gemma/weights --model_size 7B --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
from transformers import GemmaForCausalLM, GemmaTokenizerFast
model = GemmaForCausalLM.from_pretrained("/output/path")
tokenizer = GemmaTokenizerFast.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
gemma_2b_config = RecurrentGemmaConfig(
num_attention_heads=10,
num_key_value_heads=1,
hidden_size=2560,
intermediate_size=15360,
vocab_size=256000,
num_hidden_layers=26,
)
gemma_7b_config = RecurrentGemmaConfig()
CONFIG_MAPPING = {"2B": gemma_2b_config, "7B": gemma_7b_config}
LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"}
def write_model(save_path, input_base_path, config, safe_serialization=True, push_to_hub=False, dtype=torch.float32):
print(f"Fetching all parameters from the checkpoint at '{input_base_path}'")
model_state_dict = torch.load(input_base_path, map_location="cpu")
REPLACEMENT = {
"blocks.": "layers.",
".ffw_down.b": ".down_proj.b",
".ffw_down.w": ".down_proj.w",
".ffw_up.b": ".up_proj.bias",
".ffw_up.w": ".up_proj.weight",
"recurrent_block": "temporal_block",
"attention_block": "temporal_block",
"temporal_block.proj_final": "temporal_block.out_proj",
"norm.scale": "norm.weight",
".proj_k": ".k_proj",
".proj_q": ".q_proj",
".proj_v": ".v_proj",
".proj_final": ".o_proj",
"embedder.input_embedding": "embed_tokens.weight",
"conv_1d.w": "conv_1d.weight",
"conv_1d.b": "conv_1d.bias",
"input_gate.w": "input_gate.weight",
"input_gate.b": "input_gate.bias",
"a_param": "recurrent_param",
"a_gate.b": "recurrent_gate.bias",
"a_gate.w": "recurrent_gate.weight",
}
state_dict = {}
for k, v in model_state_dict.items():
k = "model." + k
pattern = re.compile("|".join(map(re.escape, REPLACEMENT.keys())))
key = pattern.sub(lambda match: REPLACEMENT[match.group(0)], k)
if "conv_1d.weight" in key:
v = v[:, None, :].transpose(0, 2)
if "up_proj.weight" in key:
state_dict[key.replace("up_proj", "gate_proj")] = v[0].T.contiguous()
v = v[1].T.contiguous()
if "up_proj.bias" in key:
state_dict[key.replace("up_proj", "gate_proj")] = v[0, 0, 0].clone()
v = v[1, 0, 0].contiguous()
if "recurrent_gate.bias" in key:
state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone()
elif "recurrent_gate.weight" in key:
state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone()
elif "input_gate.b" in key:
state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone()
elif "input_gate.w" in key:
state_dict[key.replace("gate.", "gate_")] = v.contiguous().clone()
elif "embed_tokens" in key:
state_dict[key] = v[: config.vocab_size, :].contiguous().clone()
state_dict["lm_head.weight"] = v[: config.vocab_size, :].contiguous().clone()
else:
state_dict[key] = v.contiguous()
torch.set_default_dtype(dtype)
print("Loading the checkpoint in a Gemma model.")
with init_empty_weights():
model = RecurrentGemmaForCausalLM(config)
model.load_state_dict(state_dict, assign=True, strict=True)
model.config.torch_dtype = torch.float32
del model.config._name_or_path
print("Saving in the Transformers format.")
if push_to_hub:
print(f"pushing the model to {save_path}")
else:
model.save_pretrained(save_path, safe_serialization=safe_serialization)
def write_tokenizer(input_tokenizer_path, save_path, push_to_hub=False):
# Initialize the tokenizer based on the `spm` model
tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast
print(f"Saving a {tokenizer_class.__name__} to {save_path}.")
tokenizer = tokenizer_class(input_tokenizer_path)
if push_to_hub:
tokenizer.push_to_hub(save_path)
else:
tokenizer.save_pretrained(save_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_checkpoint",
help="Absolute path to the target Gemma weights.",
default="/home/arthur/transformers_recurrentgemma/google/recurrent-gemma-2b-it/ToBeDeleted/2b-it.pt",
)
parser.add_argument(
"--tokenizer_checkpoint",
help="Location of Gemma tokenizer model",
)
parser.add_argument(
"--model_size",
default="2B",
choices=["2B", "7B", "tokenizer_only"],
help="'f' models correspond to the finetuned versions, and are specific to the Gemma2 official release. For more details on Gemma2, checkout the original repo: https://huggingface.co/google/gemma-7b",
)
parser.add_argument(
"--output_dir",
default="google/recurrent-gemma-2b-it-hf",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--pickle_serialization",
help="Whether or not to save using `safetensors`.",
action="store_true",
default=False,
)
parser.add_argument(
"--convert_tokenizer",
help="Whether or not to convert the tokenizer as well.",
action="store_true",
default=False,
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.",
action="store_true",
default=False,
)
parser.add_argument(
"--dtype",
default="float32",
help="Target dtype of the converted model",
)
args = parser.parse_args()
if args.convert_tokenizer:
if args.tokenizer_checkpoint is None:
raise ValueError("Path to the tokenizer is required when passing --convert_tokenizer")
spm_path = os.path.join(args.tokenizer_checkpoint)
write_tokenizer(spm_path, args.output_dir, args.push_to_hub)
config = CONFIG_MAPPING[args.model_size]
dtype = getattr(torch, args.dtype)
write_model(
config=config,
input_base_path=args.input_checkpoint,
save_path=args.output_dir,
safe_serialization=not args.pickle_serialization,
push_to_hub=args.push_to_hub,
dtype=dtype,
)
if __name__ == "__main__":
main()
......@@ -7051,6 +7051,27 @@ def load_tf_weights_in_realm(*args, **kwargs):
requires_backends(load_tf_weights_in_realm, ["torch"])
class RecurrentGemmaForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RecurrentGemmaModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RecurrentGemmaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
......@@ -34,6 +34,8 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = {
# used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"],
# used to compute the property `self.layers_block_type`
"RecurrentGemmaConfig": ["block_types"],
# used as in the config to define `intermediate_size`
"MambaConfig": ["expand"],
# used as `self.bert_model = BertModel(config, ...)`
......
......@@ -86,6 +86,7 @@ PRIVATE_MODELS = [
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"RecurrentGemmaModel", # Building part of bigger (tested) model.
"FuyuForCausalLM", # Not tested fort now
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
"UMT5EncoderModel", # Building part of bigger (tested) model.
......
......@@ -768,6 +768,7 @@ src/transformers/models/rag/modeling_tf_rag.py
src/transformers/models/rag/retrieval_rag.py
src/transformers/models/realm/modeling_realm.py
src/transformers/models/realm/retrieval_realm.py
src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
src/transformers/models/regnet/configuration_regnet.py
src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment