Unverified Commit 96eb0628 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Diff converter v2 (#30868)

* current working example!

* commit regex and result file

* update

* nit

* push the conversion file

* oups

* roadmap and nits

* attempt diffs for 3 files

* persimmon

* nit

* add diff file that is the same as the modeling_llama.py

* fix rope nits

* updates

* updates with converted versions

* give some breathing space to the code

* delete

* update

* update

* push the actual result

* update regex patterns

* update regex patterns

* fix some issues

* fix some issues

* fix some issues

* updates

* updates

* updates

* updates

* updates

* revert changes done to llama

* updates

* update gemma

* updates

* oups

* current state

* current state

* update

* ouiiii

* nit

* clear diffs

* nit

* fixup

* update

* doc 🚀

* 🔥

* for now use gemma

* deal with comments

* style

* handle funtions

* deal with assigns

* todos

* process inheritage

* keep decorators?

* 🤗

* deal with duplicates

* fixup

* correctly remove duplicate code

* run ruff post script

* ruff deals pretty well with imports, let's leave it to him

* ah maybe not lol

* for now remove all imports from child.

* nit

* conversion of llama

* okay

* convert starcoder2

* synch with main

* update llama diff

* updates

* https://docs.astral.sh/ruff/rules/redefined-while-unused/

 fixes the imports, bit needs later version of ruff

* updates

* okay actual state

* non zero exit

* update!

* revert unrelated

* remove other diff files

* updates

* cleanup

* update

* less diff!

* stash

* current updates

* updates

* No need for call

* finished fining deps

* update

* current changes

* current state

* current state

* new status

* nit

* finally

* fixes

* nits

* order is now expected

* use logger info instead of prints

* fixup

* up

* nit

* update

* nits

* update

* correct merge

* update

* update

* update

* add warning

* update caution message

* update

* better merging strategy

* copy class statements :wink

* fixups

* nits

* update

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* nits

* smaller header

* do cleanup some stuff

* even simpler header?

* fixup

* updates

* ruff

* update examples

* nit

* TODO

* state

* OUUUUUUF

* current state

* nits

* final state

* add a readme

* fixup

* remove diff llama

* fix

* nit

* dummy noy funny

* ruff format tests src utils --check

* everless diffs

* less diffs and fix test

* fixes

* naming nit?

* update converter and add supper example

* nits

* updated for function signatures

* update

* update

* add converted dummies

* autoformat

* single target assign fix

* fixup

* fix some imports

* fixes

* don't push them

* `# noqa: F841`

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 372baec2
# Using the `diff_converter` linter
`pip install libcst` is a must!
# `sh examples/diff-conversion/convert_examples.sh` to get the converted outputs
The diff converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular `diff` file like `diff_gemma.py` into a `single model single file`.
Examples of possible usage are available in the `examples/diff-conversion`, or `diff_gemma` for a full model usage.
`python utils/diff_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model2.py"`
## How it works
We use the `libcst` parser to produce an AST representation of the `diff_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the difference dependencies.
The code from the `diff` file and the class dependency mapping are "merged" to produce the single model single file.
We use ruff to automatically remove the potential duplicate imports.
## Why we use libcst instead of the native AST?
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`
\ No newline at end of file
#!/bin/bash
# Iterate over each file in the current directory
for file in examples/diff-conversion/diff_*; do
# Check if it's a regular file
if [ -f "$file" ]; then
# Call the Python script with the file name as an argument
python utils/diff_model_converter.py --files_to_parse "$file"
fi
done
\ No newline at end of file
from math import log
from typing import List, Optional, Tuple, Union
import torch
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
def _pre_process_input(input_ids):
print(log(input_ids))
return input_ids
# example where we need some deps and some functions
class DummyModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
input_ids = _pre_process_input(input_ids)
return super().forward(
None,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
)
from transformers.models.llama.configuration_llama import LlamaConfig
# Example where we only want to only add a new config argument and new arg doc
# here there is no `ARG` so we are gonna take parent doc
class MyNewModelConfig(LlamaConfig):
r"""
mlp_bias (`bool`, *optional*, defaults to `False`)
"""
def __init__(self, mlp_bias=True, new_param=0, **super_kwargs):
self.mlp_bias = mlp_bias
self.new_param = new_param
super().__init__(self, **super_kwargs)
from transformers.models.gemma.modeling_gemma import GemmaForSequenceClassification
from transformers.models.llama.configuration_llama import LlamaConfig
# Example where we only want to only modify the docstring
class MyNewModel2Config(LlamaConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
# Example where alllllll the dependencies are fetched to just copy the entire class
class MyNewModel2ForSequenceClassification(GemmaForSequenceClassification):
pass
# Example where we only want to overwrite the defaults of an init
from transformers.models.gemma.configuration_gemma import GemmaConfig
class NewModelConfig(GemmaConfig):
def __init__(
self,
vocab_size=256030,
hidden_size=64,
intermediate_size=90,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=1500,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
):
super().__init__(self)
from typing import List, Optional, Tuple, Union
import torch
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
# example where we need some deps and some functions
class SuperModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
out = super().forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
)
out.logits *= 2**4
return out
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,13 +19,9 @@ ...@@ -12,13 +19,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Gemma model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__) from transformers import PretrainedConfig
class GemmaConfig(PretrainedConfig): class GemmaConfig(PretrainedConfig):
...@@ -26,13 +29,9 @@ class GemmaConfig(PretrainedConfig): ...@@ -26,13 +29,9 @@ class GemmaConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B. defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
vocab_size (`int`, *optional*, defaults to 256000): vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
...@@ -83,16 +82,12 @@ class GemmaConfig(PretrainedConfig): ...@@ -83,16 +82,12 @@ class GemmaConfig(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
```python ```python
>>> from transformers import GemmaModel, GemmaConfig >>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration >>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig() >>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration >>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration) >>> model = GemmaModel(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
......
This diff is collapsed.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
# #
...@@ -13,8 +19,6 @@ ...@@ -13,8 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""PyTorch Gemma model."""
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -26,10 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -26,10 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
...@@ -37,7 +38,7 @@ from ...modeling_outputs import ( ...@@ -37,7 +38,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
...@@ -46,7 +47,6 @@ from ...utils import ( ...@@ -46,7 +47,6 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.import_utils import is_torch_fx_available
from .configuration_gemma import GemmaConfig from .configuration_gemma import GemmaConfig
...@@ -55,25 +55,14 @@ if is_flash_attn_2_available(): ...@@ -55,25 +55,14 @@ if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
if not is_torch_greater_or_equal_than_1_13:
import torch.fx
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GemmaConfig"
def _get_unpad_data(attention_mask): def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return ( return (
indices, indices,
cu_seqlens, cu_seqlens,
...@@ -108,7 +97,6 @@ class GemmaRotaryEmbedding(nn.Module): ...@@ -108,7 +97,6 @@ class GemmaRotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
...@@ -130,7 +118,35 @@ class GemmaRotaryEmbedding(nn.Module): ...@@ -130,7 +118,35 @@ class GemmaRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def forward(self, x, position_ids):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids)
return cos, sin
class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
"""GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def forward(self, x, position_ids):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
cos, sin = super().forward(x, position_ids)
return cos, sin
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
...@@ -138,7 +154,6 @@ def rotate_half(x): ...@@ -138,7 +154,6 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -190,7 +205,6 @@ class GemmaMLP(nn.Module): ...@@ -190,7 +205,6 @@ class GemmaMLP(nn.Module):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
""" """
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
...@@ -206,7 +220,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: ...@@ -206,7 +220,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class GemmaAttention(nn.Module): class GemmaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
# Ignore copy
def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -303,7 +316,6 @@ class GemmaAttention(nn.Module): ...@@ -303,7 +316,6 @@ class GemmaAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemma
class GemmaFlashAttention2(GemmaAttention): class GemmaFlashAttention2(GemmaAttention):
""" """
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
...@@ -319,7 +331,6 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -319,7 +331,6 @@ class GemmaFlashAttention2(GemmaAttention):
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Ignore copy
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -329,13 +340,13 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -329,13 +340,13 @@ class GemmaFlashAttention2(GemmaAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache): if isinstance(past_key_value, StaticCache):
raise ValueError( raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
) )
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -351,8 +362,8 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -351,8 +362,8 @@ class GemmaFlashAttention2(GemmaAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
...@@ -397,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -397,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention):
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
) )
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
if not output_attentions: if not output_attentions:
...@@ -503,7 +514,6 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -503,7 +514,6 @@ class GemmaFlashAttention2(GemmaAttention):
) )
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemma
class GemmaSdpaAttention(GemmaAttention): class GemmaSdpaAttention(GemmaAttention):
""" """
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -511,7 +521,7 @@ class GemmaSdpaAttention(GemmaAttention): ...@@ -511,7 +521,7 @@ class GemmaSdpaAttention(GemmaAttention):
SDPA API. SDPA API.
""" """
# Ignore copy # Adapted from GemmaAttention.forward
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -548,8 +558,8 @@ class GemmaSdpaAttention(GemmaAttention): ...@@ -548,8 +558,8 @@ class GemmaSdpaAttention(GemmaAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
...@@ -584,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention): ...@@ -584,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1) attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
...@@ -598,7 +608,6 @@ GEMMA_ATTENTION_CLASSES = { ...@@ -598,7 +608,6 @@ GEMMA_ATTENTION_CLASSES = {
} }
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
class GemmaDecoderLayer(nn.Module): class GemmaDecoderLayer(nn.Module):
def __init__(self, config: GemmaConfig, layer_idx: int): def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__() super().__init__()
...@@ -692,9 +701,8 @@ class GemmaPreTrainedModel(PreTrainedModel): ...@@ -692,9 +701,8 @@ class GemmaPreTrainedModel(PreTrainedModel):
config_class = GemmaConfig config_class = GemmaConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
_no_split_modules = ["GemmaDecoderLayer"] _no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
...@@ -713,6 +721,9 @@ class GemmaPreTrainedModel(PreTrainedModel): ...@@ -713,6 +721,9 @@ class GemmaPreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
_CONFIG_FOR_DOC = "GemmaConfig"
GEMMA_INPUTS_DOCSTRING = r""" GEMMA_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -821,7 +832,6 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -821,7 +832,6 @@ class GemmaModel(GemmaPreTrainedModel):
self.embed_tokens = value self.embed_tokens = value
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
# Ignore copy
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
...@@ -989,6 +999,8 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -989,6 +999,8 @@ class GemmaModel(GemmaPreTrainedModel):
if attention_mask is not None and attention_mask.dim() == 4: if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask causal_mask = attention_mask
else: else:
causal_mask = torch.full( causal_mask = torch.full(
...@@ -1020,7 +1032,6 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -1020,7 +1032,6 @@ class GemmaModel(GemmaPreTrainedModel):
return causal_mask return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMA,Llama->Gemma,llama->gemma
class GemmaForCausalLM(GemmaPreTrainedModel): class GemmaForCausalLM(GemmaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
...@@ -1051,7 +1062,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1051,7 +1062,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.model return self.model
# Ignore copy
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -1244,7 +1254,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1244,7 +1254,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
""", """,
GEMMA_START_DOCSTRING, GEMMA_START_DOCSTRING,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMA,Llama->Gemma
class GemmaForSequenceClassification(GemmaPreTrainedModel): class GemmaForSequenceClassification(GemmaPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1360,7 +1369,6 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): ...@@ -1360,7 +1369,6 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
""", """,
GEMMA_START_DOCSTRING, GEMMA_START_DOCSTRING,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA
class GemmaForTokenClassification(GemmaPreTrainedModel): class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""PyTorch LLaMA model."""
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
......
...@@ -559,8 +559,11 @@ def get_indent(code: str) -> str: ...@@ -559,8 +559,11 @@ def get_indent(code: str) -> str:
return "" return ""
def run_ruff(code): def run_ruff(code, check=False):
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"] if check:
command = ["ruff", "check", "-", "--fix", "--exit-zero"]
else:
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
stdout, _ = process.communicate(input=code.encode()) stdout, _ = process.communicate(input=code.encode())
return stdout.decode() return stdout.decode()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment