Unverified Commit 7f286132 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TFBart] Split TF-Bart (#9497)

* make templates ready

* make add_new_model_command_ready

* finish tf bart

* prepare tf mbart

* finish tf bart

* add tf mbart

* add marian

* prep pegasus

* add tf pegasus

* push blenderbot tf

* add blenderbot

* add blenderbot small

* clean-up

* make fix copy

* define blend bot tok

* fix

* up

* make style

* add to docs

* add copy statements

* overwrite changes

* improve

* fix docs

* finish

* fix last slow test

* fix missing git conflict line

* fix blenderbot

* up

* fix blenderbot small

* load changes

* finish copied from

* upload fix
parent 0ecbb698
......@@ -159,17 +159,6 @@ class MarianConfig(PretrainedConfig):
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 0
self.normalize_before = False
self.add_final_layer_norm = False
self.do_blenderbot_90_layernorm = False
self.normalize_embedding = False
self.static_position_embeddings = True
self.add_bias_logits = False
self.force_bos_token_to_be_generated = False
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
......
......@@ -47,7 +47,7 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration"]
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"]
if TYPE_CHECKING:
......@@ -70,7 +70,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
else:
import importlib
......
......@@ -159,17 +159,6 @@ class MBartConfig(PretrainedConfig):
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 2
self.normalize_before = True
self.add_final_layer_norm = True
self.do_blenderbot_90_layernorm = False
self.normalize_embedding = True
self.static_position_embeddings = False
self.add_bias_logits = False
self.force_bos_token_to_be_generated = False
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
......
......@@ -45,7 +45,7 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration"]
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"]
if TYPE_CHECKING:
......@@ -66,7 +66,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
else:
import importlib
......
......@@ -159,17 +159,6 @@ class PegasusConfig(PretrainedConfig):
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 0
self.normalize_before = True
self.add_final_layer_norm = True
self.do_blenderbot_90_layernorm = False
self.normalize_embedding = False
self.static_position_embeddings = True
self.add_bias_logits = False
self.force_bos_token_to_be_generated = False
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
......
This diff is collapsed.
# coding=utf-8
# Copyright {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
# Copyright 2021 {{cookiecutter.authors}} The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding=utf-8
# Copyright {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,8 +21,8 @@ import unittest
from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
......@@ -31,15 +31,17 @@ if is_torch_available():
from transformers import (
{{cookiecutter.camelcase_modelname}}Config,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Model,
)
from transformers.models.{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.models.{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
)
class {{cookiecutter.camelcase_modelname}}ModelTester:
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment