Unverified Commit 96eb0628 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Diff converter v2 (#30868)

* current working example!

* commit regex and result file

* update

* nit

* push the conversion file

* oups

* roadmap and nits

* attempt diffs for 3 files

* persimmon

* nit

* add diff file that is the same as the modeling_llama.py

* fix rope nits

* updates

* updates with converted versions

* give some breathing space to the code

* delete

* update

* update

* push the actual result

* update regex patterns

* update regex patterns

* fix some issues

* fix some issues

* fix some issues

* updates

* updates

* updates

* updates

* updates

* revert changes done to llama

* updates

* update gemma

* updates

* oups

* current state

* current state

* update

* ouiiii

* nit

* clear diffs

* nit

* fixup

* update

* doc 🚀

* 🔥

* for now use gemma

* deal with comments

* style

* handle funtions

* deal with assigns

* todos

* process inheritage

* keep decorators?

* 🤗

* deal with duplicates

* fixup

* correctly remove duplicate code

* run ruff post script

* ruff deals pretty well with imports, let's leave it to him

* ah maybe not lol

* for now remove all imports from child.

* nit

* conversion of llama

* okay

* convert starcoder2

* synch with main

* update llama diff

* updates

* https://docs.astral.sh/ruff/rules/redefined-while-unused/

 fixes the imports, bit needs later version of ruff

* updates

* okay actual state

* non zero exit

* update!

* revert unrelated

* remove other diff files

* updates

* cleanup

* update

* less diff!

* stash

* current updates

* updates

* No need for call

* finished fining deps

* update

* current changes

* current state

* current state

* new status

* nit

* finally

* fixes

* nits

* order is now expected

* use logger info instead of prints

* fixup

* up

* nit

* update

* nits

* update

* correct merge

* update

* update

* update

* add warning

* update caution message

* update

* better merging strategy

* copy class statements :wink

* fixups

* nits

* update

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* nits

* smaller header

* do cleanup some stuff

* even simpler header?

* fixup

* updates

* ruff

* update examples

* nit

* TODO

* state

* OUUUUUUF

* current state

* nits

* final state

* add a readme

* fixup

* remove diff llama

* fix

* nit

* dummy noy funny

* ruff format tests src utils --check

* everless diffs

* less diffs and fix test

* fixes

* naming nit?

* update converter and add supper example

* nits

* updated for function signatures

* update

* update

* add converted dummies

* autoformat

* single target assign fix

* fixup

* fix some imports

* fixes

* don't push them

* `# noqa: F841`

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 372baec2
# Using the `diff_converter` linter
`pip install libcst` is a must!
# `sh examples/diff-conversion/convert_examples.sh` to get the converted outputs
The diff converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular `diff` file like `diff_gemma.py` into a `single model single file`.
Examples of possible usage are available in the `examples/diff-conversion`, or `diff_gemma` for a full model usage.
`python utils/diff_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model2.py"`
## How it works
We use the `libcst` parser to produce an AST representation of the `diff_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the difference dependencies.
The code from the `diff` file and the class dependency mapping are "merged" to produce the single model single file.
We use ruff to automatically remove the potential duplicate imports.
## Why we use libcst instead of the native AST?
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`
\ No newline at end of file
#!/bin/bash
# Iterate over each file in the current directory
for file in examples/diff-conversion/diff_*; do
# Check if it's a regular file
if [ -f "$file" ]; then
# Call the Python script with the file name as an argument
python utils/diff_model_converter.py --files_to_parse "$file"
fi
done
\ No newline at end of file
from math import log
from typing import List, Optional, Tuple, Union
import torch
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
def _pre_process_input(input_ids):
print(log(input_ids))
return input_ids
# example where we need some deps and some functions
class DummyModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
input_ids = _pre_process_input(input_ids)
return super().forward(
None,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
)
from transformers.models.llama.configuration_llama import LlamaConfig
# Example where we only want to only add a new config argument and new arg doc
# here there is no `ARG` so we are gonna take parent doc
class MyNewModelConfig(LlamaConfig):
r"""
mlp_bias (`bool`, *optional*, defaults to `False`)
"""
def __init__(self, mlp_bias=True, new_param=0, **super_kwargs):
self.mlp_bias = mlp_bias
self.new_param = new_param
super().__init__(self, **super_kwargs)
from transformers.models.gemma.modeling_gemma import GemmaForSequenceClassification
from transformers.models.llama.configuration_llama import LlamaConfig
# Example where we only want to only modify the docstring
class MyNewModel2Config(LlamaConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
# Example where alllllll the dependencies are fetched to just copy the entire class
class MyNewModel2ForSequenceClassification(GemmaForSequenceClassification):
pass
# Example where we only want to overwrite the defaults of an init
from transformers.models.gemma.configuration_gemma import GemmaConfig
class NewModelConfig(GemmaConfig):
def __init__(
self,
vocab_size=256030,
hidden_size=64,
intermediate_size=90,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=1500,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
):
super().__init__(self)
from typing import List, Optional, Tuple, Union
import torch
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
# example where we need some deps and some functions
class SuperModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
out = super().forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
)
out.logits *= 2**4
return out
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,13 +19,9 @@ ...@@ -12,13 +19,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Gemma model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__) from transformers import PretrainedConfig
class GemmaConfig(PretrainedConfig): class GemmaConfig(PretrainedConfig):
...@@ -26,13 +29,9 @@ class GemmaConfig(PretrainedConfig): ...@@ -26,13 +29,9 @@ class GemmaConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B. defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
vocab_size (`int`, *optional*, defaults to 256000): vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
...@@ -83,16 +82,12 @@ class GemmaConfig(PretrainedConfig): ...@@ -83,16 +82,12 @@ class GemmaConfig(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
```python ```python
>>> from transformers import GemmaModel, GemmaConfig >>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration >>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig() >>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration >>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration) >>> model = GemmaModel(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
......
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PretrainedConfig
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
apply_rotary_pos_emb,
repeat_kv,
)
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_outputs import CausalLMOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import logging
logger = logging.get_logger(__name__)
class GemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The legacy activation function. It is overwritten by the `hidden_activation`.
hidden_activation (`str` or `function`, *optional*):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
class GemmaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_activation is None:
logger.warning_once(
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
"`config.hidden_activation` if you want to override this behaviour.\n"
"See https://github.com/huggingface/transformers/pull/29402 for more details."
)
config.hidden_activation = "gelu_pytorch_tanh"
hidden_activation = config.hidden_activation
self.act_fn = ACT2FN[hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class GemmaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if self.hidden_size % self.num_heads != 0:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.rotary_emb = GemmaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class GemmaModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions
hidden_states = inputs_embeds
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
return super().forward(
causal_mask,
position_ids,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
input_ids=None,
inputs_embeds=hidden_states,
)
# Example where we ony modify the docstring and call super
class GemmaForCausalLM(LlamaForCausalLM):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class GemmaForSequenceClassification(LlamaForSequenceClassification):
pass
class GemmaForTokenClassification(LlamaForTokenClassification):
pass
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
# #
...@@ -13,8 +19,6 @@ ...@@ -13,8 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""PyTorch Gemma model."""
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -26,10 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -26,10 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
...@@ -37,7 +38,7 @@ from ...modeling_outputs import ( ...@@ -37,7 +38,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
...@@ -46,7 +47,6 @@ from ...utils import ( ...@@ -46,7 +47,6 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.import_utils import is_torch_fx_available
from .configuration_gemma import GemmaConfig from .configuration_gemma import GemmaConfig
...@@ -55,25 +55,14 @@ if is_flash_attn_2_available(): ...@@ -55,25 +55,14 @@ if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
if not is_torch_greater_or_equal_than_1_13:
import torch.fx
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GemmaConfig"
def _get_unpad_data(attention_mask): def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return ( return (
indices, indices,
cu_seqlens, cu_seqlens,
...@@ -108,7 +97,6 @@ class GemmaRotaryEmbedding(nn.Module): ...@@ -108,7 +97,6 @@ class GemmaRotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
...@@ -130,7 +118,35 @@ class GemmaRotaryEmbedding(nn.Module): ...@@ -130,7 +118,35 @@ class GemmaRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def forward(self, x, position_ids):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids)
return cos, sin
class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
"""GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def forward(self, x, position_ids):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
cos, sin = super().forward(x, position_ids)
return cos, sin
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
...@@ -138,7 +154,6 @@ def rotate_half(x): ...@@ -138,7 +154,6 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -190,7 +205,6 @@ class GemmaMLP(nn.Module): ...@@ -190,7 +205,6 @@ class GemmaMLP(nn.Module):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
""" """
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
...@@ -206,7 +220,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: ...@@ -206,7 +220,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class GemmaAttention(nn.Module): class GemmaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
# Ignore copy
def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -303,7 +316,6 @@ class GemmaAttention(nn.Module): ...@@ -303,7 +316,6 @@ class GemmaAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemma
class GemmaFlashAttention2(GemmaAttention): class GemmaFlashAttention2(GemmaAttention):
""" """
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
...@@ -319,7 +331,6 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -319,7 +331,6 @@ class GemmaFlashAttention2(GemmaAttention):
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Ignore copy
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -329,13 +340,13 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -329,13 +340,13 @@ class GemmaFlashAttention2(GemmaAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache): if isinstance(past_key_value, StaticCache):
raise ValueError( raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
) )
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -351,8 +362,8 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -351,8 +362,8 @@ class GemmaFlashAttention2(GemmaAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
...@@ -397,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -397,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention):
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
) )
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
if not output_attentions: if not output_attentions:
...@@ -503,7 +514,6 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -503,7 +514,6 @@ class GemmaFlashAttention2(GemmaAttention):
) )
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemma
class GemmaSdpaAttention(GemmaAttention): class GemmaSdpaAttention(GemmaAttention):
""" """
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -511,7 +521,7 @@ class GemmaSdpaAttention(GemmaAttention): ...@@ -511,7 +521,7 @@ class GemmaSdpaAttention(GemmaAttention):
SDPA API. SDPA API.
""" """
# Ignore copy # Adapted from GemmaAttention.forward
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -548,8 +558,8 @@ class GemmaSdpaAttention(GemmaAttention): ...@@ -548,8 +558,8 @@ class GemmaSdpaAttention(GemmaAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
...@@ -584,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention): ...@@ -584,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1) attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
...@@ -598,7 +608,6 @@ GEMMA_ATTENTION_CLASSES = { ...@@ -598,7 +608,6 @@ GEMMA_ATTENTION_CLASSES = {
} }
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
class GemmaDecoderLayer(nn.Module): class GemmaDecoderLayer(nn.Module):
def __init__(self, config: GemmaConfig, layer_idx: int): def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__() super().__init__()
...@@ -692,9 +701,8 @@ class GemmaPreTrainedModel(PreTrainedModel): ...@@ -692,9 +701,8 @@ class GemmaPreTrainedModel(PreTrainedModel):
config_class = GemmaConfig config_class = GemmaConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
_no_split_modules = ["GemmaDecoderLayer"] _no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
...@@ -713,6 +721,9 @@ class GemmaPreTrainedModel(PreTrainedModel): ...@@ -713,6 +721,9 @@ class GemmaPreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
_CONFIG_FOR_DOC = "GemmaConfig"
GEMMA_INPUTS_DOCSTRING = r""" GEMMA_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -821,7 +832,6 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -821,7 +832,6 @@ class GemmaModel(GemmaPreTrainedModel):
self.embed_tokens = value self.embed_tokens = value
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
# Ignore copy
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
...@@ -989,6 +999,8 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -989,6 +999,8 @@ class GemmaModel(GemmaPreTrainedModel):
if attention_mask is not None and attention_mask.dim() == 4: if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask causal_mask = attention_mask
else: else:
causal_mask = torch.full( causal_mask = torch.full(
...@@ -1020,7 +1032,6 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -1020,7 +1032,6 @@ class GemmaModel(GemmaPreTrainedModel):
return causal_mask return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMA,Llama->Gemma,llama->gemma
class GemmaForCausalLM(GemmaPreTrainedModel): class GemmaForCausalLM(GemmaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
...@@ -1051,7 +1062,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1051,7 +1062,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.model return self.model
# Ignore copy
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -1244,7 +1254,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1244,7 +1254,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
""", """,
GEMMA_START_DOCSTRING, GEMMA_START_DOCSTRING,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMA,Llama->Gemma
class GemmaForSequenceClassification(GemmaPreTrainedModel): class GemmaForSequenceClassification(GemmaPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1360,7 +1369,6 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): ...@@ -1360,7 +1369,6 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
""", """,
GEMMA_START_DOCSTRING, GEMMA_START_DOCSTRING,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA
class GemmaForTokenClassification(GemmaPreTrainedModel): class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""PyTorch LLaMA model."""
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
......
...@@ -559,7 +559,10 @@ def get_indent(code: str) -> str: ...@@ -559,7 +559,10 @@ def get_indent(code: str) -> str:
return "" return ""
def run_ruff(code): def run_ruff(code, check=False):
if check:
command = ["ruff", "check", "-", "--fix", "--exit-zero"]
else:
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"] command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
stdout, _ = process.communicate(input=code.encode()) stdout, _ = process.communicate(input=code.encode())
......
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import importlib
import re
from typing import Dict
import libcst as cst
from check_copies import run_ruff
from libcst import ClassDef, CSTTransformer, CSTVisitor
from libcst import matchers as m
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
from transformers import logging
logger = logging.get_logger(__name__)
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
"""
def get_module_source_from_name(module_name: str) -> str:
# Extract the source code from the module name
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
return f"Module {module_name} not found"
with open(spec.origin, "r") as file:
source_code = file.read()
return source_code
class ClassFinder(CSTVisitor):
"""A visitor class which analyses a module, creating a mapping of dependencies between classes and functions.
For example if the visited code has
```python3
def init_value(): return 1
class LlamaModel(PreTrainedModel):
def __init__(self):
super().__init__(self)
self.value = init_value()
```
then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]}
The dependency mapping is updated via the `visit_Name`, `visit_Arg` and `visit_Decorator`. This is very broad, and by
checking the parent node, or the scope of a `cst.Name` or `cst.Arg` or `cst.Decorator` we are able to map the
dependence parent -> child.
When visiting such nodes, we update the dependency of the parent node, to take into account the visited node.
All `visit_XXX` correspond to the code executed when vising the cst.Node of type XXX.
"""
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
def __init__(self, python_module: cst.Module):
# fmt: off
self.python_module: cst.Module = python_module # original cst.Module being visited
self.classes: Dict[str, cst.ClassDef] = {} # stores a mapping from classname to the cst.Node
self.imports = {} # stores all import statements
self.function_def = {} # stores global scope function definition
self.assignments = {} # LLAMA_DOCSTRING
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
# fmt: on
def _update_class_dependency(self, name, value):
"""Update the dependency mapping for `name` with `value` by appending the previous
dependencies to the new `value`.
"""
dep = set(self.class_dependency_mapping.get(value, set()))
dep |= set(self.class_dependency_mapping.get(name, {})) | set({value})
self.class_dependency_mapping[name] = dep
def visit_ClassDef(self, node: ClassDef) -> None:
"""We don't have non global scope class defs in transformers. Here we add the inheritance dependencies"""
self.classes[node.name.value] = node
for k in node.bases: # deal with inheritance
base_name = self.python_module.code_for_node(k)
self._update_class_dependency(node.name.value, base_name)
def visit_SimpleStatementLine(self, node):
"""
Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements
are extracted and saved in their corresponding dict. They are then used when updating dependency mappings.
"""
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
):
self.assignments[node.body[0].targets[0].target.value] = node
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
self.imports[node.body[0].names] = node
def visit_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.function_def[node.name.value] = node
def leave_If(self, node):
for stmt in node.body.body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
self.imports[stmt.body[0].names] = node
def leave_Name(self, node):
if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
parent = self.get_metadata(cst.metadata.ScopeProvider, node)
if not isinstance(parent, cst.metadata.scope_provider.GlobalScope):
self._update_class_dependency(parent._name_prefix.split(".")[0], node.value)
def leave_Arg(self, node):
if m.matches(node.value, m.Name()):
parent = self.get_metadata(ParentNodeProvider, node)
if m.matches(parent, m.ClassDef()) and parent.bases:
self._update_class_dependency(parent.name.value, node.value.value)
def leave_Dict(self, node):
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent, m.Assign(targets=[m.AssignTarget()])):
name = parent.targets[0].target.value
if name in self.assignments:
for k in node.elements:
dep_name = k.value.value
if dep_name in self.classes:
self._update_class_dependency(name, dep_name)
def leave_Decorator(self, node):
if hasattr(node.decorator, "args"):
for k in node.decorator.args:
if k.value.value in self.assignments:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
self._update_class_dependency(name, k.value.value)
def leave_Module(self, node):
"""When leaving the module, we store the position of each global scoped node (Assigns, function def and class def)
to allow sorting the dependencies based on their position in the code. We use the PositionProvider metadata wrapper for this.
"""
self.global_nodes = {**self.assignments, **self.classes, **self.function_def}
# now sort the class dependency_mapping based on the position of the nodes
self.class_start_line = {}
for id, node in self.global_nodes.items():
self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references.
It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING.
Supported renaming patterns:
- llama -> my_new_model and my_new_model -> llama
- Llama -> MyNewModel and MyNewModel -> Llama
- LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA
- LLaMa -> MyNewModel abd MyNewModel -> Llama
"""
def __init__(self, old_name, new_name):
super().__init__()
self.old_name = old_name
self.new_name = new_name
self.default_name = "".join(x.title() for x in new_name.split("_"))
self.patterns = {
old_name: new_name,
old_name.upper(): new_name.upper(),
"".join(x.title() for x in old_name.split("_")): self.default_name,
}
def preserve_case_replace(self, text):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys())
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)
def replace(match):
word = match.group(0)
return self.patterns.get(word, self.default_name)
return compiled_regex.sub(replace, text)
@m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
update = self.preserve_case_replace(updated_node.value)
return updated_node.with_changes(value=update)
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma"):
"""Helper function to rename and then parse a source file using the ClassFinder"""
transformer = ReplaceNameTransformer(old_id, new_id)
new_module = module.visit(transformer)
wrapper = MetadataWrapper(new_module)
class_finder = ClassFinder(new_module)
wrapper.visit(class_finder)
return class_finder
DOCSTRING_NODE = m.SimpleStatementLine(
body=[
m.Expr(
value=m.SimpleString(
# match anything between """ """
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
)
)
]
)
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, python_module: cst.Module, original_methods, updated_methods):
self.python_module = python_module
self.original_methods = original_methods
self.updated_methods = updated_methods
def update_body(self, existing_body, new_statements):
"""
Helper method to update the body by removing duplicates before adding new statements.
"""
deduplicated_new_body = []
existing_nodes = {
self.python_module.code_for_node(node).strip() for node in new_statements if isinstance(node, cst.CSTNode)
}
for stmt in existing_body:
if self.python_module.code_for_node(stmt).strip() not in existing_nodes:
if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
continue
deduplicated_new_body.append(stmt)
existing_nodes.add(stmt)
else:
logger.info(f"\nFound duplicate {self.python_module.code_for_node(stmt)}")
return deduplicated_new_body
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
"""Updates the body of the input `node`'s `func_name` function by replacing calls
to super().func_name() with the source code of the parent class' `func_name`.
It keeps everything that is defined before `super().func_name()`.
"""
new_body = []
self.has_docstring = False
for expr in node.body:
self.has_docstring = m.matches(node.body[0], DOCSTRING_NODE)
if m.matches(
expr,
m.SimpleStatementLine(
body=[
m.Return(
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
)
| m.Expr(
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
)
]
),
):
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
else:
new_body.append(expr)
return node.with_changes(body=new_body)
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
if updated_node.name.value in self.updated_methods:
name = updated_node.name.value
new_body = self.replace_super_calls(updated_node.body, name)
return updated_node.with_changes(body=new_body, params=updated_node.params)
return updated_node
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
""" "When a return statement is reached, it is replaced with the unrolled super code"""
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
func_def = self.get_metadata(ParentNodeProvider, original_node)
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
updated_return_value = updated_node.value.with_changes(
args=[
cst.Arg(
value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))])
)
]
)
return updated_node.with_changes(value=updated_return_value)
return updated_node
def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str):
"""
Given the `class_name`, the `updated_node`'s call to super are unpacked.
| ```python | | ```python
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
| def __init__(self): | | def __init__(self):
Going from: | self.dropout = 0.2 | to: | self.dropout = 0.2
| super().__init__() | | super().__init__(config)
| ``` | | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
| self.layers = nn.ModuleList(
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
| )
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
| self.gradient_checkpointing = False
| # Initialize weights and apply final processing
| self.post_init()
| ```
"""
original_node = class_finder.classes[class_name]
original_methods = {f.name.value if hasattr(f, "name") else f: f for f in original_node.body.body}
updated_methods = {f.name.value if hasattr(f, "name") else f: f for f in updated_node.body.body}
end_meth = []
for name, func in original_methods.items():
if name in updated_methods and updated_methods[name] is not None:
new_params = updated_methods[name].params
# Replace the method in the replacement class, preserving decorators
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
if kwarg_name and kwarg_name.name.value == "super_kwargs":
parent_params = {k.name.value: k for k in func.params.params}
parent_params.update({k.name.value: k for k in new_params.params[1:]})
new_params = new_params.with_changes(
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
)
func = func.with_changes(body=updated_methods[name].body, params=new_params)
end_meth.append(func)
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods))
new_replacement_body = new_replacement_class.body[0].body # get the indented block
return original_node.with_changes(body=new_replacement_body)
class DiffConverterTransformer(CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
def __init__(self, python_module, new_name):
super().__init__()
self.model_name = (
new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3`
)
# fmt: off
self.python_module = python_module # we store the original module to use `code_for_node`
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
self.new_body = {} # store the new body, all global scope nodes should be added here
self.inserted_deps = [] # nodes inserted via super dependency
self.all_imports = [] # just stores all of the imports
self.global_scope_index = 0
# fmt: on
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""When visiting imports from `transformers.models.xxx` we need to:
1. Get the original source code
2. Parse it into an AST Tree
3. Add this import to `self.transformers_imports` as visited to not parse it twice
"""
import_statement = self.python_module.code_for_node(node.module)
if m.matches(node.module, m.Attribute()):
for imported_ in node.names:
_import = re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", import_statement)
if _import:
source = _import.groups()[0]
if source == "modeling" and "Config" in self.python_module.code_for_node(imported_):
raise ValueError(
f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead"
)
if import_statement not in self.transformers_imports:
source_code = get_module_source_from_name(import_statement)
tree = cst.parse_module(source_code)
self.transformers_imports[import_statement] = tree
imported_class = self.python_module.code_for_node(imported_.name)
self.imported_mapping[imported_class] = import_statement
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.global_scope_index += 100
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
return node
def leave_SimpleStatementLine(self, original_node, updated_node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
if parent_node not in self.all_imports:
self.all_imports.append(updated_node)
return updated_node
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
return cst.RemoveFromParent()
if parent_node not in self.all_imports:
self.all_imports.append(updated_node)
return updated_node
self.global_scope_index += 100
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Assign()])):
# TODO This only works for single target assigns!
node_name = updated_node.body[0].targets[0].target.value
else:
node_name = self.python_module.code_for_node(updated_node.body[0])
self.new_body[node_name] = {
"insert_idx": self.global_scope_index,
"node": updated_node,
}
return updated_node
def leave_ClassDef(self, original_node, updated_node):
"""
1. Filter the `base` classes of this class
If they are from `transformers.models.xx` then:
- take the AST tree of the module it comes from and parse it with a `ClassFinder`.
- rename all every instance of `old_name` (llama) to `new_name` (gemma)
2. We insert the modules which the inherited base depends on. This has to be done in
the order of the dependencies. If on is already in the new_body (because it's defined in the diff file)
then we remove it from the new body to add it again in the correct order.
3. Replace the calls to `super().xxxx` merging parent code
"""
class_name = original_node.name.value
bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping]
self.global_scope_index += 100
for super_class in bases:
if super_class not in self.imported_mapping:
raise ImportError(
f"{super_class} was not imported using `from transformers.models.xxxxx.modeling_xxxx import {super_class}"
)
super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree
model_name = re.search(r"_(\S*)", super_file_name)
if model_name:
model_name = model_name.groups()[0]
else:
raise ValueError(
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
)
if super_file_name not in self.visited_module: # only extract classes once
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name], model_name, self.model_name
)
self.visited_module[super_file_name] = class_finder
else: # we are re-using the previously parsed data
class_finder = self.visited_module[super_file_name]
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
start_insert_idx = self.global_scope_index
for dependency, _ in list_dependencies:
node = class_finder.global_nodes.get(dependency, None)
if node is not None:
if dependency not in self.new_body:
start_insert_idx -= 1
self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node}
elif dependency not in self.inserted_deps:
# make sure the node is written after it's dependencies
start_insert_idx = self.new_body[dependency]["insert_idx"] - 1
self.inserted_deps.append(dependency)
if len(list_dependencies) > 0:
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
if "Config" in class_name:
self.config_body = [updated_node]
else:
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
return updated_node
def leave_If(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
full_statement = self.python_module.code_for_node(original_node.test)
if re.search(r"[\s\S]*is_.*available", full_statement):
self.all_imports.append(node)
elif full_statement not in self.new_body:
self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node}
return node
def leave_Module(self, original_node: cst.Assign, node):
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
dependency_imports = {}
for visiter in self.visited_module.values():
dependency_imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()})
if hasattr(self, "config_body"):
self.config_body = list(imports.values()) + self.config_body
dependency_imports.update(imports)
new_body = list(dependency_imports.values())
if len(self.new_body.keys()) > 0:
new_body += [k[1]["node"] for k in sorted(self.new_body.items(), key=lambda x: x[1]["insert_idx"])]
else:
new_body = []
return node.with_changes(body=[*new_body])
def convert_file(diff_file, cst_transformers=None):
model_name = re.search(r"diff_(.*)(?=\.py$)", diff_file).groups()[0]
# Parse the Python file
with open(diff_file, "r") as file:
code = file.read()
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
if cst_transformers is None:
cst_transformers = DiffConverterTransformer(module, model_name)
new_mod = wrapper.visit(cst_transformers)
ruffed_code = run_ruff(new_mod.code, True)
formatted_code = run_ruff(ruffed_code, False)
if len(formatted_code.strip()) > 0:
with open(diff_file.replace("diff_", "modeling_"), "w") as f:
f.write(AUTO_GENERATED_MESSAGE + formatted_code)
if hasattr(cst_transformers, "config_body"):
config_module = cst.Module(body=[*cst_transformers.config_body], header=new_mod.header)
with open(diff_file.replace("diff_", "configuration_"), "w") as f:
ruffed_code = run_ruff(config_module.code, True)
formatted_code = run_ruff(ruffed_code, False)
f.write(AUTO_GENERATED_MESSAGE + formatted_code)
# TODO optimize by re-using the class_finder
return cst_transformers
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model.py"],
nargs="+",
help="A list of `diff_xxxx` files that should be converted to single model file",
)
args = parser.parse_args()
if args.files_to_parse == ["all"]:
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
for file_name in args.files_to_parse:
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converter = convert_file(file_name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment