Unverified Commit 0cf60f13 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Add gemma 2 (#31659)



* inital commit

* Add doc

* protect?

* fixup stuffs

* update tests

* fix build documentation

* mmmmmmm config attributes

* style

* nit

* uodate

* nit

* Fix docs

* protect some stuff

---------
Co-authored-by: default avatarLysandre <lysandre@huggingface.co>
parent 4aa17d00
......@@ -145,6 +145,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Funnel Transformer](model_doc/funnel) | ✅ | ✅ | ❌ |
| [Fuyu](model_doc/fuyu) | ✅ | ❌ | ❌ |
| [Gemma](model_doc/gemma) | ✅ | ❌ | ✅ |
| [Gemma2](model_doc/gemma2) | ✅ | ❌ | ❌ |
| [GIT](model_doc/git) | ✅ | ❌ | ❌ |
| [GLPN](model_doc/glpn) | ✅ | ❌ | ❌ |
| [GPT Neo](model_doc/gpt_neo) | ✅ | ❌ | ✅ |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Gemma2
## Overview
The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/Gemma2-open-models/) by Gemma2 Team, Google.
Gemma2 models are trained on 6T tokens, and released with 2 versions, 2b and 7b.
The abstract from the paper is the following:
*This work introduces Gemma2, a new family of open language models demonstrating strong performance across academic benchmarks for language understanding, reasoning, and safety. We release two sizes of models (2 billion and 7 billion parameters), and provide both pretrained and fine-tuned checkpoints. Gemma2 outperforms similarly sized open models on 11 out of 18 text-based tasks, and we present comprehensive evaluations of safety and responsibility aspects of the models, alongside a detailed description of our model development. We believe the responsible release of LLMs is critical for improving the safety of frontier models, and for enabling the next wave of LLM innovations*
Tips:
- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py`
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen]().
## Gemma2Config
[[autodoc]] Gemma2Config
## Gemma2Model
[[autodoc]] Gemma2Model
- forward
## Gemma2ForCausalLM
[[autodoc]] Gemma2ForCausalLM
- forward
## Gemma2ForSequenceClassification
[[autodoc]] Gemma2ForSequenceClassification
- forward
## Gemma2ForTokenClassification
[[autodoc]] Gemma2ForTokenClassification
- forward
......@@ -435,6 +435,7 @@ _import_structure = {
],
"models.fuyu": ["FuyuConfig"],
"models.gemma": ["GemmaConfig"],
"models.gemma2": ["Gemma2Config"],
"models.git": [
"GitConfig",
"GitProcessor",
......@@ -2181,6 +2182,15 @@ else:
"GemmaPreTrainedModel",
]
)
_import_structure["models.gemma2"].extend(
[
"Gemma2ForCausalLM",
"Gemma2ForSequenceClassification",
"Gemma2ForTokenClassification",
"Gemma2Model",
"Gemma2PreTrainedModel",
]
)
_import_structure["models.git"].extend(
[
"GitForCausalLM",
......@@ -5062,6 +5072,7 @@ if TYPE_CHECKING:
)
from .models.fuyu import FuyuConfig
from .models.gemma import GemmaConfig
from .models.gemma2 import Gemma2Config
from .models.git import (
GitConfig,
GitProcessor,
......@@ -6694,6 +6705,13 @@ if TYPE_CHECKING:
GemmaModel,
GemmaPreTrainedModel,
)
from .models.gemma2 import (
Gemma2ForCausalLM,
Gemma2ForSequenceClassification,
Gemma2ForTokenClassification,
Gemma2Model,
Gemma2PreTrainedModel,
)
from .models.git import (
GitForCausalLM,
GitModel,
......
......@@ -970,3 +970,125 @@ class SlidingWindowCache(StaticCache):
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None
class HybridCache(Cache):
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.is_sliding = torch.tensor(
[i % 2 for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = (
max_batch_size,
self.num_key_value_heads,
min(config.sliding_window, max_cache_len),
self.head_dim,
)
for i in range(config.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] > max_cache_len:
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, max_cache_len - 1)
to_shift = cache_position >= max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
if sliding_window:
update_fn = self._sliding_update
else:
update_fn = self._static_update
return update_fn(
cache_position,
layer_idx,
key_states,
value_states,
k_out,
v_out,
k_out.shape[2],
)
def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return self.max_cache_len
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return None
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
......@@ -400,7 +400,7 @@ class GenerationConfig(PushToHubMixin):
# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None:
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
......
......@@ -28,6 +28,7 @@ from ..cache_utils import (
Cache,
DynamicCache,
HQQQuantizedCache,
HybridCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
......@@ -112,7 +113,7 @@ logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
......@@ -1395,10 +1396,12 @@ class GenerationMixin:
past_length = 0
if model_kwargs.get("past_key_values") is not None:
if isinstance(model_kwargs["past_key_values"], Cache):
past_length = model_kwargs["past_key_values"].get_seq_length()
else:
past_length = model_kwargs["past_key_values"][0][0].shape[2]
cache = model_kwargs["past_key_values"]
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length"):
past_length = cache.get_seq_length()
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
......@@ -1739,7 +1742,9 @@ class GenerationMixin:
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
......
......@@ -92,6 +92,7 @@ from . import (
funnel,
fuyu,
gemma,
gemma2,
git,
glpn,
gpt2,
......
......@@ -108,6 +108,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("funnel", "FunnelConfig"),
("fuyu", "FuyuConfig"),
("gemma", "GemmaConfig"),
("gemma2", "Gemma2Config"),
("git", "GitConfig"),
("glpn", "GLPNConfig"),
("gpt-sw3", "GPT2Config"),
......@@ -385,6 +386,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("funnel", "Funnel Transformer"),
("fuyu", "Fuyu"),
("gemma", "Gemma"),
("gemma2", "Gemma2"),
("git", "GIT"),
("glpn", "GLPN"),
("gpt-sw3", "GPT-Sw3"),
......
......@@ -105,6 +105,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("fsmt", "FSMTModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("gemma", "GemmaModel"),
("gemma2", "Gemma2Model"),
("git", "GitModel"),
("glpn", "GLPNModel"),
("gpt-sw3", "GPT2Model"),
......@@ -454,6 +455,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("falcon", "FalconForCausalLM"),
("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"),
("git", "GitForCausalLM"),
("gpt-sw3", "GPT2LMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
......@@ -863,6 +865,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("fnet", "FNetForSequenceClassification"),
("funnel", "FunnelForSequenceClassification"),
("gemma", "GemmaForSequenceClassification"),
("gemma2", "Gemma2ForSequenceClassification"),
("gpt-sw3", "GPT2ForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
......@@ -1044,6 +1047,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("fnet", "FNetForTokenClassification"),
("funnel", "FunnelForTokenClassification"),
("gemma", "GemmaForTokenClassification"),
("gemma2", "Gemma2ForTokenClassification"),
("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
......
......@@ -188,6 +188,13 @@ else:
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"gemma2",
(
"GemmaTokenizer" if is_sentencepiece_available() else None,
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
......
......@@ -257,6 +257,7 @@ class GemmaAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = 1 / math.sqrt(config.head_dim)
if self.hidden_size % self.num_heads != 0:
raise ValueError(
......@@ -305,7 +306,7 @@ class GemmaAttention(nn.Module):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
......
......@@ -240,6 +240,7 @@ class GemmaAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = 1 / math.sqrt(config.head_dim)
if self.hidden_size % self.num_heads != 0:
raise ValueError(
......@@ -288,7 +289,7 @@ class GemmaAttention(nn.Module):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
......@@ -898,6 +899,13 @@ class GemmaModel(GemmaPreTrainedModel):
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
# decoder layers
all_hidden_states = () if output_hidden_states else None
......@@ -1397,7 +1405,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
......@@ -1407,7 +1415,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_gemma2": ["Gemma2Config"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gemma2"] = [
"Gemma2ForCausalLM",
"Gemma2Model",
"Gemma2PreTrainedModel",
"Gemma2ForSequenceClassification",
"Gemma2ForTokenClassification",
]
if TYPE_CHECKING:
from .configuration_gemma import Gemma2Config
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gemma import (
Gemma2ForCausalLM,
Gemma2ForSequenceClassification,
Gemma2ForTokenClassification,
Gemma2Model,
Gemma2PreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import PretrainedConfig
class Gemma2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma2-7B.
e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Gemma2Model`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
```python
>>> from transformers import Gemma2Model, Gemma2Config
>>> # Initializing a Gemma2 gemma2-9b style configuration
>>> configuration = Gemma2Config()
>>> # Initializing a model from the gemma2-9b style configuration
>>> model = Gemma2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
final_logit_softcapping=30.0,
query_pre_attn_scalar=224,
sliding_window=4096,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.final_logit_softcapping = final_logit_softcapping
self.query_pre_attn_scalar = query_pre_attn_scalar
self.sliding_window = sliding_window
self.cache_implementation = "hybrid"
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import warnings
import torch
from accelerate import init_empty_weights
from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer
try:
from transformers import GemmaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
GemmaTokenizerFast = None
"""
Sample usage:
```
python src/transformers/models/gemma2/convert_gemma2_weights_to_hf.py \
--input_dir /path/to/downloaded/gemma/weights --model_size 9B --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
from transformers import Gemma2ForCausalLM, GemmaTokenizerFast
model = Gemma2ForCausalLM.from_pretrained("/output/path")
tokenizer = GemmaTokenizerFast.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
gemma_9b_config = Gemma2Config(
num_hidden_layers=42,
num_attention_heads=16,
num_key_value_heads=8,
hidden_size=3584,
intermediate_size=14336,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
head_dim=256,
sliding_window=4096,
query_pre_attn_scalar=224,
)
gemma_27b_config = Gemma2Config(
num_hidden_layers=46,
num_attention_heads=32,
num_key_value_heads=16,
hidden_size=4608,
intermediate_size=36864,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
head_dim=128,
sliding_window=4096,
query_pre_attn_scalar=144,
)
CONFIG_MAPPING = {"9B": gemma_9b_config, "27B": gemma_27b_config}
LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"}
def write_model(save_path, input_base_path, config, safe_serialization=True, push_to_hub=False, dtype=torch.float32):
num_attn_heads = config.num_attention_heads
hidden_size = config.hidden_size
num_kv_heads = config.num_key_value_heads
head_dim = config.head_dim
print(f"Fetching all parameters from the checkpoint at '{input_base_path}'")
if os.path.isdir(input_base_path):
print("Model seems sharded")
model_state_dict = {}
files = [file for file in os.listdir(input_base_path) if file.endswith(".bin")]
for file in files:
print(file)
loaded_state_dict = torch.load(os.path.join(input_base_path, file), map_location="cpu")
model_state_dict.update(loaded_state_dict)
else:
print("Model does not seem to be sharded")
model_state_dict = torch.load(input_base_path, map_location="cpu")["model_state_dict"]
model_state_dict.pop("freqs_cis")
state_dict = {}
for k, v in model_state_dict.items():
if "qkv_proj" in k:
if num_kv_heads == 1:
v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim, hidden_size)
q_proj = v[:num_attn_heads, ...]
k_proj = v[num_attn_heads : num_attn_heads + num_kv_heads, ...].repeat(num_kv_heads, 1, 1)
v_proj = v[-num_kv_heads:, ...].repeat(num_kv_heads, 1, 1)
state_dict[k.replace("qkv_proj", "q_proj")] = q_proj.reshape(
num_attn_heads * head_dim, hidden_size
).clone()
state_dict[k.replace("qkv_proj", "k_proj")] = k_proj.reshape(
num_kv_heads * head_dim, hidden_size
).clone()
state_dict[k.replace("qkv_proj", "v_proj")] = v_proj[0].clone()
else:
q_proj, k_proj, v_proj = torch.split(
v, [num_attn_heads * head_dim, num_kv_heads * head_dim, num_kv_heads * head_dim], 0
)
state_dict[k.replace("qkv_proj", "q_proj")] = q_proj.reshape(
num_attn_heads * head_dim, hidden_size
).clone()
state_dict[k.replace("qkv_proj", "k_proj")] = k_proj.reshape(
num_kv_heads * head_dim, hidden_size
).clone()
state_dict[k.replace("qkv_proj", "v_proj")] = v_proj.reshape(
num_kv_heads * head_dim, hidden_size
).clone()
elif k == "embedder.weight":
state_dict[LAYER_NAME_MAPPING[k]] = v
state_dict["lm_head.weight"] = v
else:
state_dict[k] = v
torch.set_default_dtype(dtype)
print("Loading the checkpoint in a Gemma2 model.")
with init_empty_weights():
model = Gemma2ForCausalLM(config)
model.load_state_dict(state_dict, assign=True, strict=False)
model.config.torch_dtype = torch.float32
del model.config._name_or_path
print("Saving in the Transformers format.")
if push_to_hub:
print(f"pushing the model to {save_path}")
model.push_to_hub(save_path, safe_serialization=safe_serialization, private=True)
else:
model.save_pretrained(save_path, safe_serialization=safe_serialization)
def write_tokenizer(input_tokenizer_path, save_path, push_to_hub=False):
# Initialize the tokenizer based on the `spm` model
tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast
print(f"Saving a {tokenizer_class.__name__} to {save_path}.")
tokenizer = tokenizer_class(input_tokenizer_path)
if push_to_hub:
tokenizer.push_to_hub(save_path)
else:
tokenizer.save_pretrained(save_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_checkpoint",
help="Absolute path to the target Gemma2 weights.",
required=True,
)
parser.add_argument(
"--tokenizer_checkpoint",
help="Location of Gemma2 tokenizer model",
)
parser.add_argument(
"--model_size",
default="9B",
choices=["9B", "27B", "tokenizer_only"],
help="'f' models correspond to the finetuned versions, and are specific to the Gemma22 official release. For more details on Gemma2, checkout the original repo: https://huggingface.co/google/gemma-7b",
)
parser.add_argument(
"--output_dir",
default="google/gemma-9b",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--pickle_serialization",
help="Whether or not to save using `safetensors`.",
action="store_true",
default=False,
)
parser.add_argument(
"--convert_tokenizer",
help="Whether or not to convert the tokenizer as well.",
action="store_true",
default=False,
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.",
action="store_true",
default=False,
)
parser.add_argument(
"--dtype",
default="float32",
help="Target dtype of the converted model",
)
args = parser.parse_args()
if args.convert_tokenizer:
if args.tokenizer_checkpoint is None:
raise ValueError("Path to the tokenizer is required when passing --convert_tokenizer")
spm_path = os.path.join(args.tokenizer_checkpoint)
write_tokenizer(spm_path, args.output_dir, args.push_to_hub)
if not args.model_size == "tokenizer_only":
config = CONFIG_MAPPING[args.model_size]
dtype = getattr(torch, args.dtype)
write_model(
config=config,
input_base_path=args.input_checkpoint,
save_path=args.output_dir,
safe_serialization=not args.pickle_serialization,
push_to_hub=args.push_to_hub,
dtype=dtype,
)
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
......@@ -227,7 +227,6 @@ class MistralAttention(nn.Module):
base=self.rope_theta,
)
# Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Mistral
def forward(
self,
hidden_states: torch.Tensor,
......
......@@ -4197,6 +4197,41 @@ class GemmaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Gemma2ForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Gemma2ForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Gemma2ForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Gemma2Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Gemma2PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GitForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -47,11 +47,18 @@ if is_torch_available():
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaTokenizer,
)
@require_torch
class GemmaModelTester:
config_class = GemmaConfig
if is_torch_available():
model_class = GemmaModel
for_causal_lm_class = GemmaForCausalLM
for_sequence_class = GemmaForSequenceClassification
for_token_class = GemmaForTokenClassification
def __init__(
self,
parent,
......@@ -129,9 +136,8 @@ class GemmaModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
# Ignore copy
def get_config(self):
return GemmaConfig(
return self.config_class(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
......@@ -149,18 +155,16 @@ class GemmaModelTester:
head_dim=self.head_dim,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Gemma
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = GemmaModel(config=config)
model = self.model_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Gemma
def create_and_check_model_as_decoder(
self,
config,
......@@ -174,7 +178,7 @@ class GemmaModelTester:
encoder_attention_mask,
):
config.add_cross_attention = True
model = GemmaModel(config)
model = self.model_class(config)
model.to(torch_device)
model.eval()
result = model(
......@@ -191,7 +195,6 @@ class GemmaModelTester:
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Gemma
def create_and_check_for_causal_lm(
self,
config,
......@@ -204,13 +207,12 @@ class GemmaModelTester:
encoder_hidden_states,
encoder_attention_mask,
):
model = GemmaForCausalLM(config=config)
model = self.for_causal_lm_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Gemma
def create_and_check_decoder_model_past_large_inputs(
self,
config,
......@@ -225,7 +227,7 @@ class GemmaModelTester:
):
config.is_decoder = True
config.add_cross_attention = True
model = GemmaForCausalLM(config=config)
model = self.for_causal_lm_class(config=config)
model.to(torch_device)
model.eval()
......@@ -348,7 +350,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = GemmaForSequenceClassification(config)
model = self.model_tester.for_sequence_class(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
......@@ -361,7 +363,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = GemmaForSequenceClassification(config)
model = self.model_tester.for_sequence_class(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
......@@ -376,20 +378,19 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = GemmaForSequenceClassification(config)
model = self.model_tester.for_sequence_class(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma
def test_Gemma_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = GemmaForTokenClassification(config=config)
model = self.model_tester.for_token_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
......@@ -539,47 +540,9 @@ class GemmaIntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@require_read_token
def test_model_2b_fp32(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token
def test_model_2b_fp16(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token
def test_model_2b_fp16_static_cache(self):
model_id = "google/gemma-2b"
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
......@@ -903,7 +866,7 @@ class GemmaIntegrationTest(unittest.TestCase):
}
prompts = ["Hello I am doing", "Hi today"]
tokenizer = GemmaTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment