Unverified Commit d4e92f1a authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Remove add-new-model in favor of add-new-model-like (#30424)

* Remove add-new-model in favor of add-new-model-like

* nits
parent 0eb8fbcd
name: Model templates runner
on:
repository_dispatch:
schedule:
- cron: "0 2 * * *"
jobs:
run_tests_templates:
runs-on: ubuntu-22.04
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install dependencies
run: |
sudo apt -y update && sudo apt install -y libsndfile1-dev
- name: Load cached virtual environment
uses: actions/cache@v2
id: cache
with:
path: ~/venv/
key: v4-tests_templates-${{ hashFiles('setup.py') }}
- name: Create virtual environment on cache miss
if: steps.cache.outputs.cache-hit != 'true'
run: |
python -m venv ~/venv && . ~/venv/bin/activate
pip install --upgrade pip!=21.3
pip install -e .[dev]
- name: Check transformers location
# make `transformers` available as package (required since we use `-e` flag) and check it's indeed from the repo.
run: |
. ~/venv/bin/activate
python setup.py develop
transformer_loc=$(pip show transformers | grep "Location: " | cut -c11-)
transformer_repo_loc=$(pwd .)
if [ "$transformer_loc" != "$transformer_repo_loc/src" ]; then
echo "transformers is from $transformer_loc but it shoud be from $transformer_repo_loc/src."
echo "A fix is required. Stop testing."
exit 1
fi
- name: Create model files
run: |
. ~/venv/bin/activate
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/encoder-bert-tokenizer.json --path=templates/adding_a_new_model
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/pt-encoder-bert-tokenizer.json --path=templates/adding_a_new_model
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/standalone.json --path=templates/adding_a_new_model
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/tf-encoder-bert-tokenizer.json --path=templates/adding_a_new_model
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/tf-seq-2-seq-bart-tokenizer.json --path=templates/adding_a_new_model
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/pt-seq-2-seq-bart-tokenizer.json --path=templates/adding_a_new_model
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/flax-encoder-bert-tokenizer.json --path=templates/adding_a_new_model
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/flax-seq-2-seq-bart-tokenizer.json --path=templates/adding_a_new_model
make style
python utils/check_table.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
python utils/check_copies.py --fix_and_overwrite
- name: Run all non-slow tests
run: |
. ~/venv/bin/activate
python -m pytest -n 2 --dist=loadfile -s --make-reports=tests_templates tests/*template*
- name: Run style changes
run: |
. ~/venv/bin/activate
make style && make quality && make repo-consistency
- name: Failure short reports
if: ${{ always() }}
run: cat reports/tests_templates/failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: run_all_tests_templates_test_reports
path: reports/tests_templates
...@@ -110,7 +110,7 @@ New models are constantly released and if you want to implement a new model, ple ...@@ -110,7 +110,7 @@ New models are constantly released and if you want to implement a new model, ple
If you are willing to contribute the model yourself, let us know so we can help you add it to 🤗 Transformers! If you are willing to contribute the model yourself, let us know so we can help you add it to 🤗 Transformers!
We have added a [detailed guide and templates](https://github.com/huggingface/transformers/tree/main/templates) to help you get started with adding a new model, and we also have a more technical guide for [how to add a model to 🤗 Transformers](https://huggingface.co/docs/transformers/add_new_model). We have a technical guide for [how to add a model to 🤗 Transformers](https://huggingface.co/docs/transformers/add_new_model).
## Do you want to add documentation? ## Do you want to add documentation?
......
...@@ -398,12 +398,14 @@ In dem speziellen Fall, dass Sie ein Modell hinzufügen, dessen Architektur gena ...@@ -398,12 +398,14 @@ In dem speziellen Fall, dass Sie ein Modell hinzufügen, dessen Architektur gena
Modells übereinstimmt, müssen Sie nur ein Konvertierungsskript hinzufügen, wie in [diesem Abschnitt](#write-a-conversion-script) beschrieben. Modells übereinstimmt, müssen Sie nur ein Konvertierungsskript hinzufügen, wie in [diesem Abschnitt](#write-a-conversion-script) beschrieben.
In diesem Fall können Sie einfach die gesamte Modellarchitektur des bereits vorhandenen Modells wiederverwenden. In diesem Fall können Sie einfach die gesamte Modellarchitektur des bereits vorhandenen Modells wiederverwenden.
Andernfalls beginnen wir mit der Erstellung eines neuen Modells. Sie haben hier zwei Möglichkeiten: Andernfalls beginnen wir mit der Erstellung eines neuen Modells. Wir empfehlen die Verwendung des folgenden Skripts, um ein Modell hinzuzufügen
ein bestehendes Modell:
- `transformers-cli add-new-model-like`, um ein neues Modell wie ein bestehendes hinzuzufügen ```bash
- `transformers-cli add-new-model`, um ein neues Modell aus unserer Vorlage hinzuzufügen (sieht dann aus wie BERT oder Bart, je nachdem, welche Art von Modell Sie wählen) transformers-cli add-new-model-like
```
In beiden Fällen werden Sie mit einem Fragebogen aufgefordert, die grundlegenden Informationen zu Ihrem Modell auszufüllen. Für den zweiten Befehl müssen Sie `cookiecutter` installieren, weitere Informationen dazu finden Sie [hier](https://github.com/huggingface/transformers/tree/main/templates/adding_a_new_model). Sie werden mit einem Fragebogen aufgefordert, die grundlegenden Informationen Ihres Modells einzugeben.
**Eröffnen Sie einen Pull Request auf dem Haupt-Repositorium huggingface/transformers** **Eröffnen Sie einen Pull Request auf dem Haupt-Repositorium huggingface/transformers**
......
...@@ -98,7 +98,7 @@ Es werden ständig neue Modelle veröffentlicht. Wenn Sie ein neues Modell imple ...@@ -98,7 +98,7 @@ Es werden ständig neue Modelle veröffentlicht. Wenn Sie ein neues Modell imple
Lassen Sie es uns wissen, wenn Sie bereit sind, das Modell selbst beizutragen. Dann können wir Ihnen helfen, es zu 🤗 Transformers hinzuzufügen! Lassen Sie es uns wissen, wenn Sie bereit sind, das Modell selbst beizutragen. Dann können wir Ihnen helfen, es zu 🤗 Transformers hinzuzufügen!
Wir haben eine [detaillierte Anleitung und Vorlagen](https://github.com/huggingface/transformers/tree/main/templates) hinzugefügt, um Ihnen das Hinzufügen eines neuen Modells zu erleichtern, und wir haben auch einen technischen Leitfaden dazu, [wie man ein Modell zu 🤗 Transformers hinzufügt](https://huggingface.co/docs/transformers/add_new_model). Wir haben auch einen technischen Leitfaden dazu, [wie man ein Modell zu 🤗 Transformers hinzufügt](https://huggingface.co/docs/transformers/add_new_model).
## Möchten Sie die Dokumentation erweitern? ## Möchten Sie die Dokumentation erweitern?
......
...@@ -398,12 +398,14 @@ In the special case that you are adding a model whose architecture exactly match ...@@ -398,12 +398,14 @@ In the special case that you are adding a model whose architecture exactly match
existing model you only have to add a conversion script as described in [this section](#write-a-conversion-script). existing model you only have to add a conversion script as described in [this section](#write-a-conversion-script).
In this case, you can just re-use the whole model architecture of the already existing model. In this case, you can just re-use the whole model architecture of the already existing model.
Otherwise, let's start generating a new model. You have two choices here: Otherwise, let's start generating a new model. We recommend using the following script to add a model starting from
an existing model:
- `transformers-cli add-new-model-like` to add a new model like an existing one ```bash
- `transformers-cli add-new-model` to add a new model from our template (will look like BERT or Bart depending on the type of model you select) transformers-cli add-new-model-like
```
In both cases, you will be prompted with a questionnaire to fill in the basic information of your model. The second command requires to install `cookiecutter`, you can find more information on it [here](https://github.com/huggingface/transformers/tree/main/templates/adding_a_new_model). You will be prompted with a questionnaire to fill in the basic information of your model.
**Open a Pull Request on the main huggingface/transformers repo** **Open a Pull Request on the main huggingface/transformers repo**
......
...@@ -351,13 +351,14 @@ Nel caso speciale in cui stiate aggiungendo un modello, la cui architettura sia ...@@ -351,13 +351,14 @@ Nel caso speciale in cui stiate aggiungendo un modello, la cui architettura sia
dovrete solo aggiugnere uno script di conversione, come descritto [qui](#write-a-conversion-script). dovrete solo aggiugnere uno script di conversione, come descritto [qui](#write-a-conversion-script).
In questo caso, potete riutilizzare l'intera architettura del modello gia esistente. In questo caso, potete riutilizzare l'intera architettura del modello gia esistente.
Se questo non é il caso, cominciamo con il generare un nuovo modello. Avrete due opzioni: Se questo non é il caso, cominciamo con il generare un nuovo modello. Ti consigliamo di utilizzare il seguente script per aggiungere un modello a partire da
un modello esistente:
- `transformers-cli add-new-model-like` per aggiungere un nuovo modello come uno che gia esiste ```bash
- `transformers-cli add-new-model` per aggiungere un nuovo modello da un nostro template (questo assomigliera a BERT o Bart, in base al modello che selezionerete) transformers-cli add-new-model-like
```
In entrambi i casi, l'output vi darà un questionario da riempire con informazioni basi sul modello. Il secondo comando richiede di installare Ti verrà richiesto con un questionario di compilare le informazioni di base del tuo modello.
un `cookiecutter` - maggiori informazioni [qui](https://github.com/huggingface/transformers/tree/main/templates/adding_a_new_model).
**Aprire una Pull Request in main huggingface/transformers repo** **Aprire una Pull Request in main huggingface/transformers repo**
......
...@@ -307,14 +307,15 @@ cd transformers ...@@ -307,14 +307,15 @@ cd transformers
[このセクション](#write-a-conversion-script)で説明されているように、変換スクリプトを追加するだけで済みます。 [このセクション](#write-a-conversion-script)で説明されているように、変換スクリプトを追加するだけで済みます。
この場合、既存のモデルの完全なモデルアーキテクチャを再利用できます。 この場合、既存のモデルの完全なモデルアーキテクチャを再利用できます。
それ以外の場合、新しいモデルの生成を開始します。ここで2つの選択肢があります:
- `transformers-cli add-new-model-like`を使用して既存のモデルのような新しいモデルを追加します それ以外の場合は、新しいモデルの生成を開始しましょう。 次のスクリプトを使用して、以下から始まるモデルを追加することをお勧めします
- `transformers-cli add-new-model`を使用して、テンプレートから新しいモデルを追加します(モデルのタイプに応じてBERTまたはBartのように見えます) 既存のモデル:
どちらの場合でも、モデルの基本情報を入力するための質問事項が表示されます。 ```bash
2番目のコマンドを実行するには、`cookiecutter`をインストールする必要があります。 transformers-cli add-new-model-like
詳細については[こちら](https://github.com/huggingface/transformers/tree/main/templates/adding_a_new_model)をご覧ください。 ```
モデルの基本情報を入力するためのアンケートが表示されます。
**主要な huggingface/transformers リポジトリでプルリクエストを開く** **主要な huggingface/transformers リポジトリでプルリクエストを開く**
......
...@@ -268,12 +268,14 @@ cd transformers ...@@ -268,12 +268,14 @@ cd transformers
다음과 같이 이미 존재하는 모델의 모델 아키텍처와 정확히 일치하는 모델을 추가하는 특별한 경우에는 [이 섹션](#write-a-conversion-script)에 설명된대로 변환 스크립트만 추가하면 됩니다. 이 경우에는 이미 존재하는 모델의 전체 모델 아키텍처를 그대로 재사용할 수 있습니다. 다음과 같이 이미 존재하는 모델의 모델 아키텍처와 정확히 일치하는 모델을 추가하는 특별한 경우에는 [이 섹션](#write-a-conversion-script)에 설명된대로 변환 스크립트만 추가하면 됩니다. 이 경우에는 이미 존재하는 모델의 전체 모델 아키텍처를 그대로 재사용할 수 있습니다.
그렇지 않으면 새로운 모델 생성을 시작합시다. 여기에서 두 가지 선택지가 있습니다: 그렇지 않으면 새 모델 생성을 시작하겠습니다. 다음 스크립트를 사용하여 다음에서 시작하는 모델을 추가하는 것이 좋습니다.
기존 모델:
- `transformers-cli add-new-model-like`를 사용하여 기존 모델과 유사한 새로운 모델 추가하기 ```bash
- `transformers-cli add-new-model`을 사용하여 템플릿을 기반으로 한 새로운 모델 추가하기 (선택한 모델 유형에 따라 BERT 또는 Bart와 유사한 모습일 것입니다) transformers-cli add-new-model-like
```
두 경우 모두, 모델의 기본 정보를 입력하는 설문조사시됩니다. 두 번째 명령어는 `cookiecutter`를 설치해야 합니다. 자세한 정보는 [여기](https://github.com/huggingface/transformers/tree/main/templates/adding_a_new_model)에서 확인할 수 있습니다. 모델의 기본 정보를 입력하는 설문시됩니다.
**huggingface/transformers 메인 저장소에 Pull Request 열기** **huggingface/transformers 메인 저장소에 Pull Request 열기**
......
...@@ -99,7 +99,7 @@ python src/transformers/commands/transformers_cli.py env ...@@ -99,7 +99,7 @@ python src/transformers/commands/transformers_cli.py env
만약 모델을 직접 기여하고 싶으시다면, 알려주세요. 🤗 Transformers에 추가할 수 있도록 도와드리겠습니다! 만약 모델을 직접 기여하고 싶으시다면, 알려주세요. 🤗 Transformers에 추가할 수 있도록 도와드리겠습니다!
새로운 모델을 추가하는 방법에 대한 [상세 안내서와 템플릿](https://github.com/huggingface/transformers/tree/main/templates)을 제공하고 있으며, [🤗 Transformers에 새로운 모델을 추가하는 방법](https://huggingface.co/docs/transformers/add_new_model)에 대한 기술적인 안내서도 있습니다. [🤗 Transformers에 새로운 모델을 추가하는 방법](https://huggingface.co/docs/transformers/add_new_model)에 대한 기술적인 안내서도 있습니다.
## 문서를 추가하고 싶으신가요? [[do-you-want-to-add-documentation]] ## 문서를 추가하고 싶으신가요? [[do-you-want-to-add-documentation]]
......
...@@ -98,7 +98,7 @@ python src/transformers/commands/transformers_cli.py env ...@@ -98,7 +98,7 @@ python src/transformers/commands/transformers_cli.py env
如果你想亲自贡献模型,请告诉我们。让我们帮你把它添加到 🤗 Transformers! 如果你想亲自贡献模型,请告诉我们。让我们帮你把它添加到 🤗 Transformers!
我们已经添加了[详细的指南和模板](https://github.com/huggingface/transformers/tree/main/templates)来帮助你添加新模型。我们还有一个更技术性的指南,告诉你[如何将模型添加到 🤗 Transformers](https://huggingface.co/docs/transformers/add_new_model) 我们还有一个更技术性的指南,告诉你[如何将模型添加到 🤗 Transformers](https://huggingface.co/docs/transformers/add_new_model)
## 你想要添加文档吗? ## 你想要添加文档吗?
......
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import warnings
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import List
from ..utils import logging
from . import BaseTransformersCLICommand
try:
from cookiecutter.main import cookiecutter
_has_cookiecutter = True
except ImportError:
_has_cookiecutter = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def add_new_model_command_factory(args: Namespace):
return AddNewModelCommand(args.testing, args.testing_file, path=args.path)
class AddNewModelCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
add_new_model_parser = parser.add_parser("add-new-model")
add_new_model_parser.add_argument("--testing", action="store_true", help="If in testing mode.")
add_new_model_parser.add_argument("--testing_file", type=str, help="Configuration file on which to run.")
add_new_model_parser.add_argument(
"--path", type=str, help="Path to cookiecutter. Should only be used for testing purposes."
)
add_new_model_parser.set_defaults(func=add_new_model_command_factory)
def __init__(self, testing: bool, testing_file: str, path=None, *args):
self._testing = testing
self._testing_file = testing_file
self._path = path
def run(self):
warnings.warn(
"The command `transformers-cli add-new-model` is deprecated and will be removed in v5 of Transformers. "
"It is not actively maintained anymore, so might give a result that won't pass all tests and quality "
"checks, you should use `transformers-cli add-new-model-like` instead."
)
if not _has_cookiecutter:
raise ImportError(
"Model creation dependencies are required to use the `add_new_model` command. Install them by running "
"the following at the root of your `transformers` clone:\n\n\t$ pip install -e .[modelcreation]\n"
)
# Ensure that there is no other `cookiecutter-template-xxx` directory in the current working directory
directories = [directory for directory in os.listdir() if "cookiecutter-template-" == directory[:22]]
if len(directories) > 0:
raise ValueError(
"Several directories starting with `cookiecutter-template-` in current working directory. "
"Please clean your directory by removing all folders starting with `cookiecutter-template-` or "
"change your working directory."
)
path_to_transformer_root = (
Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent
)
path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model"
# Execute cookiecutter
if not self._testing:
cookiecutter(str(path_to_cookiecutter))
else:
with open(self._testing_file, "r") as configuration_file:
testing_configuration = json.load(configuration_file)
cookiecutter(
str(path_to_cookiecutter if self._path is None else self._path),
no_input=True,
extra_context=testing_configuration,
)
directory = [directory for directory in os.listdir() if "cookiecutter-template-" in directory[:22]][0]
# Retrieve configuration
with open(directory + "/configuration.json", "r") as configuration_file:
configuration = json.load(configuration_file)
lowercase_model_name = configuration["lowercase_modelname"]
generate_tensorflow_pytorch_and_flax = configuration["generate_tensorflow_pytorch_and_flax"]
os.remove(f"{directory}/configuration.json")
output_pytorch = "PyTorch" in generate_tensorflow_pytorch_and_flax
output_tensorflow = "TensorFlow" in generate_tensorflow_pytorch_and_flax
output_flax = "Flax" in generate_tensorflow_pytorch_and_flax
model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}"
os.makedirs(model_dir, exist_ok=True)
os.makedirs(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}", exist_ok=True)
# Tests require submodules as they have parent imports
with open(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/__init__.py", "w"):
pass
shutil.move(
f"{directory}/__init__.py",
f"{model_dir}/__init__.py",
)
shutil.move(
f"{directory}/configuration_{lowercase_model_name}.py",
f"{model_dir}/configuration_{lowercase_model_name}.py",
)
def remove_copy_lines(path):
with open(path, "r") as f:
lines = f.readlines()
with open(path, "w") as f:
for line in lines:
if "# Copied from transformers." not in line:
f.write(line)
if output_pytorch:
if not self._testing:
remove_copy_lines(f"{directory}/modeling_{lowercase_model_name}.py")
shutil.move(
f"{directory}/modeling_{lowercase_model_name}.py",
f"{model_dir}/modeling_{lowercase_model_name}.py",
)
shutil.move(
f"{directory}/test_modeling_{lowercase_model_name}.py",
f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_{lowercase_model_name}.py",
)
else:
os.remove(f"{directory}/modeling_{lowercase_model_name}.py")
os.remove(f"{directory}/test_modeling_{lowercase_model_name}.py")
if output_tensorflow:
if not self._testing:
remove_copy_lines(f"{directory}/modeling_tf_{lowercase_model_name}.py")
shutil.move(
f"{directory}/modeling_tf_{lowercase_model_name}.py",
f"{model_dir}/modeling_tf_{lowercase_model_name}.py",
)
shutil.move(
f"{directory}/test_modeling_tf_{lowercase_model_name}.py",
f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_tf_{lowercase_model_name}.py",
)
else:
os.remove(f"{directory}/modeling_tf_{lowercase_model_name}.py")
os.remove(f"{directory}/test_modeling_tf_{lowercase_model_name}.py")
if output_flax:
if not self._testing:
remove_copy_lines(f"{directory}/modeling_flax_{lowercase_model_name}.py")
shutil.move(
f"{directory}/modeling_flax_{lowercase_model_name}.py",
f"{model_dir}/modeling_flax_{lowercase_model_name}.py",
)
shutil.move(
f"{directory}/test_modeling_flax_{lowercase_model_name}.py",
f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_flax_{lowercase_model_name}.py",
)
else:
os.remove(f"{directory}/modeling_flax_{lowercase_model_name}.py")
os.remove(f"{directory}/test_modeling_flax_{lowercase_model_name}.py")
shutil.move(
f"{directory}/{lowercase_model_name}.md",
f"{path_to_transformer_root}/docs/source/en/model_doc/{lowercase_model_name}.md",
)
shutil.move(
f"{directory}/tokenization_{lowercase_model_name}.py",
f"{model_dir}/tokenization_{lowercase_model_name}.py",
)
shutil.move(
f"{directory}/tokenization_fast_{lowercase_model_name}.py",
f"{model_dir}/tokenization_{lowercase_model_name}_fast.py",
)
from os import fdopen, remove
from shutil import copymode, move
from tempfile import mkstemp
def replace(original_file: str, line_to_copy_below: str, lines_to_copy: List[str]):
# Create temp file
fh, abs_path = mkstemp()
line_found = False
with fdopen(fh, "w") as new_file:
with open(original_file) as old_file:
for line in old_file:
new_file.write(line)
if line_to_copy_below in line:
line_found = True
for line_to_copy in lines_to_copy:
new_file.write(line_to_copy)
if not line_found:
raise ValueError(f"Line {line_to_copy_below} was not found in file.")
# Copy the file permissions from the old file to the new file
copymode(original_file, abs_path)
# Remove original file
remove(original_file)
# Move new file
move(abs_path, original_file)
def skip_units(line):
return (
("generating PyTorch" in line and not output_pytorch)
or ("generating TensorFlow" in line and not output_tensorflow)
or ("generating Flax" in line and not output_flax)
)
def replace_in_files(path_to_datafile):
with open(path_to_datafile) as datafile:
lines_to_copy = []
skip_file = False
skip_snippet = False
for line in datafile:
if "# To replace in: " in line and "##" not in line:
file_to_replace_in = line.split('"')[1]
skip_file = skip_units(line)
elif "# Below: " in line and "##" not in line:
line_to_copy_below = line.split('"')[1]
skip_snippet = skip_units(line)
elif "# End." in line and "##" not in line:
if not skip_file and not skip_snippet:
replace(file_to_replace_in, line_to_copy_below, lines_to_copy)
lines_to_copy = []
elif "# Replace with" in line and "##" not in line:
lines_to_copy = []
elif "##" not in line:
lines_to_copy.append(line)
remove(path_to_datafile)
replace_in_files(f"{directory}/to_replace_{lowercase_model_name}.py")
os.rmdir(directory)
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from .add_new_model import AddNewModelCommand
from .add_new_model_like import AddNewModelLikeCommand from .add_new_model_like import AddNewModelLikeCommand
from .convert import ConvertCommand from .convert import ConvertCommand
from .download import DownloadCommand from .download import DownloadCommand
...@@ -38,7 +37,6 @@ def main(): ...@@ -38,7 +37,6 @@ def main():
RunCommand.register_subcommand(commands_parser) RunCommand.register_subcommand(commands_parser)
ServeCommand.register_subcommand(commands_parser) ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser) UserCommands.register_subcommand(commands_parser)
AddNewModelCommand.register_subcommand(commands_parser)
AddNewModelLikeCommand.register_subcommand(commands_parser) AddNewModelLikeCommand.register_subcommand(commands_parser)
LfsCommands.register_subcommand(commands_parser) LfsCommands.register_subcommand(commands_parser)
PTtoTFCommand.register_subcommand(commands_parser) PTtoTFCommand.register_subcommand(commands_parser)
......
...@@ -16,257 +16,8 @@ limitations under the License. ...@@ -16,257 +16,8 @@ limitations under the License.
# Adding a new model # Adding a new model
This folder contains templates to generate new models that fit the current API and pass all tests. It generates This page has been updated in light of the removal of the `add_new_model` script in favor of the more complete
models in both PyTorch, TensorFlow, and Flax and completes the `__init__.py` and auto-modeling files, and creates the `add_new_model_like` script.
documentation. Their use is described in the [next section](#cookiecutter-templates).
There is also a CLI tool to generate a new model like an existing one called `transformers-cli add-new-model-like`. We recommend you checkout the documentation of [How to add a model](https://huggingface.co/docs/transformers/main/en/add_new_model)
Jump to the [Add new model like section](#add-new-model-like-command) to learn how to use it. in the Hugging Face Transformers documentation for complete and up-to-date instructions.
## Cookiecutter Templates
Using the `cookiecutter` utility requires to have all the `dev` dependencies installed. Let's first clone the
repository and install it in our environment:
```shell script
git clone https://github.com/huggingface/transformers
cd transformers
pip install -e ".[dev]"
```
Depending on your OS, and since the number of optional dependencies of Transformers is growing, you might get a
failure with this command. If that's the case make sure to install the Deep Learning framework you are working with
(PyTorch, TensorFlow and/or Flax) then do:
```bash
pip install -e ".[quality]"
```
Once the installation is done, you can use the CLI command `add-new-model` to generate your models:
```shell script
transformers-cli add-new-model
```
This should launch the `cookiecutter` package which should prompt you to fill in the configuration.
The `modelname` should be cased according to the plain text casing, i.e., BERT, RoBERTa, DeBERTa.
```
modelname [<ModelNAME>]:
uppercase_modelname [<MODEL_NAME>]:
lowercase_modelname [<model_name>]:
camelcase_modelname [<ModelName>]:
```
Fill in the `authors` with your team members:
```
authors [The HuggingFace Team]:
```
The checkpoint identifier is the checkpoint that will be used in the examples across the files. Put the name you wish,
as it will appear on the modelhub. Do not forget to include the organisation.
```
checkpoint_identifier [organisation/<model_name>-base-cased]:
```
The tokenizer should either be based on BERT if it behaves exactly like the BERT tokenizer, or a standalone otherwise.
```
Select tokenizer_type:
1 - Based on BERT
2 - Standalone
Choose from 1, 2 [1]:
```
<!---
Choose if your model is an encoder-decoder, or an encoder-only architecture.
If your model is an encoder-only architecture, the generated architecture will be based on the BERT model.
If your model is an encoder-decoder architecture, the generated architecture will be based on the BART model. You can,
of course, edit the files once the generation is complete.
```
Select is_encoder_decoder_model:
1 - True
2 - False
Choose from 1, 2 [1]:
```
-->
Once the command has finished, you should have a total of 7 new files spread across the repository:
```
docs/source/model_doc/<model_name>.md
src/transformers/models/<model_name>/configuration_<model_name>.py
src/transformers/models/<model_name>/modeling_<model_name>.py
src/transformers/models/<model_name>/modeling_tf_<model_name>.py
src/transformers/models/<model_name>/tokenization_<model_name>.py
tests/models/<model_name>/test_modeling_<model_name>.py
tests/models/<model_name>/test_modeling_tf_<model_name>.py
```
You can run the tests to ensure that they all pass:
```bash
python -m pytest ./tests/test_*<model_name>*.py
```
Feel free to modify each file to mimic the behavior of your model.
⚠ You should be careful about the classes preceded by the following line:️
```python
# Copied from transformers.[...]
```
This line ensures that the copy does not diverge from the source. If it *should* diverge, because the implementation
is different, this line needs to be deleted. If you don't delete this line and run `make fix-copies`,
your changes will be overwritten.
Once you have edited the files to fit your architecture, simply re-run the tests (and edit them if a change
is needed!) afterwards to make sure everything works as expected.
Once the files are generated and you are happy with your changes, here's a checklist to ensure that your contribution
will be merged quickly:
- You should run the `make fixup` utility to fix the style of the files and to ensure the code quality meets the
library's standards.
- You should complete the documentation file (`docs/source/model_doc/<model_name>.rst`) so that your model may be
usable.
## Add new model like command
Using the `transformers-cli add-new-model-like` command requires to have all the `dev` dependencies installed. Let's
first clone the repository and install it in our environment:
```shell script
git clone https://github.com/huggingface/transformers
cd transformers
pip install -e ".[dev]"
```
Depending on your OS, and since the number of optional dependencies of Transformers is growing, you might get a
failure with this command. If that's the case make sure to install the Deep Learning framework you are working with
(PyTorch, TensorFlow and/or Flax) then do:
```bash
pip install -e ".[quality]"
```
Once the installation is done, you can use the CLI command `add-new-model-like` to generate your models:
```shell script
transformers-cli add-new-model-like
```
This will start a small questionnaire you have to fill.
```
What identifier would you like to use for the model type of this model?
```
You will have to input the model type of the model you want to clone. The model type can be found in several places:
- inside the configuration of any checkpoint of that model
- the name of the documentation page of that model
For instance the doc page of `BigBirdPegasus` is `https://huggingface.co/docs/transformers/model_doc/bigbird_pegasus`
so its model type is `"bigbird_pegasus"`.
If you make a typo, the command will suggest you the closest model types it can find.
Once this is done, the questionnaire will ask you for the new model name and its various casings:
```
What is the name for your new model?
What identifier would you like to use for the model type of this model?
What name would you like to use for the module of this model?
What prefix (camel-cased) would you like to use for the model classes of this model?
What prefix (upper-cased) would you like to use for the constants relative to this model?
```
From your answer to the first question, defaults will be determined for all others. The first name should be written
as you want your model be named in the doc, with no special casing (like RoBERTa) and from there, you can either stick
with the defaults or change the cased versions.
Next will be the name of the config class to use for this model:
```
What will be the name of the config class for this model?
```
Then, you will be asked for a checkpoint identifier:
```
Please give a checkpoint identifier (on the model Hub) for this new model.
```
This is the checkpoint that will be used in the examples across the files and the integration tests. Put the name you
wish, as it will appear on the Model Hub. Do not forget to include the organisation.
Then you will have to say whether your model re-uses the same processing classes as the model you're cloning:
```
Will your new model use the same processing class as Xxx (XxxTokenizer/XxxFeatureExtractor/XxxImageProcessor)
```
Answer yes if you have no intentions to make any change to the class used for preprocessing. It can use different
files (for instance you can reuse the `BertTokenizer` with a new vocab file).
If you answer no, you will have to give the name of the classes
for the new tokenizer/image processor/feature extractor/processor (depending on the model you're cloning).
Next the questionnaire will ask
```
Should we add # Copied from statements when creating the new modeling file?
```
This is the internal mechanism used in the library to make sure code copied from various modeling files stay consistent.
If you plan to completely rewrite the modeling file, you should answer no, whereas if you just want to tweak one part
of the model, you should answer yes.
Lastly, the questionnaire will inquire about frameworks:
```
Should we add a version of your new model in all the frameworks implemented by Old Model (xxx)?
```
If you answer yes, the new model will have files for all the frameworks implemented by the model you're cloning.
Otherwise, you will get a new question to select the frameworks you want.
Once the command has finished, you will see a new subfolder in the `src/transformers/models/` folder, with the
necessary files (configuration and modeling files for all frameworks requested, and maybe the processing files,
depending on your choices).
You will also see a doc file and tests for your new models. First you should run
```bash
make style
make fix-copies
```
and then you can start tweaking your model. You should:
- fill the doc file at `docs/source/model_doc/model_name.md`
- tweak the configuration and modeling files to your need
Once you're done, you can run the tests to ensure that they all pass:
```bash
python -m pytest ./tests/test_*<model_name>*.py
```
⚠ You should be careful about the classes preceded by the following line:️
```python
# Copied from transformers.[...]
```
This line ensures that the copy does not diverge from the source. If it *should* diverge, because the implementation
is different, this line needs to be deleted. If you don't delete this line and run `make fix-copies`,
your changes will be overwritten.
Once you have edited the files to fit your architecture, simply re-run the tests (and edit them if a change
is needed!) afterwards to make sure everything works as expected.
Once the files are generated and you are happy with your changes, here's a checklist to ensure that your contribution
will be merged quickly:
- You should run the `make fixup` utility to fix the style of the files and to ensure the code quality meets the
library's standards.
- You should add your model to the main README then run `make fix-copies`.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule, OptionalDependencyNotAvailable, is_tokenizers_available
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_tf_available
{% endif %}
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_torch_available
{% endif %}
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_flax_available
{% endif %}
_import_structure = {
"configuration_{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP", "{{cookiecutter.camelcase_modelname}}Config"],
"tokenization_{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.camelcase_modelname}}Tokenizer"],
}
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_{{cookiecutter.lowercase_modelname}}_fast"] = ["{{cookiecutter.camelcase_modelname}}TokenizerFast"]
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}ForTokenClassification",
"{{cookiecutter.camelcase_modelname}}Layer",
"{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
"load_tf_weights_in_{{cookiecutter.lowercase_modelname}}",
]
{% else %}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% endif %}
{% endif %}
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
"TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"TF{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"TF{{cookiecutter.camelcase_modelname}}ForCausalLM",
"TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
"TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"TF{{cookiecutter.camelcase_modelname}}ForTokenClassification",
"TF{{cookiecutter.camelcase_modelname}}Layer",
"TF{{cookiecutter.camelcase_modelname}}Model",
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% else %}
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
"TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"TF{{cookiecutter.camelcase_modelname}}Model",
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% endif %}
{% endif %}
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [
"Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"Flax{{cookiecutter.camelcase_modelname}}ForCausalLM",
"Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
"Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"Flax{{cookiecutter.camelcase_modelname}}ForTokenClassification",
"Flax{{cookiecutter.camelcase_modelname}}Layer",
"Flax{{cookiecutter.camelcase_modelname}}Model",
"Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% else %}
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [
"Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"Flax{{cookiecutter.camelcase_modelname}}Model",
"Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% endif %}
{% endif %}
if TYPE_CHECKING:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_{{cookiecutter.lowercase_modelname}}_fast import {{cookiecutter.camelcase_modelname}}TokenizerFast
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Layer,
{{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
)
{% else %}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% endif %}
{% endif %}
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
TF{{cookiecutter.camelcase_modelname}}Layer,
TF{{cookiecutter.camelcase_modelname}}Model,
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% else %}
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
TF{{cookiecutter.camelcase_modelname}}Model,
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% endif %}
{% endif %}
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM,
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
Flax{{cookiecutter.camelcase_modelname}}ForTokenClassification,
Flax{{cookiecutter.camelcase_modelname}}Layer,
Flax{{cookiecutter.camelcase_modelname}}Model,
Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% else %}
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
Flax{{cookiecutter.camelcase_modelname}}Model,
Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% endif %}
{% endif %}
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
{
"modelname": "{{cookiecutter.modelname}}",
"uppercase_modelname": "{{cookiecutter.uppercase_modelname}}",
"lowercase_modelname": "{{cookiecutter.lowercase_modelname}}",
"camelcase_modelname": "{{cookiecutter.camelcase_modelname}}",
"authors": "{{cookiecutter.authors}}",
"checkpoint_identifier": "{{cookiecutter.checkpoint_identifier}}",
"tokenizer_type": "{{cookiecutter.tokenizer_type}}",
"generate_tensorflow_pytorch_and_flax": "{{cookiecutter.generate_tensorflow_pytorch_and_flax}}",
"is_encoder_decoder_model": "{{cookiecutter.is_encoder_decoder_model}}"
}
# coding=utf-8
# Copyright 2022 {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" {{cookiecutter.modelname}} model configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`~{{cookiecutter.camelcase_modelname}}Model`].
It is used to instantiate an {{cookiecutter.modelname}} model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the {{cookiecutter.modelname}} [{{cookiecutter.checkpoint_identifier}}](https://huggingface.co/{{cookiecutter.checkpoint_identifier}}) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used
to control the model outputs. Read the documentation from [`PretrainedConfig`]
for more information.
Args:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
vocab_size (`int`, *optional*, defaults to 30522):
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~{{cookiecutter.camelcase_modelname}}Model`] or
[`~TF{{cookiecutter.camelcase_modelname}}Model`].
hidden_size (`int`, *optional*, defaults to 768):
Dimension of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler.
If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (`int`, *optional*, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (`int`, *optional*, defaults to 2):
The vocabulary size of the `token_type_ids` passed when calling [`~{{cookiecutter.camelcase_modelname}}Model`] or
[`~TF{{cookiecutter.camelcase_modelname}}Model`].
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
{% else -%}
vocab_size (`int`, *optional*, defaults to 50265):
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~{{cookiecutter.camelcase_modelname}}Model`] or
[`~TF{{cookiecutter.camelcase_modelname}}Model`].
d_model (`int`, *optional*, defaults to 1024):
Dimension of the layers and the pooler layer.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
`"gelu"`, `"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see
https://arxiv.org/abs/1909.11556) for more details.
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see
https://arxiv.org/abs/1909.11556) for more details.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
{% endif -%}
Example:
```python
>>> from transformers import {{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Config
>>> # Initializing a {{cookiecutter.modelname}} {{cookiecutter.checkpoint_identifier}} style configuration
>>> configuration = {{cookiecutter.camelcase_modelname}}Config()
>>> # Initializing a model from the {{cookiecutter.checkpoint_identifier}} style configuration
>>> model = {{cookiecutter.camelcase_modelname}}Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "{{cookiecutter.lowercase_modelname}}"
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
{% else -%}
keys_to_ignore_at_inference = ["past_key_values"]
{% endif -%}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
{%- else %}
attribute_map = {
"num_attention_heads": "encoder_attention_heads",
"hidden_size": "d_model"
}
{%- endif %}
def __init__(
self,
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
use_cache=True,
{% else -%}
vocab_size=50265,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
decoder_start_token_id=2,
classifier_dropout=0.0,
scale_embedding=False,
{% endif -%}
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.type_vocab_size = type_vocab_size
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
{% else -%}
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
{% endif -%}
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
{% else -%}
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
{% endif -%}
**kwargs
)
# coding=utf-8
# Copyright 2022 {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax {{cookiecutter.modelname}} model. """
{% if cookiecutter.is_encoder_decoder_model == "False" %}
from typing import Callable, Optional, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
FlaxCausalLMOutput,
FlaxCausalLMOutputWithCrossAttentions,
FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput,
FlaxSequenceClassifierOutput,
FlaxTokenClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
overwrite_call_docstring,
)
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as a regular Flax linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`~{{cookiecutter.uppercase_modelname}}Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the
model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on
GPUs) and `jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
[`~FlaxPreTrainedModel.to_fp16`] and [`~FlaxPreTrainedModel.to_bf16`].
"""
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
Args:
input_ids (`numpy.ndarray` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`~{{cookiecutter.uppercase_modelname}}ConfiTokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids)
position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`.
head_mask (`numpy.ndarray` of shape `({0})`, `optional): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
remat = nn_partitioning.remat
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.word_embeddings = nn.Embed(
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4"))
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
# Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
causal: bool = False
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
: {self.config.num_attention_heads}"
)
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
@nn.compact
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic=True,
output_attentions: bool = False,
):
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
batch_size = hidden_states.shape[0]
# get query proj
query_states = self.query(hidden_states)
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self.key(key_value_states)
value_states = self.value(key_value_states)
else:
# self_attention
key_states = self.key(hidden_states)
value_states = self.value(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# handle cache prepare causal attention mask
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
)
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
causal: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.self = Flax{{cookiecutter.camelcase_modelname}}SelfAttention(self.config, dtype=self.dtype)
self.output = Flax{{cookiecutter.camelcase_modelname}}SelfOutput(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
key_value_states=None,
init_cache=False,
deterministic=True,
output_attentions: bool = False,
):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attn_outputs = self.self(
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
key_value_states=key_value_states,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_outputs[1],)
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Intermediate(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Output(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.LayerNorm(hidden_states + attention_output)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.attention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, dtype=self.dtype)
self.intermediate = Flax{{cookiecutter.camelcase_modelname}}Intermediate(self.config, dtype=self.dtype)
self.output = Flax{{cookiecutter.camelcase_modelname}}Output(self.config, dtype=self.dtype)
if self.config.add_cross_attention:
self.crossattention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, causal=False, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
):
# Self Attention
attention_outputs = self.attention(
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
# Cross-Attention Block
if encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
key_value_states=encoder_hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
outputs = (hidden_states,)
if output_attentions:
outputs += (attention_outputs[1],)
if encoder_hidden_states is not None:
outputs += (cross_attention_outputs[1],)
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
if self.gradient_checkpointing:
Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer = remat(Flax{{cookiecutter.camelcase_modelname}}Layer, static_argnums=(5, 6, 7))
self.layers = [
Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# Check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
if head_mask.shape[0] != (len(self.layers)):
raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
{head_mask.shape[0]}."
)
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states,
attention_mask,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
def __call__(
self,
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.layer(
hidden_states,
attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Pooler(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
def __call__(self, hidden_states):
cls_hidden_state = hidden_states[:, 0]
cls_hidden_state = self.dense(cls_hidden_state)
return nn.tanh(cls_hidden_state)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPredictionHeadTransform with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}PredictionHeadTransform(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.activation = ACT2FN[self.config.hidden_act]
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return self.LayerNorm(hidden_states)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}LMPredictionHead(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.transform = Flax{{cookiecutter.camelcase_modelname}}PredictionHeadTransform(self.config, dtype=self.dtype)
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
def __call__(self, hidden_states, shared_embedding=None):
hidden_states = self.transform(hidden_states)
if shared_embedding is not None:
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
hidden_states = self.decoder(hidden_states)
hidden_states += self.bias
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOnlyMLMHead with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.predictions = Flax{{cookiecutter.camelcase_modelname}}LMPredictionHead(self.config, dtype=self.dtype)
def __call__(self, hidden_states, shared_embedding=None):
hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOnlyNSPHead with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}OnlyNSPHead(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
def __call__(self, pooled_output):
return self.seq_relationship(pooled_output)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainingHeads with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}PreTrainingHeads(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.predictions = Flax{{cookiecutter.camelcase_modelname}}LMPredictionHead(self.config, dtype=self.dtype)
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
def __call__(self, hidden_states, pooled_output, shared_embedding=None):
prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = {{cookiecutter.camelcase_modelname}}Config
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
module_class: nn.Module = None
def __init__(
self,
config: {{cookiecutter.camelcase_modelname}}Config,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}}
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
attention_mask = jnp.ones_like(input_ids)
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
else:
module_init_outputs = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)
random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_cache with Bert->{{cookiecutter.camelcase_modelname}}
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length))
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.__call__ with Bert->{{cookiecutter.camelcase_modelname}}
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
past_key_values: dict = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if head_mask is None:
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
if self.config.add_cross_attention:
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
# changed by FlaxBertAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
head_mask=jnp.array(head_mask, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
deterministic=not train,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
else:
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
head_mask=jnp.array(head_mask, dtype="i4"),
deterministic=not train,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
)
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
gradient_checkpointing: bool = False
def setup(self):
self.embeddings = Flax{{cookiecutter.camelcase_modelname}}Embeddings(self.config, dtype=self.dtype)
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.pooler = Flax{{cookiecutter.camelcase_modelname}}Pooler(self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# make sure `token_type_ids` is correctly initialized when not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
# make sure `position_ids` is correctly initialized when not passed
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
outputs = self.encoder(
hidden_states,
attention_mask,
head_mask=head_mask,
deterministic=deterministic,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
if not return_dict:
# if pooled is None, don't return it
if pooled is None:
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
add_start_docstrings(
"The bare {{cookiecutter.camelcase_modelname}} Model transformer outputting raw hidden-states without any specific head on top.",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class Flax{{cookiecutter.camelcase_modelname}}Model(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}Module
class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.{{cookiecutter.lowercase_modelname}}.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxCausalLMOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings("""{{cookiecutter.camelcase_modelname}} Model with a `language modeling` head on top for MLM training. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC
)
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.{{cookiecutter.lowercase_modelname}}.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxCausalLMOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings("""{{cookiecutter.camelcase_modelname}} Model with a `language modeling` head on top for CLM training. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC
)
class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(
self.config.num_labels,
dtype=self.dtype,
)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
logits = self.classifier(pooled_output)
if not return_dict:
return (logits,) + outputs[2:]
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
{{cookiecutter.camelcase_modelname}} Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
num_choices = input_ids.shape[1]
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
# Model
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
logits = self.classifier(pooled_output)
reshaped_logits = logits.reshape(-1, num_choices)
if not return_dict:
return (reshaped_logits,) + outputs[2:]
return FlaxMultipleChoiceModelOutput(
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
{{cookiecutter.camelcase_modelname}} Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule
overwrite_call_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice, {{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC
)
class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.classifier(hidden_states)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxTokenClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
{{cookiecutter.camelcase_modelname}} Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassification(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForTokenClassification, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC
)
class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if not return_dict:
return (start_logits, end_logits) + outputs[1:]
return FlaxQuestionAnsweringModelOutput(
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
{{cookiecutter.camelcase_modelname}} Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
token_type_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.{{cookiecutter.lowercase_modelname}}.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxCausalLMOutputWithCrossAttentions(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@add_start_docstrings(
"""
{{cookiecutter.camelcase_modelname}} Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
autoregressive tasks.
""",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyway.
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
{# encoder_decoder #}
{% else %}
import math
import random
from functools import partial
from typing import Callable, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
from ...utils import add_start_docstrings, replace_return_docstrings
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
FlaxSeq2SeqQuestionAnsweringModelOutput,
FlaxSeq2SeqSequenceClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a Flax Linen [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`~{{cookiecutter.camelcase_modelname}}Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the
model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on
GPUs) and `jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
[`~FlaxPreTrainedModel.to_fp16`] and [`~FlaxPreTrainedModel.to_bf16`].
"""
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
For translation and summarization training, `decoder_input_ids` should be provided. If no
`decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to
the right for denoising pre-training following the paper.
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will
also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`.
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
{{cookiecutter.uppercase_modelname}}_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
{{cookiecutter.uppercase_modelname}}_DECODE_INPUTS_DOCSTRING = r"""
Args:
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
For translation and summarization training, `decoder_input_ids` should be provided. If no
`decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to
the right for denoising pre-training following the paper.
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
`attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`,
*optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will
also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = shifted_input_ids.at[(..., 0)].set(decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
embed_dim: int
num_heads: int
dropout: float = 0.0
causal: bool = False
bias: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.out_proj = dense()
self.dropout_layer = nn.Dropout(rate=self.dropout)
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__(
self,
hidden_states: jnp.ndarray,
key_value_states: Optional[jnp.ndarray] = None,
attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
batch_size = hidden_states.shape[0]
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states)
value_states = self.v_proj(key_value_states)
else:
# self_attention
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# handle cache prepare causal attention mask
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
)
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = Flax{{cookiecutter.camelcase_modelname}}Attention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class Flax{{cookiecutter.camelcase_modelname}}EncoderLayerCollection(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = Flax{{cookiecutter.camelcase_modelname}}Attention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn = Flax{{cookiecutter.camelcase_modelname}}Attention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
class Flax{{cookiecutter.camelcase_modelname}}DecoderLayerCollection(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
]
self.layerdrop = self.config.decoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None, None)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class Flax{{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
config: {{cookiecutter.camelcase_modelname}}Config
inner_dim: int
num_classes: int
pooler_dropout: float
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense(
self.num_classes,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.dense(hidden_states)
hidden_states = jnp.tanh(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = Flax{{cookiecutter.camelcase_modelname}}EncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(position_ids + self.offset)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return outputs
return FlaxBaseModelOutput(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = Flax{{cookiecutter.camelcase_modelname}}DecoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# embed positions
positions = self.embed_positions(position_ids + self.offset)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return outputs
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
self.decoder = Flax{{cookiecutter.camelcase_modelname}}Decoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
def _get_encoder_module(self):
return self.encoder
def _get_decoder_module(self):
return self.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config
base_model_prefix: str = "model"
module_class: nn.Module = None
def __init__(
self,
config: {{cookiecutter.camelcase_modelname}}Config,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`,
*optional*: `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the
encoder. Used in the cross-attention of the decoder.
"""
# init input variables to retrieve cache
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # we only need to call the decoder to init the cache
)
return unfreeze(init_variables["cache"])
@add_start_docstrings({{cookiecutter.uppercase_modelname}}_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class={{cookiecutter.camelcase_modelname}}Config)
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```python
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
>>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=1024, return_tensors='np')
>>> encoder_outputs = model.encode(**inputs)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_ids, attention_mask, position_ids, **kwargs)
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
@add_start_docstrings({{cookiecutter.uppercase_modelname}}_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class={{cookiecutter.camelcase_modelname}}Config)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```python
>>> import jax.numpy as jnp
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
>>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=1024, return_tensors='np')
>>> encoder_outputs = model.encode(**inputs)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> last_decoder_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
if encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = decoder_input_ids.shape
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by Flax{{cookiecutter.camelcase_modelname}}Attention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past = outputs
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past = outputs
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# prepare decoder inputs
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
@add_start_docstrings(
"The bare {{cookiecutter.camelcase_modelname}} Model transformer outputting raw hidden-states without any specific head on top.",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class Flax{{cookiecutter.camelcase_modelname}}Model(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
module_class = Flax{{cookiecutter.camelcase_modelname}}Module
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
)
class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.model = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.model.variables["params"]["shared"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return output
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"The {{cookiecutter.uppercase_modelname}} Model with a language modeling head. Can be used for summarization.", {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
)
class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule
dtype: jnp.dtype = jnp.float32
@add_start_docstrings({{cookiecutter.uppercase_modelname}}_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class={{cookiecutter.camelcase_modelname}}Config)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
deterministic: bool = True,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```python
>>> import jax.numpy as jnp
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
>>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=1024, return_tensors='np')
>>> encoder_outputs = model.encode(**inputs)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> logits = outputs.logits
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
if encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = decoder_input_ids.shape
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by Flax{{cookiecutter.camelcase_modelname}}Attention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
outputs = decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = module.model.variables["params"]["shared"]["embedding"]
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
if past_key_values is None:
lm_logits, decoder_outputs = outputs
else:
(lm_logits, decoder_outputs), past = outputs
if return_dict:
outputs = FlaxCausalLMOutputWithCrossAttentions(
logits=lm_logits,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
)
else:
outputs = (lm_logits,) + decoder_outputs[1:]
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs
):
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
FLAX_{{cookiecutter.uppercase_modelname}}_CONDITIONAL_GENERATION_DOCSTRING = """
Returns:
Summarization example:
```python
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
>>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np')
>>> # Generate Summary
>>> summary_ids = model.generate(inputs['input_ids']).sequences
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
```
Mask filling example:
```python
>>> import jax
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
>>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> TXT = "My friends are <mask> but they eat too many carbs."
>>> input_ids = tokenizer([TXT], return_tensors='np')['input_ids']
>>> logits = model(input_ids).logits
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
>>> values, predictions = jax.lax.top_k(probs, k=1)
>>> tokenizer.decode(predictions).split()
```
"""
overwrite_call_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING + FLAX_{{cookiecutter.uppercase_modelname}}_CONDITIONAL_GENERATION_DOCSTRING
)
append_replace_return_docstrings(
Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
num_labels: Optional[int] = None
def setup(self):
self.model = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
self.classification_head = Flax{{cookiecutter.camelcase_modelname}}ClassificationHead(
config=self.config,
inner_dim=self.config.d_model,
num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels,
pooler_dropout=self.config.classifier_dropout,
)
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0] # last hidden state
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer:
if len(jnp.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
if any(eos_mask.sum(1) == 0):
raise ValueError("There are missing <eos> tokens in input_ids")
# Ensure to keep 1 only for the last <eos> token for each example
eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6
eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0)
sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1)
logits = self.classification_head(sentence_representation, deterministic=deterministic)
if not return_dict:
output = (logits,) + outputs[1:]
return output
return FlaxSeq2SeqSequenceClassifierOutput(
logits=logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"""
{{cookiecutter.camelcase_modelname}} model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
""",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
dtype = jnp.float32
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxSeq2SeqSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
num_labels = 2
def setup(self):
self.model = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return output
return FlaxSeq2SeqQuestionAnsweringModelOutput(
start_logits=start_logits,
end_logits=end_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"""
{{cookiecutter.uppercase_modelname}} Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule
dtype = jnp.float32
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxSeq2SeqQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)
{% endif -%}
# coding=utf-8
# Copyright 2022 {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 {{cookiecutter.modelname}} model. """
{% if cookiecutter.is_encoder_decoder_model == "False" %}
import math
from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from ...modeling_tf_outputs import (
TFBaseModelOutputWithPastAndCrossAttentions,
TFCausalLMOutputWithCrossAttentions,
TFMaskedLMOutput,
TFMultipleChoiceModelOutput,
TFQuestionAnsweringModelOutput,
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFMaskedLanguageModelingLoss,
TFModelInputType,
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFSequenceSummary,
TFTokenClassificationLoss,
get_initializer,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}Embeddings(keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
self.type_vocab_size = config.type_vocab_size
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def call(
self,
input_ids: tf.Tensor = None,
position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None,
past_key_values_length=0,
training: bool = False,
) -> tf.Tensor:
"""
Applies embedding based on inputs tensor.
Returns:
final_embeddings (`tf.Tensor`): output embedding tensor.
"""
assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None:
check_embeddings_within_bounds(input_ids, self.vocab_size)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None:
token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None:
position_ids = tf.expand_dims(
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
return final_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}SelfAttention(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TF{{cookiecutter.camelcase_modelname}}Model call() function)
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(inputs=attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.matmul(attention_probs, value_layer)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, all_head_size)
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}SelfOutput(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}Attention(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.self_attention = TF{{cookiecutter.camelcase_modelname}}SelfAttention(config, name="self")
self.dense_output = TF{{cookiecutter.camelcase_modelname}}SelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(
self,
input_tensor: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
self_outputs = self.self_attention(
hidden_states=input_tensor,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
)
# add attentions (possibly with past_key_value) if we output them
outputs = (attention_output,) + self_outputs[1:]
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}Intermediate(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}Output(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}Layer(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.attention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="attention")
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="crossattention")
self.intermediate = TF{{cookiecutter.camelcase_modelname}}Intermediate(config, name="intermediate")
self.bert_output = TF{{cookiecutter.camelcase_modelname}}Output(config, name="output")
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor | None,
encoder_attention_mask: tf.Tensor | None,
past_key_value: Tuple[tf.Tensor] | None,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
input_tensor=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=self_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
"by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
input_tensor=attention_output,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output(
hidden_states=intermediate_output, input_tensor=attention_output, training=training
)
outputs = (layer_output,) + outputs # add attentions if we output them
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}Encoder(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.layer = [TF{{cookiecutter.camelcase_modelname}}Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor | None,
encoder_attention_mask: tf.Tensor | None,
past_key_values: Tuple[Tuple[tf.Tensor]] | None,
use_cache: Optional[bool],
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
)
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}PredictionHeadTransform(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
if isinstance(config.hidden_act, str):
self.transform_act_fn = get_tf_activation(config.hidden_act)
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(inputs=hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}LMPredictionHead(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, input_embeddings: keras.layers.Layer, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.transform = TF{{cookiecutter.camelcase_modelname}}PredictionHeadTransform(config, name="transform")
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape: tf.TensorShape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def get_output_embeddings(self) -> keras.layers.Layer:
return self.input_embeddings
def set_output_embeddings(self, value: tf.Variable):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self) -> Dict[str, tf.Variable]:
return {"bias": self.bias}
def set_bias(self, value: tf.Variable):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.transform(hidden_states=hidden_states)
seq_length = shape_list(hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.vocab_size])
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}MLMHead(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, input_embeddings: keras.layers.Layer, **kwargs):
super().__init__(**kwargs)
self.predictions = TF{{cookiecutter.camelcase_modelname}}LMPredictionHead(config, input_embeddings, name="predictions")
def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
prediction_scores = self.predictions(hidden_states=sequence_output)
return prediction_scores
@keras_serializable
class TF{{cookiecutter.camelcase_modelname}}MainLayer(keras.layers.Layer):
config_class = {{cookiecutter.camelcase_modelname}}Config
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, add_pooling_layer: bool = True, **kwargs):
super().__init__(**kwargs)
self.config = config
self.is_decoder = config.is_decoder
self.embeddings = TF{{cookiecutter.camelcase_modelname}}Embeddings(config, name="embeddings")
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, name="encoder")
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
def get_input_embeddings(self) -> keras.layers.Layer:
return self.embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value: tf.Variable):
self.embeddings.weight = value
self.embeddings.vocab_size = shape_list(value)[0]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
raise NotImplementedError
@unpack_inputs
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
if not self.config.is_decoder:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if past_key_values is None:
past_key_values_length = 0
past_key_values = [None] * len(self.encoder.layer)
else:
past_key_values_length = shape_list(past_key_values[0][0])[-2]
if attention_mask is None:
attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
if token_type_ids is None:
token_type_ids = tf.fill(dims=input_shape, value=0)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
training=training,
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(attention_mask)
mask_seq_length = seq_length + past_key_values_length
# Copied from `modeling_tf_t5.py`
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
if self.is_decoder:
seq_ids = tf.range(mask_seq_length)
causal_mask = tf.less_equal(
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
extended_attention_mask = causal_mask * attention_mask[:, None, :]
attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
if past_key_values[0] is not None:
# attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = tf.reshape(
attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.is_decoder and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
encoder_attention_mask = tf.cast(
encoder_attention_mask, dtype=extended_attention_mask.dtype
)
num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder(
hidden_states=embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = encoder_outputs[0]
if not return_dict:
return (
sequence_output,
) + encoder_outputs[1:]
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=sequence_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = {{cookiecutter.camelcase_modelname}}Config
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass.
Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general
usage and behavior.
<Tip>
TensorFlow models and layers in `transformers` accept two formats as input:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional argument.
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second format outside of Keras methods like `fit()` and `predict()`, such as when creating
your own layers or models with the Keras `Functional` API, there are three possibilities you
can use to gather all the input Tensors in the first positional argument:
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
Note that when creating models and layers with (subclassing)[https://keras.io/guides/making_new_layers_and_models_via_subclassing/]
then you don't need to worry about any of this, as you can just pass inputs like you would to any other Python
function!
</Tip>
Args:
config ([`~{{cookiecutter.camelcase_modelname}}Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
Args:
input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]`, `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See
[`PreTrainedTokenizer.__call__`] and [`PreTrainedTokenizer.encode`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids)
position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
config will be used instead.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
used instead.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This
argument can be used in eager mode, in graph mode the value will always be set to True.
training (`bool`, *optional*, defaults to `False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
@add_start_docstrings(
"The bare {{cookiecutter.modelname}} Model transformer outputing raw hidden-states without any specific head on top.",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
r"""
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
decoding (see `past_key_values`). Set to `False` during training, `True` during generation
"""
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
@add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TFMaskedLanguageModelingLoss):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
if config.is_decoder:
logger.warning(
"If you want to use `TF{{cookiecutter.camelcase_modelname}}ForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, input_embeddings=self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls")
def get_lm_head(self) -> keras.layers.Layer:
return self.mlm.predictions
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFMaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
loss = (
None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput(
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
)
class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TFCausalLanguageModelingLoss):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
if not config.is_decoder:
logger.warning("If you want to use `TF{{cookiecutter.camelcase_modelname}}ForCausalLM` as a standalone, add `is_decoder=True.`")
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, input_embeddings=self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls")
def get_lm_head(self) -> keras.layers.Layer:
return self.mlm.predictions
def prepare_inputs_for_generation(self, inputs, past_key_values=None, attention_mask=None, **model_kwargs):
# cut decoder_input_ids if past is used
if past_key_values:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": model_kwargs["use_cache"],
}
@unpack_inputs
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFCausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
r"""
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
decoding (see `past_key_values`). Set to `False` during training, `True` during generation
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`.
"""
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
logits = self.mlm(sequence_output=sequence_output, training=training)
loss = None
if labels is not None:
# shift labels to the left and cut last logit token
shifted_logits = logits[:, :-1]
labels = labels[:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFCausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(keras.layers.Layer):
"""Head for sentence-level classification tasks."""
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.out_proj = keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
if isinstance(config.hidden_act, str):
self.classifier_act_fn = get_tf_activation(config.hidden_act)
else:
self.classifier_act_fn = config.hidden_act
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.classifier_act_fn(hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.out_proj(hidden_states)
return hidden_states
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model transformer with a sequence classification/regression head on top
e.g., for GLUE tasks. """,
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
self.classifier = TF{{cookiecutter.camelcase_modelname}}ClassificationHead(config, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
logits = self.classifier(hidden_states=outputs[0], training=training)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TFMultipleChoiceLoss):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
self.sequence_summary = TFSequenceSummary(
config, config.initializer_range, name="sequence_summary"
)
self.classifier = keras.layers.Dense(
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = (
tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
)
flat_attention_mask = (
tf.reshape(tensor=attention_mask, shape=(-1, seq_length))
if attention_mask is not None
else None
)
flat_token_type_ids = (
tf.reshape(tensor=token_type_ids, shape=(-1, seq_length))
if token_type_ids is not None
else None
)
flat_position_ids = (
tf.reshape(tensor=position_ids, shape=(-1, seq_length))
if position_ids is not None
else None
)
flat_inputs_embeds = (
tf.reshape(
tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])
)
if inputs_embeds is not None
else None
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=flat_input_ids,
attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids,
position_ids=flat_position_ids,
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
logits = self.sequence_summary(inputs=outputs[0], training=training)
logits = self.classifier(inputs=logits)
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TFMultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.classifier = keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
sequence_output = self.dropout(inputs=sequence_output, training=training)
logits = self.classifier(inputs=sequence_output)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TFTokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layer on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
self.qa_outputs = keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
start_positions: np.ndarray | tf.Tensor | None = None,
end_positions: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
r"""
start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
logits = self.qa_outputs(inputs=sequence_output)
start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
start_logits = tf.squeeze(input=start_logits, axis=-1)
end_logits = tf.squeeze(input=end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFQuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
{% else %}
import random
from typing import Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
# Public API
from ...modeling_tf_utils import (
DUMMY_INPUTS,
TFPreTrainedModel,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import ContextManagers, logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype))
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(keras.layers.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
"""Input is expected to be of size [bsz x seqlen]."""
seq_len = input_shape[1]
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length
return super().call(tf.cast(position_ids, dtype=tf.int32))
class TF{{cookiecutter.camelcase_modelname}}Attention(keras.layers.Layer):
"""Multi-headed attention from "Attention Is All You Need"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = keras.layers.Dropout(dropout)
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
def call(
self,
hidden_states: tf.Tensor,
key_value_states: tf.Tensor | None = None,
past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
attention_mask: tf.Tensor | None = None,
layer_head_mask: tf.Tensor | None = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor | None]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = shape_list(hidden_states)
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = tf.concat([past_key_value[0], key_states], axis=2)
value_states = tf.concat([past_key_value[1], value_states], axis=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
key_states = tf.reshape(key_states, proj_shape)
value_states = tf.reshape(value_states, proj_shape)
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
)
attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
attn_output = self.out_proj(attn_output)
attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
return attn_output, attn_weights, past_key_value
class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TF{{cookiecutter.camelcase_modelname}}Attention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
)
self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.dropout = keras.layers.Dropout(config.dropout)
self.activation_fn = get_tf_activation(config.activation_function)
self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
attention_mask (`tf.Tensor`): attention mask of size
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
*(encoder_attention_heads,)*
"""
residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout(hidden_states, training=training)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states, self_attn_weights
class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TF{{cookiecutter.camelcase_modelname}}Attention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
name="self_attn",
is_decoder=True,
)
self.dropout = keras.layers.Dropout(config.dropout)
self.activation_fn = get_tf_activation(config.activation_function)
self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.encoder_attn = TF{{cookiecutter.camelcase_modelname}}Attention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
name="encoder_attn",
is_decoder=True,
)
self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(
self,
hidden_states,
attention_mask: tf.Tensor | None = None,
encoder_hidden_states: tf.Tensor | None = None,
encoder_attention_mask: tf.Tensor | None = None,
layer_head_mask: tf.Tensor | None = None,
cross_attn_layer_head_mask: tf.Tensor | None = None,
past_key_value: Tuple[tf.Tensor] | None = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
attention_mask (`tf.Tensor`): attention mask of size
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
encoder_hidden_states (`tf.Tensor`): cross attention input to the layer of shape *(batch, seq_len, embed_dim)*
encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
*(decoder_attention_heads,)*
cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module.
*(decoder_attention_heads,)*
past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
"""
residual = hidden_states
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout(hidden_states, training=training)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
return (
hidden_states,
self_attn_weights,
cross_attn_weights,
present_key_value,
)
class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config
base_model_prefix = "model"
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use
it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage
and behavior.
<Tip>
TensorFlow models and layers in `transformers` accept two formats as input:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional argument.
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second format outside of Keras methods like `fit()` and `predict()`, such as when creating
your own layers or models with the Keras `Functional` API, there are three possibilities you
can use to gather all the input Tensors in the first positional argument:
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
Note that when creating models and layers with (subclassing)[https://keras.io/guides/making_new_layers_and_models_via_subclassing/]
then you don't need to worry about any of this, as you can just pass inputs like you would to any other Python
function!
</Tip>
Args:
config ([`~{{cookiecutter.camelcase_modelname}}Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the
model weights.
"""
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
Args:
input_ids (`tf.Tensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
{{cookiecutter.camelcase_modelname}} uses the `eos_token_id` as the starting token for
`decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
`decoder_input_ids` have to be input (see `past_key_values`).
For translation and summarization training, `decoder_input_ids` should be provided. If no
`decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to
the right for denoising pre-training following the paper.
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tf.FloatTensor`, *optional*):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape `(batch_size, sequence_length, hidden_size)` is a sequence of
past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
decoding (see `past_key_values`). Set to `False` during training, `True` during generation
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
config will be used instead.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
used instead.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This
argument can be used in eager mode, in graph mode the value will always be set to True.
training (`bool`, *optional*, defaults to `False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
@keras_serializable
class TF{{cookiecutter.camelcase_modelname}}Encoder(keras.layers.Layer):
config_class = {{cookiecutter.camelcase_modelname}}Config
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`TF{{cookiecutter.camelcase_modelname}}EncoderLayer`].
Args:
config: {{cookiecutter.camelcase_modelname}}Config
"""
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.dropout = keras.layers.Dropout(config.dropout)
self.layerdrop = config.encoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
name="embed_positions",
)
self.layers = [TF{{cookiecutter.camelcase_modelname}}EncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
@unpack_inputs
def call(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert `input_ids` indices
into associated vectors than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
in the config will be used instead.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail. This argument can be used only in eager mode, in graph mode the value in the config
will be used instead.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This
argument can be used in eager mode, in graph mode the value will always be set to True.
training (`bool`, *optional*, defaults to `False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
# check attention mask and invert
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
)
# encoder layers
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if training and (dropout_probability < self.layerdrop): # skip the layer
continue
hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
)
if output_attentions:
all_attentions += (attn,)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
@keras_serializable
class TF{{cookiecutter.camelcase_modelname}}Decoder(keras.layers.Layer):
config_class = {{cookiecutter.camelcase_modelname}}Config
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TF{{cookiecutter.camelcase_modelname}}DecoderLayer`]
Args:
config: {{cookiecutter.camelcase_modelname}}Config
embed_tokens: output embedding
"""
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.padding_idx = config.pad_token_id
self.embed_tokens = embed_tokens
self.layerdrop = config.decoder_layerdrop
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
name="embed_positions",
)
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
self.layers = [TF{{cookiecutter.camelcase_modelname}}DecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
self.dropout = keras.layers.Dropout(config.dropout)
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
@unpack_inputs
def call(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
If `past_key_values` are used, the user can optionally input only the last
`decoder_input_ids` (those that don't have their past key value states given to this model) of
shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size,
sequence_length)`.
inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices
into associated vectors than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
in the config will be used instead.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail. This argument can be used only in eager mode, in graph mode the value in the config
will be used instead.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This
argument can be used in eager mode, in graph mode the value will always be set to True.
training (`bool`, *optional*, defaults to `False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = (
shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
)
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if inputs_embeds is None:
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
attention_mask, combined_attention_mask = self.compute_combined_attns_mask(
input_ids, attention_mask, input_shape, past_key_values_length
)
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
hidden_states = self.layernorm_embedding(hidden_states + positions)
hidden_states = self.dropout(hidden_states, training=training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=head_mask[idx] if head_mask is not None else None,
cross_attn_layer_head_mask=cross_attn_head_mask[idx]
if cross_attn_head_mask is not None
else None,
past_key_value=past_key_value,
)
if use_cache:
present_key_values += (present_key_value,)
if output_attentions:
all_self_attns += (layer_self_attn,)
if encoder_hidden_states is not None:
all_cross_attns += (layer_cross_attn,)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attns,
)
@tf.function
def compute_combined_attns_mask(self, input_ids, attention_mask, input_shape, past_key_values_length):
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else:
combined_attention_mask = _expand_mask(
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
if attention_mask is None and input_ids is not None and input_shape[-1] > 1:
attention_mask = tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id), input_ids.dtype
)
attention_mask = tf.concat(
[
tf.ones((input_shape[0], past_key_values_length), dtype=attention_mask.dtype),
attention_mask,
],
axis=-1,
)
else:
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length))
return attention_mask, combined_attention_mask
@keras_serializable
class TF{{cookiecutter.camelcase_modelname}}MainLayer(keras.layers.Layer):
config_class = {{cookiecutter.camelcase_modelname}}Config
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.shared = keras.layers.Embedding(
input_dim=config.vocab_size,
output_dim=config.d_model,
embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
name="model.shared"
)
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
self.shared.load_weight_prefix = "model.shared"
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, self.shared, name="encoder")
self.decoder = TF{{cookiecutter.camelcase_modelname}}Decoder(config, self.shared, name="decoder")
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
@unpack_inputs
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
if decoder_input_ids is None and decoder_inputs_embeds is None:
use_cache = False
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
encoder_outputs = TFBaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not return_dict and not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + '/' + self.shared.name + '/'):
self.shared.build(None)
@add_start_docstrings(
"The bare {{cookiecutter.uppercase_modelname}} Model outputting raw hidden-states without any specific head on top.",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
def call(self, x):
return x + self.bias
@add_start_docstrings(
"The {{cookiecutter.uppercase_modelname}} Model with a language modeling head. Can be used for summarization.",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight",
]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
def get_decoder(self):
return self.model.decoder
def get_encoder(self):
return self.model.encoder
def get_bias(self):
return {"final_logits_bias": self.bias_layer.bias}
def set_bias(self, value):
# Replaces the existing layers containing bias for correct (de)serialization.
vocab_size = value["final_logits_bias"].shape[-1]
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
)
self.bias_layer.bias.assign(value["final_logits_bias"])
def get_output_embeddings(self):
return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
):
"""
Returns:
Examples:
```python
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
>>> import tensorflow as tf
>>> mname = '{{cookiecutter.checkpoint_identifier}}'
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained(mname)
>>> TXT = "My friends are <mask> but they eat too many carbs."
>>> model = TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained(mname)
>>> batch = tokenizer([TXT], return_tensors='tf')
>>> logits = model(inputs=batch.input_ids).logits
>>> probs = tf.nn.softmax(logits[0])
>>> # probs[5] is associated with the mask token
```"""
if labels is not None:
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training
)
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return TFSeq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)
{% endif -%}
# coding=utf-8
# Copyright 2022 {{cookiecutter.authors}} The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch {{cookiecutter.modelname}} model. """
{% if cookiecutter.is_encoder_decoder_model == "False" %}
import math
import os
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional, Tuple, Union
from ...activations import ACT2FN
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
def load_tf_weights_in_{{cookiecutter.lowercase_modelname}}(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
logger.info(f"Loading TF weight {name} with shape {shape}")
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info(f"Skipping {'/'.join(name)}")
continue
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name)
else:
scope_names = [m_name]
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, scope_names[0])
except AttributeError:
logger.info(f"Skipping {'/'.join(name)}")
continue
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}")
pointer.data = torch.from_numpy(array)
return model
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
persistent=False,
)
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in {{cookiecutter.camelcase_modelname}}Model forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config, position_embedding_type=position_embedding_type)
self.output = {{cookiecutter.camelcase_modelname}}SelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Intermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Output(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Layer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = {{cookiecutter.camelcase_modelname}}Attention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config, position_embedding_type="absolute")
self.intermediate = {{cookiecutter.camelcase_modelname}}Intermediate(config)
self.output = {{cookiecutter.camelcase_modelname}}Output(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([{{cookiecutter.camelcase_modelname}}Layer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
if self.gradient_checkpointing and self.training and use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}PredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}LMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = {{cookiecutter.camelcase_modelname}}PredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}OnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = {{cookiecutter.camelcase_modelname}}LMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = {{cookiecutter.camelcase_modelname}}Config
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config ([`~{{cookiecutter.camelcase_modelname}}Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`{{cookiecutter.camelcase_modelname}}Tokenizer`].
See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range `[0, config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert *input_ids* indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare {{cookiecutter.modelname}} Model transformer outputting raw hidden-states without any specific head on top.",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelname}}PreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
the self-attention layers, following the architecture described in [Attention is
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the
`is_decoder` argument of the configuration set to `True`.
To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
argument and `add_cross_attention` set to `True`; an
`encoder_hidden_states` is then expected as an input to the forward pass.
"""
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = {{cookiecutter.camelcase_modelname}}Embeddings(config)
self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
is used in the cross-attention if the model is configured as a decoder.
Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
decoding (see `past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
if not return_dict:
return (sequence_output,) + encoder_outputs[1:]
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=sequence_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
@add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config):
super().__init__(config)
if config.is_decoder:
logger.warning(
"If you want to use `{{cookiecutter.camelcase_modelname}}ForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss.
Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring)
Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels
in `[0, ..., config.vocab_size]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]
# add a dummy token
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
)
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
if not config.is_decoder:
logger.warning("If you want to use `{{cookiecutter.camelcase_modelname}}ForCausalLM` as a standalone, add `is_decoder=True.`")
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2
tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two
additional tensors are only required when the model is used as a decoder in a Sequence to Sequence
model.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential
decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
decoding (see `past_key_values`).
Returns:
Example:
```python
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM, {{cookiecutter.camelcase_modelname}}Config
>>> import torch
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> config = {{cookiecutter.camelcase_modelname}}Config.from_pretrained("{{cookiecutter.checkpoint_identifier}}")
>>> config.is_decoder = True
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('{{cookiecutter.checkpoint_identifier}}', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + layer_past[2:],)
return reordered_past
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
self.config = config
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = ACT2FN[self.config.hidden_act](x)
x = self.dropout(x)
x = self.out_proj(x)
return x
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.classifier = {{cookiecutter.camelcase_modelname}}ClassificationHead(config)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss.
Indices should be in `[0, ..., config.num_labels - 1]`.
If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class {{cookiecutter.camelcase_modelname}}ForMultipleChoice({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.sequence_summary = SequenceSummary(config)
self.classifier = nn.Linear(config.hidden_size, 1)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss.
Indices should be in `[0, ..., num_choices-1]` where `num_choices` is the size of the second dimension
of the input tensors. (See `input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
pooled_output = self.sequence_summary(sequence_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss.
Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config):
super().__init__(config)
config.num_labels = 2
self.num_labels = config.num_labels
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
{% else %}
import math
import copy
from typing import Optional, Tuple, List, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
CausalLMOutputWithCrossAttentions
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__(num_embeddings, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions)
class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = {{cookiecutter.camelcase_modelname}}Attention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
attention_mask (`torch.FloatTensor`): attention mask of size
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
*(config.encoder_attention_heads,)*.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = {{cookiecutter.camelcase_modelname}}Attention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = {{cookiecutter.camelcase_modelname}}Attention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
attention_mask (`torch.FloatTensor`): attention mask of size
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape *(batch, seq_len, embed_dim)*
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
*(encoder_attention_heads,)*.
cross_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size *(decoder_attention_heads,)*.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim: int,
inner_dim: int,
num_classes: int,
pooler_dropout: float,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
general usage and behavior.
Parameters:
config ([`~{{cookiecutter.camelcase_modelname}}Config`]):
Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model
weights.
"""
{{cookiecutter.uppercase_modelname}}_GENERATION_EXAMPLE = r"""
Summarization example:
```python
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration
>>> model = {{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
>>> # Generate Summary
>>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5)
>>> print(tokenizer.decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
```
"""
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Provide for translation and summarization training. By default, the model will create this tensor by
shifting the `input_ids` to the right, following the paper.
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will
also be used by default.
If you want to change padding behavior, you should read [`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_attention_mask`] and
modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
`attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`,
*optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors
of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds`
have to be input (see `past_key_values`). This is useful if you want more control over how to convert
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds`
takes the value of `inputs_embeds`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
decoding (see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
{{cookiecutter.uppercase_modelname}}_STANDALONE_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`ProphetNetTokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_modelname}}PreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`{{cookiecutter.camelcase_modelname}}EncoderLayer`].
Args:
config: {{cookiecutter.camelcase_modelname}}Config
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tokens: Optional[nn.Embedding] = None):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_positions = {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert `input_ids` indices
into associated vectors than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.randn([])
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_modelname}}PreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`{{cookiecutter.camelcase_modelname}}DecoderLayer`]
Args:
config: {{cookiecutter.camelcase_modelname}}Config
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tokens: Optional[nn.Embedding] = None):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_positions = {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}DecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2
tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential
decoding.
If `past_key_values` are used, the user can optionally input only the last
`decoder_input_ids` (those that don't have their past key value states given to this model) of
shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size,
sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
if self.gradient_checkpointing and self.training and use_cache:
logger.warning("`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...")
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (
len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = torch.randn([])
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
@add_start_docstrings(
"The bare {{cookiecutter.modelname}} Model outputting raw hidden-states without any specific head on top.",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config):
super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config, self.shared)
self.decoder = {{cookiecutter.camelcase_modelname}}Decoder(config, self.shared)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"The {{cookiecutter.modelname}} Model with a language modeling head. Can be used for summarization.", {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
)
class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutter.camelcase_modelname}}PreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder\.version",
r"decoder\.version",
r"lm_head\.weight",
]
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config):
super().__init__(config)
self.model = {{cookiecutter.camelcase_modelname}}Model(config)
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
old_num_tokens = self.final_logits_bias.shape[-1]
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings({{cookiecutter.uppercase_modelname}}_GENERATION_EXAMPLE)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Conditional generation example:
```python
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> TXT = "My friends are <mask> but they eat too many carbs."
>>> model = {{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
>>> logits = model(input_ids).logits
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
>>> probs = logits[0, masked_index].softmax(dim=0)
>>> values, predictions = probs.topk(5)
>>> tokenizer.decode(predictions).split()
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
return reordered_past
@add_start_docstrings(
"""
{{cookiecutter.camelcase_modelname}} model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
""",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(config, **kwargs)
self.model = {{cookiecutter.camelcase_modelname}}Model(config)
self.classification_head = {{cookiecutter.camelcase_modelname}}ClassificationHead(
config.d_model,
config.d_model,
config.num_labels,
config.classifier_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
if input_ids is None and inputs_embeds is not None:
raise NotImplementedError(
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, :
]
logits = self.classification_head(sentence_representation)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqSequenceClassifierOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"""
{{cookiecutter.modelname}} Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config):
super().__init__(config)
config.num_labels = 2
self.num_labels = config.num_labels
self.model = {{cookiecutter.camelcase_modelname}}Model(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs)
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
start_positions=None,
end_positions=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if start_positions is not None and end_positions is not None:
use_cache = False
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (
start_logits,
end_logits,
) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return Seq2SeqQuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the [`EncoderDecoderModel`] framework.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = {{cookiecutter.camelcase_modelname}}Decoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
super().__init__(config)
self.model = {{cookiecutter.camelcase_modelname}}DecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`~{{cookiecutter.camelcase_modelname}}Tokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
decoding (see `past_key_values`).
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
Example:
```python
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('facebook/bart-large')
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
return reordered_past
{% endif -%}
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
{% if cookiecutter.is_encoder_decoder_model == "False" %}
import unittest
from transformers import is_flax_available, {{cookiecutter.camelcase_modelname}}Config
from transformers.testing_utils import require_flax, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import numpy as np
from transformers import (
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM,
Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
Flax{{cookiecutter.camelcase_modelname}}ForTokenClassification,
Flax{{cookiecutter.camelcase_modelname}}Model,
)
class Flax{{cookiecutter.camelcase_modelname}}ModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = 13
self.seq_length = 7
self.is_training = True
self.use_input_mask = True
self.use_token_type_ids = True
self.use_labels = True
self.vocab_size = 99
self.hidden_size = 32
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.intermediate_size = 37
self.hidden_act = "gelu"
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 512
self.type_vocab_size = 16
self.type_sequence_label_size = 2
self.initializer_range = 0.02
self.num_labels = 3
self.num_choices = 4
self.scope = None
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = {{cookiecutter.camelcase_modelname}}Config(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
return_dict=True,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = Flax{{cookiecutter.camelcase_modelname}}Model(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
inputs = [input_ids, input_mask]
result = model(*inputs)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_lm_head(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.is_decoder = True
model = Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
}
prediction_scores = model(**inputs)["logits"]
self.parent.assertListEqual(
list(prediction_scores.shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
}
result = model(**inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
}
result = model(**inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice(config=config)
multiple_choice_inputs_ids = np.tile(np.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = np.tile(np.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = np.tile(np.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(**inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = Flax{{cookiecutter.camelcase_modelname}}ForTokenClassification(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
}
result = model(**inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
}
result = model(**inputs)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_flax
class Flax{{cookiecutter.camelcase_modelname}}ModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
Flax{{cookiecutter.camelcase_modelname}}Model,
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM,
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
Flax{{cookiecutter.camelcase_modelname}}ForTokenClassification,
Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
)
if is_flax_available()
else ()
)
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = Flax{{cookiecutter.camelcase_modelname}}ModelTester(self)
self.config_tester = ConfigTester(self, config_class={{cookiecutter.camelcase_modelname}}Config, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model = Flax{{cookiecutter.camelcase_modelname}}Model.from_pretrained("{{cookiecutter.checkpoint_identifier}}")
self.assertIsNotNone(model)
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
if a is None and b is None:
return True
try:
if _assert_tensors_equal(a, b, atol=atol):
return True
raise
except Exception:
if len(prefix) > 0:
prefix = f"{prefix}: "
raise AssertionError(f"{prefix}{a} != {b}")
@require_flax
class Flax{{cookiecutter.camelcase_modelname}}ModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM.from_pretrained("{{cookiecutter.checkpoint_identifier}}")
input_ids = np.array([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
# TODO Replace vocab size
vocab_size = 32000
expected_shape = [1, 6, vocab_size]
self.assertEqual(output.shape, expected_shape)
print(output[:, :3, :3])
# TODO Replace values below with what was printed above.
expected_slice = np.array(
[
[
[-0.05243197, -0.04498899, 0.05512108],
[-0.07444685, -0.01064632, 0.04352357],
[-0.05020351, 0.05530146, 0.00700043],
]
]
)
_assert_tensors_equal(output[:, :3, :3], expected_slice, atol=1e-4)
{% else %}
import unittest
from transformers import (
is_flax_available,
{{cookiecutter.camelcase_modelname}}Config,
{{cookiecutter.camelcase_modelname}}Tokenizer,
)
from transformers.testing_utils import require_sentencepiece, require_flax, require_tokenizers, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import numpy as np
import jax.numpy as jnp
from transformers import (
Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
Flax{{cookiecutter.camelcase_modelname}}Model,
)
@require_flax
class Flax{{cookiecutter.camelcase_modelname}}ModelTester:
config_cls = {{cookiecutter.camelcase_modelname}}Config
config_updates = {}
hidden_act = "gelu"
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
def prepare_config_and_inputs_for_common(self):
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size).clip(3, self.vocab_size)
eos_tensor = np.expand_dims(np.array([self.eos_token_id] * self.batch_size), 1)
input_ids = np.concatenate([input_ids, eos_tensor], axis=1)
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = self.config_cls(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_ids=[2],
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.pad_token_id,
**self.config_updates,
)
inputs_dict = prepare_{{cookiecutter.lowercase_modelname}}_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def check_use_cache_forward(self, model_class_name, config, inputs_dict):
max_decoder_length = 20
model = model_class_name(config)
encoder_outputs = model.encode(inputs_dict["input_ids"])
decoder_input_ids, decoder_attention_mask = (
inputs_dict["decoder_input_ids"],
inputs_dict["decoder_attention_mask"],
)
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
decoder_attention_mask = jnp.ones((decoder_input_ids.shape[0], max_decoder_length), dtype="i4")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :],
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1),
)
outputs_cache = model.decode(
decoder_input_ids[:, :-1],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
decoder_position_ids=decoder_position_ids,
)
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model.decode(
decoder_input_ids[:, -1:],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=outputs_cache.past_key_values,
decoder_position_ids=decoder_position_ids,
)
outputs = model.decode(decoder_input_ids, encoder_outputs)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
max_decoder_length = 20
model = model_class_name(config)
encoder_outputs = model.encode(inputs_dict["input_ids"])
decoder_input_ids, decoder_attention_mask = (
inputs_dict["decoder_input_ids"],
inputs_dict["decoder_attention_mask"],
)
decoder_attention_mask_cache = jnp.concatenate(
[
decoder_attention_mask,
jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
],
axis=-1,
)
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :],
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1),
)
outputs_cache = model.decode(
decoder_input_ids[:, :-1],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask_cache,
past_key_values=past_key_values,
decoder_position_ids=decoder_position_ids,
)
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model.decode(
decoder_input_ids[:, -1:],
encoder_outputs,
past_key_values=outputs_cache.past_key_values,
decoder_attention_mask=decoder_attention_mask_cache,
decoder_position_ids=decoder_position_ids,
)
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def prepare_{{cookiecutter.lowercase_modelname}}_inputs_dict(
config,
input_ids,
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
):
if attention_mask is None:
attention_mask = np.not_equal(input_ids, config.pad_token_id).astype(np.int8)
if decoder_attention_mask is None:
decoder_attention_mask = np.concatenate([np.ones(decoder_input_ids[:, :1].shape, dtype=np.int8), np.not_equal(decoder_input_ids[:, 1:], config.pad_token_id).astype(np.int8)], axis=-1)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
}
@require_flax
class Flax{{cookiecutter.camelcase_modelname}}ModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
Flax{{cookiecutter.camelcase_modelname}}Model,
) if is_flax_available()
else ()
)
all_generative_model_classes = (Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,) if is_flax_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = Flax{{cookiecutter.camelcase_modelname}}ModelTester(self)
self.config_tester = ConfigTester(self, config_class={{cookiecutter.camelcase_modelname}}Config)
def test_config(self):
self.config_tester.run_common_tests()
def test_use_cache_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward(model_class, config, inputs_dict)
def test_use_cache_forward_with_attn_mask(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict)
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
if a is None and b is None:
return True
try:
if _assert_tensors_equal(a, b, atol=atol):
return True
raise
except Exception:
if len(prefix) > 0:
prefix = f"{prefix}: "
raise AssertionError(f"{prefix}{a} != {b}")
def _long_tensor(tok_lst):
return np.array(tok_lst, dtype=np.int32)
TOLERANCE = 1e-4
@slow
@require_sentencepiece
@require_tokenizers
@require_flax
class Flax{{cookiecutter.camelcase_modelname}}ModelIntegrationTest(unittest.TestCase):
def test_inference_no_head(self):
model = Flax{{cookiecutter.camelcase_modelname}}Model.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
# change to intended input here
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
decoder_input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
inputs_dict = prepare_{{cookiecutter.lowercase_modelname}}_inputs_dict(model.config, input_ids, decoder_input_ids)
output = model(**inputs_dict)[0]
expected_shape = (1, 11, 1024)
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = np.array(
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
)
_assert_tensors_equal(output[:, :3, :3], expected_slice, atol=TOLERANCE)
def test_inference_with_head(self):
model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
# change to intended input here
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
decoder_input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
inputs_dict = prepare_{{cookiecutter.lowercase_modelname}}_inputs_dict(model.config, input_ids, decoder_input_ids)
output = model(**inputs_dict)[0]
expected_shape = (1, 11, 1024)
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = np.array(
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
)
_assert_tensors_equal(output[:, :3, :3], expected_slice, atol=TOLERANCE)
def test_seq_to_seq_generation(self):
hf = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
tok = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
batch_input = [
# string 1,
# string 2,
# string 3,
# string 4,
]
# The below article tests that we don't add any hypotheses outside of the top n_beams
dct = tok.batch_encode_plus(
batch_input,
max_length=512,
padding="max_length",
truncation_strategy="only_first",
truncation=True,
return_tensors="np",
)
hypotheses_batch = hf.generate(
input_ids=dct["input_ids"],
attention_mask=dct["attention_mask"],
num_beams=2,
)
EXPECTED = [
# here expected 1,
# here expected 2,
# here expected 3,
# here expected 4,
]
generated = tok.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
assert generated == EXPECTED
{%- endif %}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment