"vscode:/vscode.git/clone" did not exist on "b4f9464f90d9bc215b36dea4ed5f7eaf83144301"
Unverified Commit c852036b authored by Amil Khare's avatar Amil Khare Committed by GitHub
Browse files

[cleanup] Hoist ModelTester objects to top level (#4939)


Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 0c55a384
...@@ -35,81 +35,39 @@ if is_tf_available(): ...@@ -35,81 +35,39 @@ if is_tf_available():
) )
@require_tf class TFXLMModelTester:
class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple)
if is_tf_available()
else ()
)
all_generative_model_classes = (
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
class TFXLMModelTester(object):
def __init__( def __init__(
self, self, parent,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_lengths=True,
use_token_type_ids=True,
use_labels=True,
gelu_activation=True,
sinusoidal_embeddings=False,
causal=False,
asm=False,
n_langs=2,
vocab_size=99,
n_special=0,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
summary_type="last",
use_proj=True,
scope=None,
bos_token_id=0,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = 13
self.seq_length = seq_length self.seq_length = 7
self.is_training = is_training self.is_training = True
self.use_input_lengths = use_input_lengths self.use_input_lengths = True
self.use_token_type_ids = use_token_type_ids self.use_token_type_ids = True
self.use_labels = use_labels self.use_labels = True
self.gelu_activation = gelu_activation self.gelu_activation = True
self.sinusoidal_embeddings = sinusoidal_embeddings self.sinusoidal_embeddings = False
self.asm = asm self.causal = False
self.n_langs = n_langs self.asm = False
self.vocab_size = vocab_size self.n_langs = 2
self.n_special = n_special self.vocab_size = 99
self.summary_type = summary_type self.n_special = 0
self.causal = causal self.hidden_size = 32
self.use_proj = use_proj self.num_hidden_layers = 5
self.hidden_size = hidden_size self.num_attention_heads = 4
self.num_hidden_layers = num_hidden_layers self.hidden_dropout_prob = 0.1
self.num_attention_heads = num_attention_heads self.attention_probs_dropout_prob = 0.1
self.hidden_dropout_prob = hidden_dropout_prob self.max_position_embeddings = 512
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.type_vocab_size = 16
self.max_position_embeddings = max_position_embeddings self.type_sequence_label_size = 2
self.n_langs = n_langs self.initializer_range = 0.02
self.type_sequence_label_size = type_sequence_label_size self.num_labels = 3
self.initializer_range = initializer_range self.num_choices = 4
self.summary_type = summary_type self.summary_type = "last"
self.num_labels = num_labels self.use_proj = True
self.num_choices = num_choices self.scope = None
self.scope = scope self.bos_token_id = 0
self.bos_token_id = bos_token_id
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -211,9 +169,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -211,9 +169,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
"logits": logits.numpy(), "logits": logits.numpy(),
} }
self.parent.assertListEqual( self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_xlm_qa( def create_and_check_xlm_qa(
self, self,
...@@ -283,8 +239,21 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -283,8 +239,21 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
} }
return config, inputs_dict return config, inputs_dict
@require_tf
class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple)
if is_tf_available()
else ()
)
all_generative_model_classes = (
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
def setUp(self): def setUp(self):
self.model_tester = TFXLMModelTest.TFXLMModelTester(self) self.model_tester = TFXLMModelTester(self)
self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37) self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37)
def test_config(self): def test_config(self):
......
...@@ -37,78 +37,35 @@ if is_tf_available(): ...@@ -37,78 +37,35 @@ if is_tf_available():
) )
@require_tf class TFXLNetModelTester:
class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
TFXLNetModel,
TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple,
)
if is_tf_available()
else ()
)
all_generative_model_classes = (
(TFXLNetLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False
class TFXLNetModelTester(object):
def __init__( def __init__(
self, self, parent,
parent,
batch_size=13,
seq_length=7,
mem_len=10,
clamp_len=-1,
reuse_len=15,
is_training=True,
use_labels=True,
vocab_size=99,
cutoffs=[10, 50, 80],
hidden_size=32,
num_attention_heads=4,
d_inner=128,
num_hidden_layers=5,
type_sequence_label_size=2,
untie_r=True,
bi_data=False,
same_length=False,
initializer_range=0.05,
seed=1,
type_vocab_size=2,
bos_token_id=1,
eos_token_id=2,
pad_token_id=5,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = 13
self.seq_length = seq_length self.seq_length = 7
self.mem_len = mem_len self.mem_len = 10
# self.key_len = seq_length + mem_len # self.key_len = seq_length + mem_len
self.clamp_len = clamp_len self.clamp_len = -1
self.reuse_len = reuse_len self.reuse_len = 15
self.is_training = is_training self.is_training = True
self.use_labels = use_labels self.use_labels = True
self.vocab_size = vocab_size self.vocab_size = 99
self.cutoffs = cutoffs self.cutoffs = [10, 50, 80]
self.hidden_size = hidden_size self.hidden_size = 32
self.num_attention_heads = num_attention_heads self.num_attention_heads = 4
self.d_inner = d_inner self.d_inner = 128
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = 5
self.bi_data = bi_data self.type_sequence_label_size = 2
self.untie_r = untie_r self.untie_r = True
self.same_length = same_length self.bi_data = False
self.initializer_range = initializer_range self.same_length = False
self.seed = seed self.initializer_range = 0.05
self.type_vocab_size = type_vocab_size self.seed = 1
self.type_sequence_label_size = type_sequence_label_size self.type_vocab_size = 2
self.bos_token_id = bos_token_id self.bos_token_id = 1
self.pad_token_id = pad_token_id self.eos_token_id = 2
self.eos_token_id = eos_token_id self.pad_token_id = 5
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -377,8 +334,28 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -377,8 +334,28 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
inputs_dict = {"input_ids": input_ids_1} inputs_dict = {"input_ids": input_ids_1}
return config, inputs_dict return config, inputs_dict
@require_tf
class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
TFXLNetModel,
TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple,
)
if is_tf_available()
else ()
)
all_generative_model_classes = (
(TFXLNetLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False
def setUp(self): def setUp(self):
self.model_tester = TFXLNetModelTest.TFXLNetModelTester(self) self.model_tester = TFXLNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37) self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
def test_config(self): def test_config(self):
......
...@@ -29,58 +29,30 @@ if is_torch_available(): ...@@ -29,58 +29,30 @@ if is_torch_available():
from transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
@require_torch class TransfoXLModelTester:
class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
test_pruning = False
test_torchscript = False
test_resize_embeddings = True
class TransfoXLModelTester(object):
def __init__( def __init__(
self, self, parent,
parent,
batch_size=14,
seq_length=7,
mem_len=30,
clamp_len=15,
is_training=True,
use_labels=True,
vocab_size=99,
cutoffs=[10, 50, 80],
hidden_size=32,
d_embed=32,
num_attention_heads=4,
d_head=8,
d_inner=128,
div_val=2,
num_hidden_layers=5,
scope=None,
seed=1,
eos_token_id=0,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = 14
self.seq_length = seq_length self.seq_length = 7
self.mem_len = mem_len self.mem_len = 30
self.key_length = seq_length + mem_len self.key_length = self.seq_length + self.mem_len
self.clamp_len = clamp_len self.clamp_len = 15
self.is_training = is_training self.is_training = True
self.use_labels = use_labels self.use_labels = True
self.vocab_size = vocab_size self.vocab_size = 99
self.cutoffs = cutoffs self.cutoffs = [10, 50, 80]
self.hidden_size = hidden_size self.hidden_size = 32
self.d_embed = d_embed self.d_embed = 32
self.num_attention_heads = num_attention_heads self.num_attention_heads = 4
self.d_head = d_head self.d_head = 8
self.d_inner = d_inner self.d_inner = 128
self.div_val = div_val self.div_val = 2
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = 5
self.scope = scope self.scope = None
self.seed = seed self.seed = 1
self.eos_token_id = eos_token_id self.eos_token_id = 0
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -187,6 +159,16 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -187,6 +159,16 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = {"input_ids": input_ids_1} inputs_dict = {"input_ids": input_ids_1}
return config, inputs_dict return config, inputs_dict
@require_torch
class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
test_pruning = False
test_torchscript = False
test_resize_embeddings = True
def check_cutoffs_and_n_token( def check_cutoffs_and_n_token(
self, copied_cutoffs, layer, model_embed, model, model_class, resized_value, vocab_size self, copied_cutoffs, layer, model_embed, model, model_class, resized_value, vocab_size
): ):
...@@ -210,7 +192,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -210,7 +192,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
self.assertEqual(model.crit.n_token, vocab_size + resized_value) self.assertEqual(model.crit.n_token, vocab_size + resized_value)
def setUp(self): def setUp(self):
self.model_tester = TransfoXLModelTest.TransfoXLModelTester(self) self.model_tester = TransfoXLModelTester(self)
self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37) self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37)
def test_config(self): def test_config(self):
......
...@@ -37,87 +37,38 @@ if is_torch_available(): ...@@ -37,87 +37,38 @@ if is_torch_available():
from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_LIST
@require_torch class XLMModelTester:
class XLMModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
XLMModel,
XLMWithLMHeadModel,
XLMForQuestionAnswering,
XLMForSequenceClassification,
XLMForQuestionAnsweringSimple,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (
(XLMWithLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
class XLMModelTester(object):
def __init__( def __init__(
self, self, parent,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_lengths=True,
use_token_type_ids=True,
use_labels=True,
gelu_activation=True,
sinusoidal_embeddings=False,
causal=False,
asm=False,
n_langs=2,
vocab_size=99,
n_special=0,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
summary_type="last",
use_proj=True,
scope=None,
bos_token_id=0,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = 13
self.seq_length = seq_length self.seq_length = 7
self.is_training = is_training self.is_training = True
self.use_input_lengths = use_input_lengths self.use_input_lengths = True
self.use_token_type_ids = use_token_type_ids self.use_token_type_ids = True
self.use_labels = use_labels self.use_labels = True
self.gelu_activation = gelu_activation self.gelu_activation = True
self.sinusoidal_embeddings = sinusoidal_embeddings self.sinusoidal_embeddings = False
self.asm = asm self.causal = False
self.n_langs = n_langs self.asm = False
self.vocab_size = vocab_size self.n_langs = 2
self.n_special = n_special self.vocab_size = 99
self.summary_type = summary_type self.n_special = 0
self.causal = causal self.hidden_size = 32
self.use_proj = use_proj self.num_hidden_layers = 5
self.hidden_size = hidden_size self.num_attention_heads = 4
self.num_hidden_layers = num_hidden_layers self.hidden_dropout_prob = 0.1
self.num_attention_heads = num_attention_heads self.attention_probs_dropout_prob = 0.1
self.hidden_dropout_prob = hidden_dropout_prob self.max_position_embeddings = 512
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.type_sequence_label_size = 2
self.max_position_embeddings = max_position_embeddings self.initializer_range = 0.02
self.n_langs = n_langs self.num_labels = 3
self.type_sequence_label_size = type_sequence_label_size self.num_choices = 4
self.initializer_range = initializer_range self.summary_type = "last"
self.summary_type = summary_type self.use_proj = True
self.num_labels = num_labels self.scope = None
self.num_choices = num_choices self.bos_token_id = 0
self.scope = scope
self.bos_token_id = bos_token_id
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -223,9 +174,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -223,9 +174,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
} }
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_xlm_simple_qa( def create_and_check_xlm_simple_qa(
self, self,
...@@ -318,8 +267,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -318,8 +267,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
[self.batch_size, model.config.start_n_top * model.config.end_n_top], [self.batch_size, model.config.start_n_top * model.config.end_n_top],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["end_top_index"].size()), list(result["end_top_index"].size()), [self.batch_size, model.config.start_n_top * model.config.end_n_top],
[self.batch_size, model.config.start_n_top * model.config.end_n_top],
) )
self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size]) self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])
...@@ -347,9 +295,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -347,9 +295,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
} }
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size])
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
)
def create_and_check_xlm_for_token_classification( def create_and_check_xlm_for_token_classification(
self, self,
...@@ -372,9 +318,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -372,9 +318,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
} }
self.parent.assertListEqual( self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
)
self.check_loss_output(result) self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
...@@ -392,8 +336,27 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -392,8 +336,27 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths} inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths}
return config, inputs_dict return config, inputs_dict
@require_torch
class XLMModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
XLMModel,
XLMWithLMHeadModel,
XLMForQuestionAnswering,
XLMForSequenceClassification,
XLMForQuestionAnsweringSimple,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (
(XLMWithLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
def setUp(self): def setUp(self):
self.model_tester = XLMModelTest.XLMModelTester(self) self.model_tester = XLMModelTester(self)
self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37) self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37)
def test_config(self): def test_config(self):
......
...@@ -39,27 +39,7 @@ if is_torch_available(): ...@@ -39,27 +39,7 @@ if is_torch_available():
from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_LIST
@require_torch class XLNetModelTester:
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
XLNetModel,
XLNetLMHeadModel,
XLNetForTokenClassification,
XLNetForSequenceClassification,
XLNetForQuestionAnswering,
XLNetForMultipleChoice,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (
(XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False
class XLNetModelTester(object):
def __init__( def __init__(
self, self,
parent, parent,
...@@ -89,31 +69,31 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -89,31 +69,31 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
num_choices=4, num_choices=4,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = 14
self.seq_length = seq_length self.seq_length = 7
self.mem_len = mem_len self.mem_len = 10
# self.key_len = seq_length + mem_len # self.key_len = seq_length + mem_len
self.clamp_len = clamp_len self.clamp_len = -1
self.reuse_len = reuse_len self.reuse_len = 15
self.is_training = is_training self.is_training = True
self.use_labels = use_labels self.use_labels = True
self.vocab_size = vocab_size self.vocab_size = 99
self.cutoffs = cutoffs self.cutoffs = [10, 50, 80]
self.hidden_size = hidden_size self.hidden_size = 32
self.num_attention_heads = num_attention_heads self.num_attention_heads = 4
self.d_inner = d_inner self.d_inner = 128
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = 5
self.bi_data = bi_data self.type_sequence_label_size = 2
self.untie_r = untie_r self.untie_r = True
self.same_length = same_length self.bi_data = False
self.initializer_range = initializer_range self.same_length = False
self.seed = seed self.initializer_range = 0.05
self.type_vocab_size = type_vocab_size self.seed = 1
self.type_sequence_label_size = type_sequence_label_size self.type_vocab_size = 2
self.bos_token_id = bos_token_id self.bos_token_id = 1
self.pad_token_id = pad_token_id self.eos_token_id = 2
self.eos_token_id = eos_token_id self.pad_token_id = 5
self.num_choices = num_choices self.num_choices = 4
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -126,9 +106,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -126,9 +106,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device, self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
) )
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros( target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,)
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
)
target_mapping[:, 0, -1] = 1.0 # predict last token target_mapping[:, 0, -1] = 1.0 # predict last token
sequence_labels = None sequence_labels = None
...@@ -270,9 +248,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -270,9 +248,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels) loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
loss_2, all_logits_2, mems_2 = model( loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1
)
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping) logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
...@@ -370,8 +346,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -370,8 +346,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
[self.batch_size, model.config.start_n_top * model.config.end_n_top], [self.batch_size, model.config.start_n_top * model.config.end_n_top],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["end_top_index"].size()), list(result["end_top_index"].size()), [self.batch_size, model.config.start_n_top * model.config.end_n_top],
[self.batch_size, model.config.start_n_top * model.config.end_n_top],
) )
self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size]) self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])
self.parent.assertListEqual( self.parent.assertListEqual(
...@@ -472,8 +447,29 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -472,8 +447,29 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = {"input_ids": input_ids_1} inputs_dict = {"input_ids": input_ids_1}
return config, inputs_dict return config, inputs_dict
@require_torch
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
XLNetModel,
XLNetLMHeadModel,
XLNetForTokenClassification,
XLNetForSequenceClassification,
XLNetForQuestionAnswering,
XLNetForMultipleChoice,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (
(XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False
def setUp(self): def setUp(self):
self.model_tester = XLNetModelTest.XLNetModelTester(self) self.model_tester = XLNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37) self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
def test_config(self): def test_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment