Unverified Commit c3d9ac76 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Expose get_config() on ModelTesters (#12812)

* Expose get_config() on ModelTesters

* Typo
parent cabcc751
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import ReformerConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
require_sentencepiece, require_sentencepiece,
require_tokenizers, require_tokenizers,
...@@ -36,7 +36,6 @@ if is_torch_available(): ...@@ -36,7 +36,6 @@ if is_torch_available():
from transformers import ( from transformers import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerConfig,
ReformerForMaskedLM, ReformerForMaskedLM,
ReformerForQuestionAnswering, ReformerForQuestionAnswering,
ReformerForSequenceClassification, ReformerForSequenceClassification,
...@@ -51,44 +50,44 @@ class ReformerModelTester: ...@@ -51,44 +50,44 @@ class ReformerModelTester:
def __init__( def __init__(
self, self,
parent, parent,
batch_size=None, batch_size=13,
seq_length=None, seq_length=32,
is_training=None, is_training=True,
is_decoder=None, is_decoder=True,
use_input_mask=None, use_input_mask=True,
use_labels=None, use_labels=True,
vocab_size=None, vocab_size=32,
attention_head_size=None, attention_head_size=16,
hidden_size=None, hidden_size=32,
num_attention_heads=None, num_attention_heads=2,
local_attn_chunk_length=None, local_attn_chunk_length=4,
local_num_chunks_before=None, local_num_chunks_before=1,
local_num_chunks_after=None, local_num_chunks_after=0,
num_buckets=None, num_buckets=None,
num_hashes=1, num_hashes=1,
lsh_attn_chunk_length=None, lsh_attn_chunk_length=None,
lsh_num_chunks_before=None, lsh_num_chunks_before=None,
lsh_num_chunks_after=None, lsh_num_chunks_after=None,
chunk_size_lm_head=None, chunk_size_lm_head=0,
chunk_size_feed_forward=None, chunk_size_feed_forward=0,
feed_forward_size=None, feed_forward_size=32,
hidden_act=None, hidden_act="gelu",
hidden_dropout_prob=None, hidden_dropout_prob=0.1,
local_attention_probs_dropout_prob=None, local_attention_probs_dropout_prob=0.1,
lsh_attention_probs_dropout_prob=None, lsh_attention_probs_dropout_prob=None,
max_position_embeddings=None, max_position_embeddings=512,
initializer_range=None, initializer_range=0.02,
axial_norm_std=None, axial_norm_std=1.0,
layer_norm_eps=None, layer_norm_eps=1e-12,
axial_pos_embds=None, axial_pos_embds=True,
axial_pos_shape=None, axial_pos_shape=[4, 8],
axial_pos_embds_dim=None, axial_pos_embds_dim=[16, 16],
attn_layers=None, attn_layers=["local", "local", "local", "local"],
pad_token_id=None, pad_token_id=0,
eos_token_id=None, eos_token_id=2,
scope=None, scope=None,
hash_seed=None, hash_seed=0,
num_labels=None, num_labels=2,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -101,7 +100,7 @@ class ReformerModelTester: ...@@ -101,7 +100,7 @@ class ReformerModelTester:
self.attention_head_size = attention_head_size self.attention_head_size = attention_head_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.num_hidden_layers = len(attn_layers) self.num_hidden_layers = len(attn_layers) if attn_layers is not None else 0
self.local_attn_chunk_length = local_attn_chunk_length self.local_attn_chunk_length = local_attn_chunk_length
self.local_num_chunks_after = local_num_chunks_after self.local_num_chunks_after = local_num_chunks_after
self.local_num_chunks_before = local_num_chunks_before self.local_num_chunks_before = local_num_chunks_before
...@@ -149,7 +148,17 @@ class ReformerModelTester: ...@@ -149,7 +148,17 @@ class ReformerModelTester:
if self.use_labels: if self.use_labels:
choice_labels = ids_tensor([self.batch_size], 2) choice_labels = ids_tensor([self.batch_size], 2)
config = ReformerConfig( config = self.get_config()
return (
config,
input_ids,
input_mask,
choice_labels,
)
def get_config(self):
return ReformerConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
...@@ -177,13 +186,6 @@ class ReformerModelTester: ...@@ -177,13 +186,6 @@ class ReformerModelTester:
hash_seed=self.hash_seed, hash_seed=self.hash_seed,
) )
return (
config,
input_ids,
input_mask,
choice_labels,
)
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
model = ReformerModel(config=config) model = ReformerModel(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -593,45 +595,8 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod ...@@ -593,45 +595,8 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
test_torchscript = False test_torchscript = False
test_sequence_classification_problem_types = True test_sequence_classification_problem_types = True
def prepare_kwargs(self):
return {
"batch_size": 13,
"seq_length": 32,
"is_training": True,
"is_decoder": True,
"use_input_mask": True,
"use_labels": True,
"vocab_size": 32,
"attention_head_size": 16,
"hidden_size": 32,
"num_attention_heads": 2,
"local_attn_chunk_length": 4,
"local_num_chunks_before": 1,
"local_num_chunks_after": 0,
"chunk_size_lm_head": 0,
"chunk_size_feed_forward": 0,
"feed_forward_size": 32,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"local_attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"initializer_range": 0.02,
"axial_norm_std": 1.0,
"layer_norm_eps": 1e-12,
"axial_pos_embds": True,
"axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [16, 16],
"attn_layers": ["local", "local", "local", "local"],
"pad_token_id": 0,
"eos_token_id": 2,
"scope": None,
"hash_seed": 0,
"num_labels": 2,
}
def setUp(self): def setUp(self):
tester_kwargs = self.prepare_kwargs() self.model_tester = ReformerModelTester(self)
self.model_tester = ReformerModelTester(self, **tester_kwargs)
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
@slow @slow
...@@ -716,49 +681,46 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation ...@@ -716,49 +681,46 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation
test_headmasking = False test_headmasking = False
test_torchscript = False test_torchscript = False
def prepare_kwargs(self):
return {
"batch_size": 13,
"seq_length": 13,
"use_input_mask": True,
"use_labels": True,
"is_training": False,
"is_decoder": True,
"vocab_size": 32,
"attention_head_size": 16,
"hidden_size": 64,
"num_attention_heads": 2,
"num_buckets": 2,
"num_hashes": 4,
"lsh_attn_chunk_length": 4,
"lsh_num_chunks_before": 1,
"lsh_num_chunks_after": 0,
"chunk_size_lm_head": 5,
"chunk_size_feed_forward": 6,
"feed_forward_size": 32,
"hidden_act": "relu",
"hidden_dropout_prob": 0.1,
"lsh_attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"initializer_range": 0.02,
"axial_norm_std": 1.0,
"layer_norm_eps": 1e-12,
"axial_pos_embds": True,
"axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [16, 48],
# sanotheu
# "attn_layers": ["lsh", "lsh", "lsh", "lsh"],
"attn_layers": ["lsh"],
"pad_token_id": 0,
"eos_token_id": 2,
"scope": None,
"hash_seed": 0,
"num_labels": 2,
}
def setUp(self): def setUp(self):
tester_kwargs = self.prepare_kwargs() self.model_tester = ReformerModelTester(
self.model_tester = ReformerModelTester(self, **tester_kwargs) self,
batch_size=13,
seq_length=13,
use_input_mask=True,
use_labels=True,
is_training=False,
is_decoder=True,
vocab_size=32,
attention_head_size=16,
hidden_size=64,
num_attention_heads=2,
num_buckets=2,
num_hashes=4,
lsh_attn_chunk_length=4,
lsh_num_chunks_before=1,
lsh_num_chunks_after=0,
chunk_size_lm_head=5,
chunk_size_feed_forward=6,
feed_forward_size=32,
hidden_act="relu",
hidden_dropout_prob=0.1,
lsh_attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
initializer_range=0.02,
axial_norm_std=1.0,
layer_norm_eps=1e-12,
axial_pos_embds=True,
axial_pos_shape=[4, 8],
axial_pos_embds_dim=[16, 48],
# sanotheu
# attn_layers=[lsh,lsh,lsh,lsh],
attn_layers=["lsh"],
pad_token_id=0,
eos_token_id=2,
scope=None,
hash_seed=0,
num_labels=2,
)
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
def _check_attentions_for_generate( def _check_attentions_for_generate(
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from transformers import is_torch_available from transformers import RobertaConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -29,7 +29,6 @@ if is_torch_available(): ...@@ -29,7 +29,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
RobertaConfig,
RobertaForCausalLM, RobertaForCausalLM,
RobertaForMaskedLM, RobertaForMaskedLM,
RobertaForMultipleChoice, RobertaForMultipleChoice,
...@@ -94,7 +93,12 @@ class RobertaModelTester: ...@@ -94,7 +93,12 @@ class RobertaModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = RobertaConfig( config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return RobertaConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
...@@ -108,8 +112,6 @@ class RobertaModelTester: ...@@ -108,8 +112,6 @@ class RobertaModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def prepare_config_and_inputs_for_decoder(self): def prepare_config_and_inputs_for_decoder(self):
( (
config, config,
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import unittest import unittest
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available from transformers import RoFormerConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -29,7 +29,6 @@ if is_torch_available(): ...@@ -29,7 +29,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
RoFormerConfig,
RoFormerForCausalLM, RoFormerForCausalLM,
RoFormerForMaskedLM, RoFormerForMaskedLM,
RoFormerForMultipleChoice, RoFormerForMultipleChoice,
...@@ -113,7 +112,12 @@ class RoFormerModelTester: ...@@ -113,7 +112,12 @@ class RoFormerModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = RoFormerConfig( config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return RoFormerConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
...@@ -128,8 +132,6 @@ class RoFormerModelTester: ...@@ -128,8 +132,6 @@ class RoFormerModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def prepare_config_and_inputs_for_decoder(self): def prepare_config_and_inputs_for_decoder(self):
( (
config, config,
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Speech2Text model. """ """ Testing suite for the PyTorch Speech2Text model. """
import copy import copy
import inspect import inspect
import os import os
import tempfile import tempfile
import unittest import unittest
from transformers import Speech2TextConfig
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import ( from transformers.testing_utils import (
is_torch_available, is_torch_available,
...@@ -40,12 +40,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_te ...@@ -40,12 +40,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_te
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import ( from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
Speech2TextConfig,
Speech2TextForConditionalGeneration,
Speech2TextModel,
Speech2TextProcessor,
)
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
...@@ -142,7 +137,17 @@ class Speech2TextModelTester: ...@@ -142,7 +137,17 @@ class Speech2TextModelTester:
attention_mask = torch.ones([self.batch_size, self.seq_length], dtype=torch.long, device=torch_device) attention_mask = torch.ones([self.batch_size, self.seq_length], dtype=torch.long, device=torch_device)
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2) decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2)
config = Speech2TextConfig( config = self.get_config()
inputs_dict = prepare_speech_to_text_inputs_dict(
config,
input_features=input_features,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
return config, inputs_dict
def get_config(self):
return Speech2TextConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=self.hidden_size, d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers, encoder_layers=self.num_hidden_layers,
...@@ -165,13 +170,6 @@ class Speech2TextModelTester: ...@@ -165,13 +170,6 @@ class Speech2TextModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
inputs_dict = prepare_speech_to_text_inputs_dict(
config,
input_features=input_features,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs() config, inputs_dict = self.prepare_config_and_inputs()
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import SqueezeBertConfig, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -28,7 +28,6 @@ if is_torch_available(): ...@@ -28,7 +28,6 @@ if is_torch_available():
from transformers import ( from transformers import (
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
SqueezeBertConfig,
SqueezeBertForMaskedLM, SqueezeBertForMaskedLM,
SqueezeBertForMultipleChoice, SqueezeBertForMultipleChoice,
SqueezeBertForQuestionAnswering, SqueezeBertForQuestionAnswering,
...@@ -37,179 +36,181 @@ if is_torch_available(): ...@@ -37,179 +36,181 @@ if is_torch_available():
SqueezeBertModel, SqueezeBertModel,
) )
class SqueezeBertModelTester(object):
def __init__( class SqueezeBertModelTester(object):
self, def __init__(
parent, self,
batch_size=13, parent,
seq_length=7, batch_size=13,
is_training=True, seq_length=7,
use_input_mask=True, is_training=True,
use_token_type_ids=False, use_input_mask=True,
use_labels=True, use_token_type_ids=False,
vocab_size=99, use_labels=True,
hidden_size=32, vocab_size=99,
num_hidden_layers=5, hidden_size=32,
num_attention_heads=4, num_hidden_layers=5,
intermediate_size=64, num_attention_heads=4,
hidden_act="gelu", intermediate_size=64,
hidden_dropout_prob=0.1, hidden_act="gelu",
attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1,
max_position_embeddings=512, attention_probs_dropout_prob=0.1,
type_vocab_size=16, max_position_embeddings=512,
type_sequence_label_size=2, type_vocab_size=16,
initializer_range=0.02, type_sequence_label_size=2,
num_labels=3, initializer_range=0.02,
num_choices=4, num_labels=3,
scope=None, num_choices=4,
q_groups=2, scope=None,
k_groups=2, q_groups=2,
v_groups=2, k_groups=2,
post_attention_groups=2, v_groups=2,
intermediate_groups=4, post_attention_groups=2,
output_groups=1, intermediate_groups=4,
): output_groups=1,
self.parent = parent ):
self.batch_size = batch_size self.parent = parent
self.seq_length = seq_length self.batch_size = batch_size
self.is_training = is_training self.seq_length = seq_length
self.use_input_mask = use_input_mask self.is_training = is_training
self.use_token_type_ids = use_token_type_ids self.use_input_mask = use_input_mask
self.use_labels = use_labels self.use_token_type_ids = use_token_type_ids
self.vocab_size = vocab_size self.use_labels = use_labels
self.hidden_size = hidden_size self.vocab_size = vocab_size
self.num_hidden_layers = num_hidden_layers self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob self.hidden_act = hidden_act
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.hidden_dropout_prob = hidden_dropout_prob
self.max_position_embeddings = max_position_embeddings self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_vocab_size = type_vocab_size self.max_position_embeddings = max_position_embeddings
self.type_sequence_label_size = type_sequence_label_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels self.initializer_range = initializer_range
self.num_choices = num_choices self.num_labels = num_labels
self.scope = scope self.num_choices = num_choices
self.q_groups = q_groups self.scope = scope
self.k_groups = k_groups self.q_groups = q_groups
self.v_groups = v_groups self.k_groups = k_groups
self.post_attention_groups = post_attention_groups self.v_groups = v_groups
self.intermediate_groups = intermediate_groups self.post_attention_groups = post_attention_groups
self.output_groups = output_groups self.intermediate_groups = intermediate_groups
self.output_groups = output_groups
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask: input_mask = None
input_mask = random_attention_mask([self.batch_size, self.seq_length]) if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
token_labels = None sequence_labels = None
choice_labels = None token_labels = None
if self.use_labels: choice_labels = None
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) if self.use_labels:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
choice_labels = ids_tensor([self.batch_size], self.num_choices) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = SqueezeBertConfig(
embedding_size=self.hidden_size, config = self.get_config()
vocab_size=self.vocab_size,
hidden_size=self.hidden_size, return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads, def get_config(self):
intermediate_size=self.intermediate_size, return SqueezeBertConfig(
hidden_act=self.hidden_act, embedding_size=self.hidden_size,
attention_probs_dropout_prob=self.hidden_dropout_prob, vocab_size=self.vocab_size,
attention_dropout=self.attention_probs_dropout_prob, hidden_size=self.hidden_size,
max_position_embeddings=self.max_position_embeddings, num_hidden_layers=self.num_hidden_layers,
initializer_range=self.initializer_range, num_attention_heads=self.num_attention_heads,
q_groups=self.q_groups, intermediate_size=self.intermediate_size,
k_groups=self.k_groups, hidden_act=self.hidden_act,
v_groups=self.v_groups, attention_probs_dropout_prob=self.hidden_dropout_prob,
post_attention_groups=self.post_attention_groups, attention_dropout=self.attention_probs_dropout_prob,
intermediate_groups=self.intermediate_groups, max_position_embeddings=self.max_position_embeddings,
output_groups=self.output_groups, initializer_range=self.initializer_range,
) q_groups=self.q_groups,
k_groups=self.k_groups,
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels v_groups=self.v_groups,
post_attention_groups=self.post_attention_groups,
def create_and_check_squeezebert_model( intermediate_groups=self.intermediate_groups,
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels output_groups=self.output_groups,
): )
model = SqueezeBertModel(config=config)
model.to(torch_device) def create_and_check_squeezebert_model(
model.eval() self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
result = model(input_ids, input_mask) ):
result = model(input_ids) model = SqueezeBertModel(config=config)
self.parent.assertEqual( model.to(torch_device)
result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size) model.eval()
) result = model(input_ids, input_mask)
result = model(input_ids)
def create_and_check_squeezebert_for_masked_lm( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
): def create_and_check_squeezebert_for_masked_lm(
model = SqueezeBertForMaskedLM(config=config) self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
model.to(torch_device) ):
model.eval() model = SqueezeBertForMaskedLM(config=config)
result = model(input_ids, attention_mask=input_mask, labels=token_labels) model.to(torch_device)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
def create_and_check_squeezebert_for_question_answering( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
): def create_and_check_squeezebert_for_question_answering(
model = SqueezeBertForQuestionAnswering(config=config) self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
model.to(torch_device) ):
model.eval() model = SqueezeBertForQuestionAnswering(config=config)
result = model( model.to(torch_device)
input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels model.eval()
) result = model(
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) )
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_squeezebert_for_sequence_classification( self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
): def create_and_check_squeezebert_for_sequence_classification(
config.num_labels = self.num_labels self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
model = SqueezeBertForSequenceClassification(config) ):
model.to(torch_device) config.num_labels = self.num_labels
model.eval() model = SqueezeBertForSequenceClassification(config)
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) model.to(torch_device)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) model.eval()
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
def create_and_check_squeezebert_for_token_classification( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
): def create_and_check_squeezebert_for_token_classification(
config.num_labels = self.num_labels self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
model = SqueezeBertForTokenClassification(config=config) ):
model.to(torch_device) config.num_labels = self.num_labels
model.eval() model = SqueezeBertForTokenClassification(config=config)
model.to(torch_device)
result = model(input_ids, attention_mask=input_mask, labels=token_labels) model.eval()
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
def create_and_check_squeezebert_for_multiple_choice( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
): def create_and_check_squeezebert_for_multiple_choice(
config.num_choices = self.num_choices self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
model = SqueezeBertForMultipleChoice(config=config) ):
model.to(torch_device) config.num_choices = self.num_choices
model.eval() model = SqueezeBertForMultipleChoice(config=config)
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() model.to(torch_device)
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() model.eval()
result = model( multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_inputs_ids, multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
attention_mask=multiple_choice_input_mask, result = model(
labels=choice_labels, multiple_choice_inputs_ids,
) attention_mask=multiple_choice_input_mask,
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) labels=choice_labels,
)
def prepare_config_and_inputs_for_common(self): self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs def prepare_config_and_inputs_for_common(self):
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} config_and_inputs = self.prepare_config_and_inputs()
return config, inputs_dict (config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch @require_torch
......
...@@ -18,7 +18,7 @@ import copy ...@@ -18,7 +18,7 @@ import copy
import tempfile import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import T5Config, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -30,7 +30,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor ...@@ -30,7 +30,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import ByT5Tokenizer, T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer from transformers import ByT5Tokenizer, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer
from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -100,7 +100,19 @@ class T5ModelTester: ...@@ -100,7 +100,19 @@ class T5ModelTester:
if self.use_labels: if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = T5Config( config = self.get_config()
return (
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
)
def get_config(self):
return T5Config(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=self.hidden_size, d_model=self.hidden_size,
d_ff=self.d_ff, d_ff=self.d_ff,
...@@ -117,15 +129,6 @@ class T5ModelTester: ...@@ -117,15 +129,6 @@ class T5ModelTester:
decoder_start_token_id=self.decoder_start_token_id, decoder_start_token_id=self.decoder_start_token_id,
) )
return (
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
)
def check_prepare_lm_labels_via_shift_left( def check_prepare_lm_labels_via_shift_left(
self, self,
config, config,
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import unittest import unittest
...@@ -29,6 +28,7 @@ from transformers import ( ...@@ -29,6 +28,7 @@ from transformers import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TapasConfig,
is_torch_available, is_torch_available,
) )
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
...@@ -43,7 +43,6 @@ if is_torch_available(): ...@@ -43,7 +43,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
TapasConfig,
TapasForMaskedLM, TapasForMaskedLM,
TapasForQuestionAnswering, TapasForQuestionAnswering,
TapasForSequenceClassification, TapasForSequenceClassification,
...@@ -183,7 +182,24 @@ class TapasModelTester: ...@@ -183,7 +182,24 @@ class TapasModelTester:
float_answer = floats_tensor([self.batch_size]).to(torch_device) float_answer = floats_tensor([self.batch_size]).to(torch_device)
aggregation_labels = ids_tensor([self.batch_size], self.num_aggregation_labels).to(torch_device) aggregation_labels = ids_tensor([self.batch_size], self.num_aggregation_labels).to(torch_device)
config = TapasConfig( config = self.get_config()
return (
config,
input_ids,
input_mask,
token_type_ids,
sequence_labels,
token_labels,
labels,
numeric_values,
numeric_values_scale,
float_answer,
aggregation_labels,
)
def get_config(self):
return TapasConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
...@@ -220,20 +236,6 @@ class TapasModelTester: ...@@ -220,20 +236,6 @@ class TapasModelTester:
disable_per_token_loss=self.disable_per_token_loss, disable_per_token_loss=self.disable_per_token_loss,
) )
return (
config,
input_ids,
input_mask,
token_type_ids,
sequence_labels,
token_labels,
labels,
numeric_values,
numeric_values_scale,
float_answer,
aggregation_labels,
)
def create_and_check_model( def create_and_check_model(
self, self,
config, config,
......
...@@ -17,7 +17,7 @@ import copy ...@@ -17,7 +17,7 @@ import copy
import random import random
import unittest import unittest
from transformers import is_torch_available from transformers import TransfoXLConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -29,7 +29,7 @@ if is_torch_available(): ...@@ -29,7 +29,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import TransfoXLConfig, TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel from transformers import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from transformers.models.transfo_xl.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.transfo_xl.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -69,7 +69,12 @@ class TransfoXLModelTester: ...@@ -69,7 +69,12 @@ class TransfoXLModelTester:
if self.use_labels: if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = TransfoXLConfig( config = self.get_config()
return (config, input_ids_1, input_ids_2, lm_labels)
def get_config(self):
return TransfoXLConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
mem_len=self.mem_len, mem_len=self.mem_len,
clamp_len=self.clamp_len, clamp_len=self.clamp_len,
...@@ -85,8 +90,6 @@ class TransfoXLModelTester: ...@@ -85,8 +90,6 @@ class TransfoXLModelTester:
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
return (config, input_ids_1, input_ids_2, lm_labels)
def set_seed(self): def set_seed(self):
random.seed(self.seed) random.seed(self.seed)
torch.manual_seed(self.seed) torch.manual_seed(self.seed)
......
...@@ -14,12 +14,11 @@ ...@@ -14,12 +14,11 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch VisualBERT model. """ """ Testing suite for the PyTorch VisualBERT model. """
import copy import copy
import unittest import unittest
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available from transformers import VisualBertConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -30,7 +29,6 @@ if is_torch_available(): ...@@ -30,7 +29,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
VisualBertConfig,
VisualBertForMultipleChoice, VisualBertForMultipleChoice,
VisualBertForPreTraining, VisualBertForPreTraining,
VisualBertForQuestionAnswering, VisualBertForQuestionAnswering,
...@@ -98,7 +96,7 @@ class VisualBertModelTester: ...@@ -98,7 +96,7 @@ class VisualBertModelTester:
self.num_choices = num_choices self.num_choices = num_choices
self.scope = scope self.scope = scope
def prepare_config(self): def get_config(self):
return VisualBertConfig( return VisualBertConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -138,7 +136,7 @@ class VisualBertModelTester: ...@@ -138,7 +136,7 @@ class VisualBertModelTester:
if self.use_visual_token_type_ids: if self.use_visual_token_type_ids:
visual_token_type_ids = ids_tensor([self.batch_size, self.visual_seq_length], self.type_vocab_size) visual_token_type_ids = ids_tensor([self.batch_size, self.visual_seq_length], self.type_vocab_size)
config = self.prepare_config() config = self.get_config()
return config, { return config, {
"input_ids": input_ids, "input_ids": input_ids,
"token_type_ids": token_type_ids, "token_type_ids": token_type_ids,
...@@ -198,7 +196,7 @@ class VisualBertModelTester: ...@@ -198,7 +196,7 @@ class VisualBertModelTester:
if self.use_labels: if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_choices) labels = ids_tensor([self.batch_size], self.num_choices)
config = self.prepare_config() config = self.get_config()
return config, { return config, {
"input_ids": input_ids, "input_ids": input_ids,
"token_type_ids": token_type_ids, "token_type_ids": token_type_ids,
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import inspect import inspect
import unittest import unittest
from transformers import ViTConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
...@@ -29,7 +30,7 @@ if is_torch_available(): ...@@ -29,7 +30,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import ViTConfig, ViTForImageClassification, ViTModel from transformers import ViTForImageClassification, ViTModel
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
...@@ -86,7 +87,12 @@ class ViTModelTester: ...@@ -86,7 +87,12 @@ class ViTModelTester:
if self.use_labels: if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size) labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = ViTConfig( config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ViTConfig(
image_size=self.image_size, image_size=self.image_size,
patch_size=self.patch_size, patch_size=self.patch_size,
num_channels=self.num_channels, num_channels=self.num_channels,
...@@ -101,8 +107,6 @@ class ViTModelTester: ...@@ -101,8 +107,6 @@ class ViTModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
) )
return config, pixel_values, labels
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
model = ViTModel(config=config) model = ViTModel(config=config)
model.to(torch_device) model.to(torch_device)
......
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
import pytest import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import is_torch_available from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -32,7 +32,6 @@ if is_torch_available(): ...@@ -32,7 +32,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
Wav2Vec2Config,
Wav2Vec2FeatureExtractor, Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
...@@ -106,7 +105,12 @@ class Wav2Vec2ModelTester: ...@@ -106,7 +105,12 @@ class Wav2Vec2ModelTester:
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size) input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = random_attention_mask([self.batch_size, self.seq_length]) attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = Wav2Vec2Config( config = self.get_config()
return config, input_values, attention_mask
def get_config(self):
return Wav2Vec2Config(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
feat_extract_norm=self.feat_extract_norm, feat_extract_norm=self.feat_extract_norm,
feat_extract_dropout=self.feat_extract_dropout, feat_extract_dropout=self.feat_extract_dropout,
...@@ -127,8 +131,6 @@ class Wav2Vec2ModelTester: ...@@ -127,8 +131,6 @@ class Wav2Vec2ModelTester:
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
return config, input_values, attention_mask
def create_and_check_model(self, config, input_values, attention_mask): def create_and_check_model(self, config, input_values, attention_mask):
model = Wav2Vec2Model(config=config) model = Wav2Vec2Model(config=config)
model.to(torch_device) model.to(torch_device)
......
...@@ -13,10 +13,9 @@ ...@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import unittest
from transformers import is_torch_available from transformers import XLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -28,7 +27,6 @@ if is_torch_available(): ...@@ -28,7 +27,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
XLMConfig,
XLMForMultipleChoice, XLMForMultipleChoice,
XLMForQuestionAnswering, XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
...@@ -97,7 +95,22 @@ class XLMModelTester: ...@@ -97,7 +95,22 @@ class XLMModelTester:
is_impossible_labels = ids_tensor([self.batch_size], 2).float() is_impossible_labels = ids_tensor([self.batch_size], 2).float()
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = XLMConfig( config = self.get_config()
return (
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
)
def get_config(self):
return XLMConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
n_special=self.n_special, n_special=self.n_special,
emb_dim=self.hidden_size, emb_dim=self.hidden_size,
...@@ -118,18 +131,6 @@ class XLMModelTester: ...@@ -118,18 +131,6 @@ class XLMModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
) )
return (
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
)
def create_and_check_xlm_model( def create_and_check_xlm_model(
self, self,
config, config,
......
...@@ -13,11 +13,10 @@ ...@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
import unittest import unittest
from transformers import is_torch_available from transformers import XLNetConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -29,7 +28,6 @@ if is_torch_available(): ...@@ -29,7 +28,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
XLNetConfig,
XLNetForMultipleChoice, XLNetForMultipleChoice,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
XLNetForQuestionAnsweringSimple, XLNetForQuestionAnsweringSimple,
...@@ -131,7 +129,25 @@ class XLNetModelTester: ...@@ -131,7 +129,25 @@ class XLNetModelTester:
is_impossible_labels = ids_tensor([self.batch_size], 2).float() is_impossible_labels = ids_tensor([self.batch_size], 2).float()
token_labels = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) token_labels = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
config = XLNetConfig( config = self.get_config()
return (
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
)
def get_config(self):
return XLNetConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=self.hidden_size, d_model=self.hidden_size,
n_head=self.num_attention_heads, n_head=self.num_attention_heads,
...@@ -150,21 +166,6 @@ class XLNetModelTester: ...@@ -150,21 +166,6 @@ class XLNetModelTester:
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
) )
return (
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
)
def set_seed(self): def set_seed(self):
random.seed(self.seed) random.seed(self.seed)
torch.manual_seed(self.seed) torch.manual_seed(self.seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment