Unverified Commit 7f286132 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TFBart] Split TF-Bart (#9497)

* make templates ready

* make add_new_model_command_ready

* finish tf bart

* prepare tf mbart

* finish tf bart

* add tf mbart

* add marian

* prep pegasus

* add tf pegasus

* push blenderbot tf

* add blenderbot

* add blenderbot small

* clean-up

* make fix copy

* define blend bot tok

* fix

* up

* make style

* add to docs

* add copy statements

* overwrite changes

* improve

* fix docs

* finish

* fix last slow test

* fix missing git conflict line

* fix blenderbot

* up

* fix blenderbot small

* load changes

* finish copied from

* upload fix
parent 0ecbb698
...@@ -159,17 +159,6 @@ class MarianConfig(PretrainedConfig): ...@@ -159,17 +159,6 @@ class MarianConfig(PretrainedConfig):
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 0
self.normalize_before = False
self.add_final_layer_norm = False
self.do_blenderbot_90_layernorm = False
self.normalize_embedding = False
self.static_position_embeddings = True
self.add_bias_logits = False
self.force_bos_token_to_be_generated = False
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.encoder_attention_heads return self.encoder_attention_heads
......
...@@ -47,7 +47,7 @@ if is_torch_available(): ...@@ -47,7 +47,7 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration"] _import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -70,7 +70,7 @@ if TYPE_CHECKING: ...@@ -70,7 +70,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
else: else:
import importlib import importlib
......
...@@ -159,17 +159,6 @@ class MBartConfig(PretrainedConfig): ...@@ -159,17 +159,6 @@ class MBartConfig(PretrainedConfig):
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 2
self.normalize_before = True
self.add_final_layer_norm = True
self.do_blenderbot_90_layernorm = False
self.normalize_embedding = True
self.static_position_embeddings = False
self.add_bias_logits = False
self.force_bos_token_to_be_generated = False
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.encoder_attention_heads return self.encoder_attention_heads
......
...@@ -45,7 +45,7 @@ if is_torch_available(): ...@@ -45,7 +45,7 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration"] _import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -66,7 +66,7 @@ if TYPE_CHECKING: ...@@ -66,7 +66,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
else: else:
import importlib import importlib
......
...@@ -159,17 +159,6 @@ class PegasusConfig(PretrainedConfig): ...@@ -159,17 +159,6 @@ class PegasusConfig(PretrainedConfig):
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 0
self.normalize_before = True
self.add_final_layer_norm = True
self.do_blenderbot_90_layernorm = False
self.normalize_embedding = False
self.static_position_embeddings = True
self.add_bias_logits = False
self.force_bos_token_to_be_generated = False
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.encoder_attention_heads return self.encoder_attention_heads
......
This diff is collapsed.
# coding=utf-8 # coding=utf-8
# Copyright {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved. # Copyright 2021 {{cookiecutter.authors}} The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved. # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,8 +21,8 @@ import unittest ...@@ -21,8 +21,8 @@ import unittest
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
...@@ -31,15 +31,17 @@ if is_torch_available(): ...@@ -31,15 +31,17 @@ if is_torch_available():
from transformers import ( from transformers import (
{{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Config,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForCausalLM, {{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice, {{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification, {{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
) )
from transformers.models.{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
)
class {{cookiecutter.camelcase_modelname}}ModelTester: class {{cookiecutter.camelcase_modelname}}ModelTester:
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment