"INSTALL/tool/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "72d0fd0dd42604afad158b6d1e12736b5268532a"
Unverified Commit c3d9ac76 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Expose get_config() on ModelTesters (#12812)

* Expose get_config() on ModelTesters

* Typo
parent cabcc751
...@@ -13,10 +13,9 @@ ...@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import unittest
from transformers import is_torch_available from transformers import FlaubertConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -27,7 +26,6 @@ if is_torch_available(): ...@@ -27,7 +26,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
FlaubertConfig,
FlaubertForMultipleChoice, FlaubertForMultipleChoice,
FlaubertForQuestionAnswering, FlaubertForQuestionAnswering,
FlaubertForQuestionAnsweringSimple, FlaubertForQuestionAnsweringSimple,
...@@ -96,7 +94,22 @@ class FlaubertModelTester(object): ...@@ -96,7 +94,22 @@ class FlaubertModelTester(object):
is_impossible_labels = ids_tensor([self.batch_size], 2).float() is_impossible_labels = ids_tensor([self.batch_size], 2).float()
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = FlaubertConfig( config = self.get_config()
return (
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
)
def get_config(self):
return FlaubertConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
n_special=self.n_special, n_special=self.n_special,
emb_dim=self.hidden_size, emb_dim=self.hidden_size,
...@@ -115,18 +128,6 @@ class FlaubertModelTester(object): ...@@ -115,18 +128,6 @@ class FlaubertModelTester(object):
use_proj=self.use_proj, use_proj=self.use_proj,
) )
return (
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
)
def create_and_check_flaubert_model( def create_and_check_flaubert_model(
self, self,
config, config,
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import timeout_decorator # noqa import timeout_decorator # noqa
from parameterized import parameterized from parameterized import parameterized
from transformers import is_torch_available from transformers import FSMTConfig, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -32,7 +32,7 @@ if is_torch_available(): ...@@ -32,7 +32,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer from transformers import FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer
from transformers.models.fsmt.modeling_fsmt import ( from transformers.models.fsmt.modeling_fsmt import (
SinusoidalPositionalEmbedding, SinusoidalPositionalEmbedding,
_prepare_fsmt_decoder_inputs, _prepare_fsmt_decoder_inputs,
...@@ -42,8 +42,7 @@ if is_torch_available(): ...@@ -42,8 +42,7 @@ if is_torch_available():
from transformers.pipelines import TranslationPipeline from transformers.pipelines import TranslationPipeline
@require_torch class FSMTModelTester:
class ModelTester:
def __init__( def __init__(
self, self,
parent, parent,
...@@ -78,7 +77,12 @@ class ModelTester: ...@@ -78,7 +77,12 @@ class ModelTester:
) )
input_ids[:, -1] = 2 # Eos Token input_ids[:, -1] = 2 # Eos Token
config = FSMTConfig( config = self.get_config()
inputs_dict = prepare_fsmt_inputs_dict(config, input_ids)
return config, inputs_dict
def get_config(self):
return FSMTConfig(
vocab_size=self.src_vocab_size, # hack needed for common tests vocab_size=self.src_vocab_size, # hack needed for common tests
src_vocab_size=self.src_vocab_size, src_vocab_size=self.src_vocab_size,
tgt_vocab_size=self.tgt_vocab_size, tgt_vocab_size=self.tgt_vocab_size,
...@@ -97,8 +101,6 @@ class ModelTester: ...@@ -97,8 +101,6 @@ class ModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
inputs_dict = prepare_fsmt_inputs_dict(config, input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs() config, inputs_dict = self.prepare_config_and_inputs()
...@@ -141,7 +143,7 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -141,7 +143,7 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
test_missing_keys = False test_missing_keys = False
def setUp(self): def setUp(self):
self.model_tester = ModelTester(self) self.model_tester = FSMTModelTester(self)
self.langs = ["en", "ru"] self.langs = ["en", "ru"]
config = { config = {
"langs": self.langs, "langs": self.langs,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from transformers import FunnelTokenizer, is_torch_available from transformers import FunnelConfig, FunnelTokenizer, is_torch_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -30,7 +30,6 @@ if is_torch_available(): ...@@ -30,7 +30,6 @@ if is_torch_available():
from transformers import ( from transformers import (
MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_PRETRAINING_MAPPING,
FunnelBaseModel, FunnelBaseModel,
FunnelConfig,
FunnelForMaskedLM, FunnelForMaskedLM,
FunnelForMultipleChoice, FunnelForMultipleChoice,
FunnelForPreTraining, FunnelForPreTraining,
...@@ -127,7 +126,21 @@ class FunnelModelTester: ...@@ -127,7 +126,21 @@ class FunnelModelTester:
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
fake_token_labels = ids_tensor([self.batch_size, self.seq_length], 1) fake_token_labels = ids_tensor([self.batch_size, self.seq_length], 1)
config = FunnelConfig( config = self.get_config()
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
fake_token_labels,
)
def get_config(self):
return FunnelConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
block_sizes=self.block_sizes, block_sizes=self.block_sizes,
num_decoder_layers=self.num_decoder_layers, num_decoder_layers=self.num_decoder_layers,
...@@ -143,17 +156,6 @@ class FunnelModelTester: ...@@ -143,17 +156,6 @@ class FunnelModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
) )
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
fake_token_labels,
)
def create_and_check_model( def create_and_check_model(
self, self,
config, config,
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import datetime import datetime
import unittest import unittest
from transformers import is_torch_available from transformers import GPT2Config, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -30,7 +30,6 @@ if is_torch_available(): ...@@ -30,7 +30,6 @@ if is_torch_available():
from transformers import ( from transformers import (
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
GPT2Config,
GPT2DoubleHeadsModel, GPT2DoubleHeadsModel,
GPT2ForSequenceClassification, GPT2ForSequenceClassification,
GPT2LMHeadModel, GPT2LMHeadModel,
...@@ -119,25 +118,7 @@ class GPT2ModelTester: ...@@ -119,25 +118,7 @@ class GPT2ModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = GPT2Config( config = self.get_config(gradient_checkpointing=gradient_checkpointing)
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
# intermediate_size=self.intermediate_size,
# hidden_act=self.hidden_act,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
...@@ -153,6 +134,27 @@ class GPT2ModelTester: ...@@ -153,6 +134,27 @@ class GPT2ModelTester:
choice_labels, choice_labels,
) )
def get_config(self, gradient_checkpointing=False):
return GPT2Config(
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
def prepare_config_and_inputs_for_decoder(self): def prepare_config_and_inputs_for_decoder(self):
( (
config, config,
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import GPTNeoConfig, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
...@@ -32,7 +32,6 @@ if is_torch_available(): ...@@ -32,7 +32,6 @@ if is_torch_available():
from transformers import ( from transformers import (
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPT2Tokenizer, GPT2Tokenizer,
GPTNeoConfig,
GPTNeoForCausalLM, GPTNeoForCausalLM,
GPTNeoForSequenceClassification, GPTNeoForSequenceClassification,
GPTNeoModel, GPTNeoModel,
...@@ -123,20 +122,7 @@ class GPTNeoModelTester: ...@@ -123,20 +122,7 @@ class GPTNeoModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = GPTNeoConfig( config = self.get_config(gradient_checkpointing=False)
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_layers=self.num_hidden_layers,
num_heads=self.num_attention_heads,
max_position_embeddings=self.max_position_embeddings,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
window_size=self.window_size,
attention_types=self.attention_types,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
...@@ -152,6 +138,22 @@ class GPTNeoModelTester: ...@@ -152,6 +138,22 @@ class GPTNeoModelTester:
choice_labels, choice_labels,
) )
def get_config(self, gradient_checkpointing=False):
return GPTNeoConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_layers=self.num_hidden_layers,
num_heads=self.num_attention_heads,
max_position_embeddings=self.max_position_embeddings,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
window_size=self.window_size,
attention_types=self.attention_types,
)
def prepare_config_and_inputs_for_decoder(self): def prepare_config_and_inputs_for_decoder(self):
( (
config, config,
......
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
import pytest import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import is_torch_available from transformers import HubertConfig, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -31,7 +31,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init ...@@ -31,7 +31,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import HubertConfig, HubertForCTC, HubertModel, Wav2Vec2Processor from transformers import HubertForCTC, HubertModel, Wav2Vec2Processor
from transformers.models.hubert.modeling_hubert import _compute_mask_indices from transformers.models.hubert.modeling_hubert import _compute_mask_indices
...@@ -98,7 +98,12 @@ class HubertModelTester: ...@@ -98,7 +98,12 @@ class HubertModelTester:
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size) input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = random_attention_mask([self.batch_size, self.seq_length]) attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = HubertConfig( config = self.get_config()
return config, input_values, attention_mask
def get_config(self):
return HubertConfig(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
feat_extract_norm=self.feat_extract_norm, feat_extract_norm=self.feat_extract_norm,
feat_extract_dropout=self.feat_extract_dropout, feat_extract_dropout=self.feat_extract_dropout,
...@@ -119,8 +124,6 @@ class HubertModelTester: ...@@ -119,8 +124,6 @@ class HubertModelTester:
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
return config, input_values, attention_mask
def create_and_check_model(self, config, input_values, attention_mask): def create_and_check_model(self, config, input_values, attention_mask):
model = HubertModel(config=config) model = HubertModel(config=config)
model.to(torch_device) model.to(torch_device)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import copy import copy
import unittest import unittest
from transformers import is_torch_available from transformers import IBertConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -30,7 +30,6 @@ if is_torch_available(): ...@@ -30,7 +30,6 @@ if is_torch_available():
from transformers import ( from transformers import (
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST, IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
IBertConfig,
IBertForMaskedLM, IBertForMaskedLM,
IBertForMultipleChoice, IBertForMultipleChoice,
IBertForQuestionAnswering, IBertForQuestionAnswering,
...@@ -97,7 +96,12 @@ class IBertModelTester: ...@@ -97,7 +96,12 @@ class IBertModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = IBertConfig( config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return IBertConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
...@@ -112,8 +116,6 @@ class IBertModelTester: ...@@ -112,8 +116,6 @@ class IBertModelTester:
quant_mode=True, quant_mode=True,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_model( def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import LayoutLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -27,7 +27,6 @@ if is_torch_available(): ...@@ -27,7 +27,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
LayoutLMConfig,
LayoutLMForMaskedLM, LayoutLMForMaskedLM,
LayoutLMForSequenceClassification, LayoutLMForSequenceClassification,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
...@@ -120,7 +119,12 @@ class LayoutLMModelTester: ...@@ -120,7 +119,12 @@ class LayoutLMModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = LayoutLMConfig( config = self.get_config()
return config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return LayoutLMConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
...@@ -134,8 +138,6 @@ class LayoutLMModelTester: ...@@ -134,8 +138,6 @@ class LayoutLMModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
) )
return config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_model( def create_and_check_model(
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
......
...@@ -19,7 +19,7 @@ import copy ...@@ -19,7 +19,7 @@ import copy
import tempfile import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import LEDConfig, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -34,7 +34,6 @@ if is_torch_available(): ...@@ -34,7 +34,6 @@ if is_torch_available():
from transformers import ( from transformers import (
MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING,
LEDConfig,
LEDForConditionalGeneration, LEDForConditionalGeneration,
LEDForQuestionAnswering, LEDForQuestionAnswering,
LEDForSequenceClassification, LEDForSequenceClassification,
...@@ -75,7 +74,6 @@ def prepare_led_inputs_dict( ...@@ -75,7 +74,6 @@ def prepare_led_inputs_dict(
} }
@require_torch
class LEDModelTester: class LEDModelTester:
def __init__( def __init__(
self, self,
...@@ -141,7 +139,12 @@ class LEDModelTester: ...@@ -141,7 +139,12 @@ class LEDModelTester:
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = LEDConfig( config = self.get_config()
inputs_dict = prepare_led_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def get_config(self):
return LEDConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=self.hidden_size, d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers, encoder_layers=self.num_hidden_layers,
...@@ -158,8 +161,6 @@ class LEDModelTester: ...@@ -158,8 +161,6 @@ class LEDModelTester:
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
attention_window=self.attention_window, attention_window=self.attention_window,
) )
inputs_dict = prepare_led_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs() config, inputs_dict = self.prepare_config_and_inputs()
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import LongformerConfig, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -27,7 +27,6 @@ if is_torch_available(): ...@@ -27,7 +27,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
LongformerConfig,
LongformerForMaskedLM, LongformerForMaskedLM,
LongformerForMultipleChoice, LongformerForMultipleChoice,
LongformerForQuestionAnswering, LongformerForQuestionAnswering,
...@@ -100,7 +99,12 @@ class LongformerModelTester: ...@@ -100,7 +99,12 @@ class LongformerModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = LongformerConfig( config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return LongformerConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
...@@ -115,8 +119,6 @@ class LongformerModelTester: ...@@ -115,8 +119,6 @@ class LongformerModelTester:
attention_window=self.attention_window, attention_window=self.attention_window,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_attention_mask_determinism( def create_and_check_attention_mask_determinism(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
......
...@@ -13,10 +13,9 @@ ...@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch LUKE model. """ """ Testing suite for the PyTorch LUKE model. """
import unittest import unittest
from transformers import is_torch_available from transformers import LukeConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -27,7 +26,6 @@ if is_torch_available(): ...@@ -27,7 +26,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
LukeConfig,
LukeForEntityClassification, LukeForEntityClassification,
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
...@@ -154,7 +152,25 @@ class LukeModelTester: ...@@ -154,7 +152,25 @@ class LukeModelTester:
[self.batch_size, self.entity_length], self.num_entity_span_classification_labels [self.batch_size, self.entity_length], self.num_entity_span_classification_labels
) )
config = LukeConfig( config = self.get_config()
return (
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
)
def get_config(self):
return LukeConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
entity_vocab_size=self.entity_vocab_size, entity_vocab_size=self.entity_vocab_size,
entity_emb_size=self.entity_emb_size, entity_emb_size=self.entity_emb_size,
...@@ -172,21 +188,6 @@ class LukeModelTester: ...@@ -172,21 +188,6 @@ class LukeModelTester:
use_entity_aware_attention=self.use_entity_aware_attention, use_entity_aware_attention=self.use_entity_aware_attention,
) )
return (
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
)
def create_and_check_model( def create_and_check_model(
self, self,
config, config,
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import numpy as np import numpy as np
from transformers import is_torch_available from transformers import LxmertConfig, is_torch_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
...@@ -33,7 +33,6 @@ if is_torch_available(): ...@@ -33,7 +33,6 @@ if is_torch_available():
from transformers import ( from transformers import (
MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING,
LxmertConfig,
LxmertForPreTraining, LxmertForPreTraining,
LxmertForQuestionAnswering, LxmertForQuestionAnswering,
LxmertModel, LxmertModel,
...@@ -170,7 +169,24 @@ class LxmertModelTester: ...@@ -170,7 +169,24 @@ class LxmertModelTester:
if self.task_matched: if self.task_matched:
matched_label = ids_tensor([self.batch_size], self.num_labels) matched_label = ids_tensor([self.batch_size], self.num_labels)
config = LxmertConfig( config = self.get_config()
return (
config,
input_ids,
visual_feats,
bounding_boxes,
token_type_ids,
input_mask,
obj_labels,
masked_lm_labels,
matched_label,
ans,
output_attentions,
)
def get_config(self):
return LxmertConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
...@@ -204,20 +220,6 @@ class LxmertModelTester: ...@@ -204,20 +220,6 @@ class LxmertModelTester:
output_hidden_states=self.output_hidden_states, output_hidden_states=self.output_hidden_states,
) )
return (
config,
input_ids,
visual_feats,
bounding_boxes,
token_type_ids,
input_mask,
obj_labels,
masked_lm_labels,
matched_label,
ans,
output_attentions,
)
def create_and_check_lxmert_model( def create_and_check_lxmert_model(
self, self,
config, config,
......
...@@ -19,7 +19,7 @@ import copy ...@@ -19,7 +19,7 @@ import copy
import tempfile import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import M2M100Config, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -31,7 +31,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor ...@@ -31,7 +31,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Model, M2M100Tokenizer from transformers import M2M100ForConditionalGeneration, M2M100Model, M2M100Tokenizer
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Decoder, M2M100Encoder from transformers.models.m2m_100.modeling_m2m_100 import M2M100Decoder, M2M100Encoder
...@@ -66,7 +66,6 @@ def prepare_m2m_100_inputs_dict( ...@@ -66,7 +66,6 @@ def prepare_m2m_100_inputs_dict(
} }
@require_torch
class M2M100ModelTester: class M2M100ModelTester:
def __init__( def __init__(
self, self,
...@@ -125,7 +124,12 @@ class M2M100ModelTester: ...@@ -125,7 +124,12 @@ class M2M100ModelTester:
input_ids = input_ids.clamp(self.pad_token_id + 1) input_ids = input_ids.clamp(self.pad_token_id + 1)
decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1) decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1)
config = M2M100Config( config = self.get_config()
inputs_dict = prepare_m2m_100_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def get_config(self):
return M2M100Config(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=self.hidden_size, d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers, encoder_layers=self.num_hidden_layers,
...@@ -143,8 +147,6 @@ class M2M100ModelTester: ...@@ -143,8 +147,6 @@ class M2M100ModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
inputs_dict = prepare_m2m_100_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs() config, inputs_dict = self.prepare_config_and_inputs()
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import tempfile import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import MarianConfig, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.hf_api import HfApi from transformers.hf_api import HfApi
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -34,7 +34,6 @@ if is_torch_available(): ...@@ -34,7 +34,6 @@ if is_torch_available():
AutoConfig, AutoConfig,
AutoModelWithLMHead, AutoModelWithLMHead,
AutoTokenizer, AutoTokenizer,
MarianConfig,
MarianModel, MarianModel,
MarianMTModel, MarianMTModel,
TranslationPipeline, TranslationPipeline,
...@@ -83,7 +82,6 @@ def prepare_marian_inputs_dict( ...@@ -83,7 +82,6 @@ def prepare_marian_inputs_dict(
} }
@require_torch
class MarianModelTester: class MarianModelTester:
def __init__( def __init__(
self, self,
...@@ -126,7 +124,6 @@ class MarianModelTester: ...@@ -126,7 +124,6 @@ class MarianModelTester:
self.decoder_start_token_id = decoder_start_token_id self.decoder_start_token_id = decoder_start_token_id
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3, 3,
) )
...@@ -134,7 +131,12 @@ class MarianModelTester: ...@@ -134,7 +131,12 @@ class MarianModelTester:
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = MarianConfig( config = self.get_config()
inputs_dict = prepare_marian_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def get_config(self):
return MarianConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=self.hidden_size, d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers, encoder_layers=self.num_hidden_layers,
...@@ -151,8 +153,6 @@ class MarianModelTester: ...@@ -151,8 +153,6 @@ class MarianModelTester:
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id, decoder_start_token_id=self.decoder_start_token_id,
) )
inputs_dict = prepare_marian_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs() config, inputs_dict = self.prepare_config_and_inputs()
......
...@@ -19,7 +19,7 @@ import copy ...@@ -19,7 +19,7 @@ import copy
import tempfile import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import MBartConfig, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -34,7 +34,6 @@ if is_torch_available(): ...@@ -34,7 +34,6 @@ if is_torch_available():
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
BatchEncoding, BatchEncoding,
MBartConfig,
MBartForCausalLM, MBartForCausalLM,
MBartForConditionalGeneration, MBartForConditionalGeneration,
MBartForQuestionAnswering, MBartForQuestionAnswering,
...@@ -75,7 +74,6 @@ def prepare_mbart_inputs_dict( ...@@ -75,7 +74,6 @@ def prepare_mbart_inputs_dict(
} }
@require_torch
class MBartModelTester: class MBartModelTester:
def __init__( def __init__(
self, self,
...@@ -124,7 +122,12 @@ class MBartModelTester: ...@@ -124,7 +122,12 @@ class MBartModelTester:
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = MBartConfig( config = self.get_config()
inputs_dict = prepare_mbart_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def get_config(self):
return MBartConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=self.hidden_size, d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers, encoder_layers=self.num_hidden_layers,
...@@ -140,8 +143,6 @@ class MBartModelTester: ...@@ -140,8 +143,6 @@ class MBartModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
inputs_dict = prepare_mbart_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs() config, inputs_dict = self.prepare_config_and_inputs()
......
...@@ -19,7 +19,7 @@ import math ...@@ -19,7 +19,7 @@ import math
import os import os
import unittest import unittest
from transformers import is_torch_available from transformers import MegatronBertConfig, is_torch_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -32,7 +32,6 @@ if is_torch_available(): ...@@ -32,7 +32,6 @@ if is_torch_available():
from transformers import ( from transformers import (
MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_PRETRAINING_MAPPING,
MegatronBertConfig,
MegatronBertForCausalLM, MegatronBertForCausalLM,
MegatronBertForMaskedLM, MegatronBertForMaskedLM,
MegatronBertForMultipleChoice, MegatronBertForMultipleChoice,
...@@ -115,7 +114,12 @@ class MegatronBertModelTester: ...@@ -115,7 +114,12 @@ class MegatronBertModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = MegatronBertConfig( config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return MegatronBertConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
...@@ -131,8 +135,6 @@ class MegatronBertModelTester: ...@@ -131,8 +135,6 @@ class MegatronBertModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_megatron_bert_model( def create_and_check_megatron_bert_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import MobileBertConfig, is_torch_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -29,7 +29,6 @@ if is_torch_available(): ...@@ -29,7 +29,6 @@ if is_torch_available():
from transformers import ( from transformers import (
MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_PRETRAINING_MAPPING,
MobileBertConfig,
MobileBertForMaskedLM, MobileBertForMaskedLM,
MobileBertForMultipleChoice, MobileBertForMultipleChoice,
MobileBertForNextSentencePrediction, MobileBertForNextSentencePrediction,
...@@ -111,7 +110,12 @@ class MobileBertModelTester: ...@@ -111,7 +110,12 @@ class MobileBertModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = MobileBertConfig( config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return MobileBertConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
...@@ -127,8 +131,6 @@ class MobileBertModelTester: ...@@ -127,8 +131,6 @@ class MobileBertModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_mobilebert_model( def create_and_check_mobilebert_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import MPNetConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -27,7 +27,6 @@ if is_torch_available(): ...@@ -27,7 +27,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
MPNetConfig,
MPNetForMaskedLM, MPNetForMaskedLM,
MPNetForMultipleChoice, MPNetForMultipleChoice,
MPNetForQuestionAnswering, MPNetForQuestionAnswering,
...@@ -104,7 +103,11 @@ class MPNetModelTester: ...@@ -104,7 +103,11 @@ class MPNetModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = MPNetConfig( config = self.get_config()
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return MPNetConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
...@@ -116,7 +119,6 @@ class MPNetModelTester: ...@@ -116,7 +119,6 @@ class MPNetModelTester:
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
) )
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_mpnet_model( def create_and_check_mpnet_model(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import tempfile import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import PegasusConfig, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -30,7 +30,7 @@ from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest ...@@ -30,7 +30,7 @@ from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel from transformers import AutoModelForSeq2SeqLM, PegasusForConditionalGeneration, PegasusModel
from transformers.models.pegasus.modeling_pegasus import PegasusDecoder, PegasusEncoder, PegasusForCausalLM from transformers.models.pegasus.modeling_pegasus import PegasusDecoder, PegasusEncoder, PegasusForCausalLM
...@@ -65,7 +65,6 @@ def prepare_pegasus_inputs_dict( ...@@ -65,7 +65,6 @@ def prepare_pegasus_inputs_dict(
} }
@require_torch
class PegasusModelTester: class PegasusModelTester:
def __init__( def __init__(
self, self,
...@@ -114,7 +113,12 @@ class PegasusModelTester: ...@@ -114,7 +113,12 @@ class PegasusModelTester:
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = PegasusConfig( config = self.get_config()
inputs_dict = prepare_pegasus_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def get_config(self):
return PegasusConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=self.hidden_size, d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers, encoder_layers=self.num_hidden_layers,
...@@ -130,8 +134,6 @@ class PegasusModelTester: ...@@ -130,8 +134,6 @@ class PegasusModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
inputs_dict = prepare_pegasus_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs() config, inputs_dict = self.prepare_config_and_inputs()
......
...@@ -13,12 +13,11 @@ ...@@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import tempfile import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import ProphetNetConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -30,7 +29,6 @@ if is_torch_available(): ...@@ -30,7 +29,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
ProphetNetConfig,
ProphetNetDecoder, ProphetNetDecoder,
ProphetNetEncoder, ProphetNetEncoder,
ProphetNetForCausalLM, ProphetNetForCausalLM,
...@@ -124,7 +122,19 @@ class ProphetNetModelTester: ...@@ -124,7 +122,19 @@ class ProphetNetModelTester:
if self.use_labels: if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = ProphetNetConfig( config = self.get_config()
return (
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
)
def get_config(self):
return ProphetNetConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_encoder_layers=self.num_encoder_layers, num_encoder_layers=self.num_encoder_layers,
...@@ -145,15 +155,6 @@ class ProphetNetModelTester: ...@@ -145,15 +155,6 @@ class ProphetNetModelTester:
is_encoder_decoder=self.is_encoder_decoder, is_encoder_decoder=self.is_encoder_decoder,
) )
return (
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
)
def prepare_config_and_inputs_for_decoder(self): def prepare_config_and_inputs_for_decoder(self):
( (
config, config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment