Unverified Commit c3d9ac76 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Expose get_config() on ModelTesters (#12812)

* Expose get_config() on ModelTesters

* Typo
parent cabcc751
......@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers import FlaubertConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -27,7 +26,6 @@ if is_torch_available():
import torch
from transformers import (
FlaubertConfig,
FlaubertForMultipleChoice,
FlaubertForQuestionAnswering,
FlaubertForQuestionAnsweringSimple,
......@@ -96,7 +94,22 @@ class FlaubertModelTester(object):
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = FlaubertConfig(
config = self.get_config()
return (
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
)
def get_config(self):
return FlaubertConfig(
vocab_size=self.vocab_size,
n_special=self.n_special,
emb_dim=self.hidden_size,
......@@ -115,18 +128,6 @@ class FlaubertModelTester(object):
use_proj=self.use_proj,
)
return (
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
)
def create_and_check_flaubert_model(
self,
config,
......
......@@ -19,7 +19,7 @@ import unittest
import timeout_decorator # noqa
from parameterized import parameterized
from transformers import is_torch_available
from transformers import FSMTConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
......@@ -32,7 +32,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer
from transformers import FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer
from transformers.models.fsmt.modeling_fsmt import (
SinusoidalPositionalEmbedding,
_prepare_fsmt_decoder_inputs,
......@@ -42,8 +42,7 @@ if is_torch_available():
from transformers.pipelines import TranslationPipeline
@require_torch
class ModelTester:
class FSMTModelTester:
def __init__(
self,
parent,
......@@ -78,7 +77,12 @@ class ModelTester:
)
input_ids[:, -1] = 2 # Eos Token
config = FSMTConfig(
config = self.get_config()
inputs_dict = prepare_fsmt_inputs_dict(config, input_ids)
return config, inputs_dict
def get_config(self):
return FSMTConfig(
vocab_size=self.src_vocab_size, # hack needed for common tests
src_vocab_size=self.src_vocab_size,
tgt_vocab_size=self.tgt_vocab_size,
......@@ -97,8 +101,6 @@ class ModelTester:
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
inputs_dict = prepare_fsmt_inputs_dict(config, input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
......@@ -141,7 +143,7 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
test_missing_keys = False
def setUp(self):
self.model_tester = ModelTester(self)
self.model_tester = FSMTModelTester(self)
self.langs = ["en", "ru"]
config = {
"langs": self.langs,
......
......@@ -16,7 +16,7 @@
import unittest
from transformers import FunnelTokenizer, is_torch_available
from transformers import FunnelConfig, FunnelTokenizer, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
......@@ -30,7 +30,6 @@ if is_torch_available():
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
FunnelBaseModel,
FunnelConfig,
FunnelForMaskedLM,
FunnelForMultipleChoice,
FunnelForPreTraining,
......@@ -127,7 +126,21 @@ class FunnelModelTester:
choice_labels = ids_tensor([self.batch_size], self.num_choices)
fake_token_labels = ids_tensor([self.batch_size, self.seq_length], 1)
config = FunnelConfig(
config = self.get_config()
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
fake_token_labels,
)
def get_config(self):
return FunnelConfig(
vocab_size=self.vocab_size,
block_sizes=self.block_sizes,
num_decoder_layers=self.num_decoder_layers,
......@@ -143,17 +156,6 @@ class FunnelModelTester:
type_vocab_size=self.type_vocab_size,
)
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
fake_token_labels,
)
def create_and_check_model(
self,
config,
......
......@@ -17,7 +17,7 @@
import datetime
import unittest
from transformers import is_torch_available
from transformers import GPT2Config, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -30,7 +30,6 @@ if is_torch_available():
from transformers import (
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
GPT2Config,
GPT2DoubleHeadsModel,
GPT2ForSequenceClassification,
GPT2LMHeadModel,
......@@ -119,25 +118,7 @@ class GPT2ModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = GPT2Config(
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
# intermediate_size=self.intermediate_size,
# hidden_act=self.hidden_act,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......@@ -153,6 +134,27 @@ class GPT2ModelTester:
choice_labels,
)
def get_config(self, gradient_checkpointing=False):
return GPT2Config(
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
def prepare_config_and_inputs_for_decoder(self):
(
config,
......
......@@ -17,7 +17,7 @@
import unittest
from transformers import is_torch_available
from transformers import GPTNeoConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
......@@ -32,7 +32,6 @@ if is_torch_available():
from transformers import (
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPT2Tokenizer,
GPTNeoConfig,
GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoModel,
......@@ -123,20 +122,7 @@ class GPTNeoModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = GPTNeoConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_layers=self.num_hidden_layers,
num_heads=self.num_attention_heads,
max_position_embeddings=self.max_position_embeddings,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
window_size=self.window_size,
attention_types=self.attention_types,
)
config = self.get_config(gradient_checkpointing=False)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......@@ -152,6 +138,22 @@ class GPTNeoModelTester:
choice_labels,
)
def get_config(self, gradient_checkpointing=False):
return GPTNeoConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_layers=self.num_hidden_layers,
num_heads=self.num_attention_heads,
max_position_embeddings=self.max_position_embeddings,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
window_size=self.window_size,
attention_types=self.attention_types,
)
def prepare_config_and_inputs_for_decoder(self):
(
config,
......
......@@ -21,7 +21,7 @@ import unittest
import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import is_torch_available
from transformers import HubertConfig, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -31,7 +31,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available():
import torch
from transformers import HubertConfig, HubertForCTC, HubertModel, Wav2Vec2Processor
from transformers import HubertForCTC, HubertModel, Wav2Vec2Processor
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
......@@ -98,7 +98,12 @@ class HubertModelTester:
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = HubertConfig(
config = self.get_config()
return config, input_values, attention_mask
def get_config(self):
return HubertConfig(
hidden_size=self.hidden_size,
feat_extract_norm=self.feat_extract_norm,
feat_extract_dropout=self.feat_extract_dropout,
......@@ -119,8 +124,6 @@ class HubertModelTester:
vocab_size=self.vocab_size,
)
return config, input_values, attention_mask
def create_and_check_model(self, config, input_values, attention_mask):
model = HubertModel(config=config)
model.to(torch_device)
......
......@@ -17,7 +17,7 @@
import copy
import unittest
from transformers import is_torch_available
from transformers import IBertConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -30,7 +30,6 @@ if is_torch_available():
from transformers import (
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
IBertConfig,
IBertForMaskedLM,
IBertForMultipleChoice,
IBertForQuestionAnswering,
......@@ -97,7 +96,12 @@ class IBertModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = IBertConfig(
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return IBertConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
......@@ -112,8 +116,6 @@ class IBertModelTester:
quant_mode=True,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
......
......@@ -16,7 +16,7 @@
import unittest
from transformers import is_torch_available
from transformers import LayoutLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -27,7 +27,6 @@ if is_torch_available():
import torch
from transformers import (
LayoutLMConfig,
LayoutLMForMaskedLM,
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
......@@ -120,7 +119,12 @@ class LayoutLMModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = LayoutLMConfig(
config = self.get_config()
return config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return LayoutLMConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
......@@ -134,8 +138,6 @@ class LayoutLMModelTester:
initializer_range=self.initializer_range,
)
return config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_model(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
......
......@@ -19,7 +19,7 @@ import copy
import tempfile
import unittest
from transformers import is_torch_available
from transformers import LEDConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
......@@ -34,7 +34,6 @@ if is_torch_available():
from transformers import (
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
LEDConfig,
LEDForConditionalGeneration,
LEDForQuestionAnswering,
LEDForSequenceClassification,
......@@ -75,7 +74,6 @@ def prepare_led_inputs_dict(
}
@require_torch
class LEDModelTester:
def __init__(
self,
......@@ -141,7 +139,12 @@ class LEDModelTester:
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = LEDConfig(
config = self.get_config()
inputs_dict = prepare_led_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def get_config(self):
return LEDConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
......@@ -158,8 +161,6 @@ class LEDModelTester:
pad_token_id=self.pad_token_id,
attention_window=self.attention_window,
)
inputs_dict = prepare_led_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
......
......@@ -16,7 +16,7 @@
import unittest
from transformers import is_torch_available
from transformers import LongformerConfig, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -27,7 +27,6 @@ if is_torch_available():
import torch
from transformers import (
LongformerConfig,
LongformerForMaskedLM,
LongformerForMultipleChoice,
LongformerForQuestionAnswering,
......@@ -100,7 +99,12 @@ class LongformerModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = LongformerConfig(
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return LongformerConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
......@@ -115,8 +119,6 @@ class LongformerModelTester:
attention_window=self.attention_window,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_attention_mask_determinism(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
......
......@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch LUKE model. """
import unittest
from transformers import is_torch_available
from transformers import LukeConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -27,7 +26,6 @@ if is_torch_available():
import torch
from transformers import (
LukeConfig,
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
......@@ -154,7 +152,25 @@ class LukeModelTester:
[self.batch_size, self.entity_length], self.num_entity_span_classification_labels
)
config = LukeConfig(
config = self.get_config()
return (
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
)
def get_config(self):
return LukeConfig(
vocab_size=self.vocab_size,
entity_vocab_size=self.entity_vocab_size,
entity_emb_size=self.entity_emb_size,
......@@ -172,21 +188,6 @@ class LukeModelTester:
use_entity_aware_attention=self.use_entity_aware_attention,
)
return (
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
)
def create_and_check_model(
self,
config,
......
......@@ -19,7 +19,7 @@ import unittest
import numpy as np
from transformers import is_torch_available
from transformers import LxmertConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
......@@ -33,7 +33,6 @@ if is_torch_available():
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
LxmertConfig,
LxmertForPreTraining,
LxmertForQuestionAnswering,
LxmertModel,
......@@ -170,7 +169,24 @@ class LxmertModelTester:
if self.task_matched:
matched_label = ids_tensor([self.batch_size], self.num_labels)
config = LxmertConfig(
config = self.get_config()
return (
config,
input_ids,
visual_feats,
bounding_boxes,
token_type_ids,
input_mask,
obj_labels,
masked_lm_labels,
matched_label,
ans,
output_attentions,
)
def get_config(self):
return LxmertConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads,
......@@ -204,20 +220,6 @@ class LxmertModelTester:
output_hidden_states=self.output_hidden_states,
)
return (
config,
input_ids,
visual_feats,
bounding_boxes,
token_type_ids,
input_mask,
obj_labels,
masked_lm_labels,
matched_label,
ans,
output_attentions,
)
def create_and_check_lxmert_model(
self,
config,
......
......@@ -19,7 +19,7 @@ import copy
import tempfile
import unittest
from transformers import is_torch_available
from transformers import M2M100Config, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
......@@ -31,7 +31,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
import torch
from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Model, M2M100Tokenizer
from transformers import M2M100ForConditionalGeneration, M2M100Model, M2M100Tokenizer
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Decoder, M2M100Encoder
......@@ -66,7 +66,6 @@ def prepare_m2m_100_inputs_dict(
}
@require_torch
class M2M100ModelTester:
def __init__(
self,
......@@ -125,7 +124,12 @@ class M2M100ModelTester:
input_ids = input_ids.clamp(self.pad_token_id + 1)
decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1)
config = M2M100Config(
config = self.get_config()
inputs_dict = prepare_m2m_100_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def get_config(self):
return M2M100Config(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
......@@ -143,8 +147,6 @@ class M2M100ModelTester:
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
inputs_dict = prepare_m2m_100_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
......
......@@ -17,7 +17,7 @@
import tempfile
import unittest
from transformers import is_torch_available
from transformers import MarianConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.hf_api import HfApi
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
......@@ -34,7 +34,6 @@ if is_torch_available():
AutoConfig,
AutoModelWithLMHead,
AutoTokenizer,
MarianConfig,
MarianModel,
MarianMTModel,
TranslationPipeline,
......@@ -83,7 +82,6 @@ def prepare_marian_inputs_dict(
}
@require_torch
class MarianModelTester:
def __init__(
self,
......@@ -126,7 +124,6 @@ class MarianModelTester:
self.decoder_start_token_id = decoder_start_token_id
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3,
)
......@@ -134,7 +131,12 @@ class MarianModelTester:
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = MarianConfig(
config = self.get_config()
inputs_dict = prepare_marian_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def get_config(self):
return MarianConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
......@@ -151,8 +153,6 @@ class MarianModelTester:
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
)
inputs_dict = prepare_marian_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
......
......@@ -19,7 +19,7 @@ import copy
import tempfile
import unittest
from transformers import is_torch_available
from transformers import MBartConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
......@@ -34,7 +34,6 @@ if is_torch_available():
from transformers import (
AutoTokenizer,
BatchEncoding,
MBartConfig,
MBartForCausalLM,
MBartForConditionalGeneration,
MBartForQuestionAnswering,
......@@ -75,7 +74,6 @@ def prepare_mbart_inputs_dict(
}
@require_torch
class MBartModelTester:
def __init__(
self,
......@@ -124,7 +122,12 @@ class MBartModelTester:
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = MBartConfig(
config = self.get_config()
inputs_dict = prepare_mbart_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def get_config(self):
return MBartConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
......@@ -140,8 +143,6 @@ class MBartModelTester:
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
inputs_dict = prepare_mbart_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
......
......@@ -19,7 +19,7 @@ import math
import os
import unittest
from transformers import is_torch_available
from transformers import MegatronBertConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
......@@ -32,7 +32,6 @@ if is_torch_available():
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
MegatronBertConfig,
MegatronBertForCausalLM,
MegatronBertForMaskedLM,
MegatronBertForMultipleChoice,
......@@ -115,7 +114,12 @@ class MegatronBertModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = MegatronBertConfig(
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return MegatronBertConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
......@@ -131,8 +135,6 @@ class MegatronBertModelTester:
initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_megatron_bert_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
......
......@@ -16,7 +16,7 @@
import unittest
from transformers import is_torch_available
from transformers import MobileBertConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
......@@ -29,7 +29,6 @@ if is_torch_available():
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
MobileBertConfig,
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
MobileBertForNextSentencePrediction,
......@@ -111,7 +110,12 @@ class MobileBertModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = MobileBertConfig(
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return MobileBertConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
......@@ -127,8 +131,6 @@ class MobileBertModelTester:
initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_mobilebert_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
......
......@@ -16,7 +16,7 @@
import unittest
from transformers import is_torch_available
from transformers import MPNetConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -27,7 +27,6 @@ if is_torch_available():
import torch
from transformers import (
MPNetConfig,
MPNetForMaskedLM,
MPNetForMultipleChoice,
MPNetForQuestionAnswering,
......@@ -104,7 +103,11 @@ class MPNetModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = MPNetConfig(
config = self.get_config()
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return MPNetConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
......@@ -116,7 +119,6 @@ class MPNetModelTester:
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
)
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_mpnet_model(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
......
......@@ -17,7 +17,7 @@
import tempfile
import unittest
from transformers import is_torch_available
from transformers import PegasusConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
......@@ -30,7 +30,7 @@ from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest
if is_torch_available():
import torch
from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel
from transformers import AutoModelForSeq2SeqLM, PegasusForConditionalGeneration, PegasusModel
from transformers.models.pegasus.modeling_pegasus import PegasusDecoder, PegasusEncoder, PegasusForCausalLM
......@@ -65,7 +65,6 @@ def prepare_pegasus_inputs_dict(
}
@require_torch
class PegasusModelTester:
def __init__(
self,
......@@ -114,7 +113,12 @@ class PegasusModelTester:
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = PegasusConfig(
config = self.get_config()
inputs_dict = prepare_pegasus_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def get_config(self):
return PegasusConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
......@@ -130,8 +134,6 @@ class PegasusModelTester:
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
inputs_dict = prepare_pegasus_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
......
......@@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import tempfile
import unittest
from transformers import is_torch_available
from transformers import ProphetNetConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -30,7 +29,6 @@ if is_torch_available():
import torch
from transformers import (
ProphetNetConfig,
ProphetNetDecoder,
ProphetNetEncoder,
ProphetNetForCausalLM,
......@@ -124,7 +122,19 @@ class ProphetNetModelTester:
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = ProphetNetConfig(
config = self.get_config()
return (
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
)
def get_config(self):
return ProphetNetConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_encoder_layers=self.num_encoder_layers,
......@@ -145,15 +155,6 @@ class ProphetNetModelTester:
is_encoder_decoder=self.is_encoder_decoder,
)
return (
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
)
def prepare_config_and_inputs_for_decoder(self):
(
config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment