Unverified Commit 573bdb0a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add tests to Trainer (#6605)

* Add tests to Trainer

* Test if removing long breaks everything

* Remove ugly hack

* Fix distributed test

* Use float for number of epochs
parent 039d8d65
...@@ -62,7 +62,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten ...@@ -62,7 +62,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features]) batch[k] = torch.stack([f[k] for f in features])
else: else:
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long) batch[k] = torch.tensor([f[k] for f in features])
return batch return batch
......
...@@ -449,6 +449,7 @@ class Trainer: ...@@ -449,6 +449,7 @@ class Trainer:
else: else:
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs num_train_epochs = self.args.num_train_epochs
self.args.max_steps = t_total
self.create_optimizer_and_scheduler(num_training_steps=t_total) self.create_optimizer_and_scheduler(num_training_steps=t_total)
...@@ -530,7 +531,7 @@ class Trainer: ...@@ -530,7 +531,7 @@ class Trainer:
logging_loss = 0.0 logging_loss = 0.0
model.zero_grad() model.zero_grad()
train_iterator = trange( train_iterator = trange(
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_process_zero() epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=not self.is_local_process_zero()
) )
for epoch in train_iterator: for epoch in train_iterator:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
...@@ -626,10 +627,10 @@ class Trainer: ...@@ -626,10 +627,10 @@ class Trainer:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
if self.args.max_steps > 0 and self.global_step > self.args.max_steps: if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
if self.args.max_steps > 0 and self.global_step > self.args.max_steps: if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
train_iterator.close() train_iterator.close()
break break
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
...@@ -986,10 +987,13 @@ class Trainer: ...@@ -986,10 +987,13 @@ class Trainer:
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = None self._past = None
samples_count = 0
for inputs in tqdm(dataloader, desc=description): for inputs in tqdm(dataloader, desc=description):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
samples_count += batch_size
if loss is not None: if loss is not None:
eval_losses.append(loss) eval_losses.append(loss * batch_size)
if logits is not None: if logits is not None:
preds = logits if preds is None else torch.cat((preds, logits), dim=0) preds = logits if preds is None else torch.cat((preds, logits), dim=0)
if labels is not None: if labels is not None:
...@@ -1023,7 +1027,7 @@ class Trainer: ...@@ -1023,7 +1027,7 @@ class Trainer:
else: else:
metrics = {} metrics = {}
if len(eval_losses) > 0: if len(eval_losses) > 0:
metrics["eval_loss"] = np.mean(eval_losses) metrics["eval_loss"] = np.sum(eval_losses) / samples_count
# Prefix all keys with eval_ # Prefix all keys with eval_
for key in list(metrics.keys()): for key in list(metrics.keys()):
......
...@@ -69,7 +69,8 @@ class TrainingArguments: ...@@ -69,7 +69,8 @@ class TrainingArguments:
max_grad_norm (:obj:`float`, `optional`, defaults to 1.0): max_grad_norm (:obj:`float`, `optional`, defaults to 1.0):
Maximum gradient norm (for gradient clipping). Maximum gradient norm (for gradient clipping).
num_train_epochs(:obj:`float`, `optional`, defaults to 3.0): num_train_epochs(:obj:`float`, `optional`, defaults to 3.0):
Total number of training epochs to perform. Total number of training epochs to perform (if not an integer, will perform the decimal part percents of
the last epoch before stopping training).
max_steps (:obj:`int`, `optional`, defaults to -1): max_steps (:obj:`int`, `optional`, defaults to -1):
If set to a positive number, the total number of training steps to perform. Overrides If set to a positive number, the total number of training steps to perform. Overrides
:obj:`num_train_epochs`. :obj:`num_train_epochs`.
......
import unittest
from transformers import AutoTokenizer, is_torch_available
from transformers.testing_utils import require_torch
if is_torch_available():
import torch
from transformers import (
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
TextDataset,
default_data_collator,
)
PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
@require_torch
class DataCollatorIntegrationTest(unittest.TestCase):
def test_default_with_dict(self):
features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# With label_ids
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor([[0, 1, 2]] * 8)))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# Features can already be tensors
features = [{"label": i, "inputs": torch.randint(10, [10])} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
# Labels can already be tensors
features = [{"label": torch.tensor(i), "inputs": torch.randint(10, [10])} for i in range(8)]
batch = default_data_collator(features)
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
def test_default_with_no_labels(self):
features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue("labels" not in batch)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# With label_ids
features = [{"label_ids": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue("labels" not in batch)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
def test_default_classification(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = default_data_collator
batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.long)
def test_default_regression(self):
MODEL_ID = "distilroberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
data_args = GlueDataTrainingArguments(
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = default_data_collator
batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.float)
def test_lm_tokenizer_without_padding(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# ^ causal lm
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
with self.assertRaises(ValueError):
# Expect error due to padding token missing on gpt2:
data_collator(examples)
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
def test_lm_tokenizer_with_padding(self):
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
data_collator = DataCollatorForLanguageModeling(tokenizer)
# ^ masked lm
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
def test_plm(self):
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
# ^ permutation lm
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 112)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
example = [torch.randint(5, [5])]
with self.assertRaises(ValueError):
# Expect error due to odd sequence length
data_collator(example)
import unittest import unittest
import numpy as np
from transformers import AutoTokenizer, TrainingArguments, is_torch_available from transformers import AutoTokenizer, TrainingArguments, is_torch_available
from transformers.testing_utils import require_torch from transformers.testing_utils import require_torch
...@@ -10,149 +12,38 @@ if is_torch_available(): ...@@ -10,149 +12,38 @@ if is_torch_available():
from transformers import ( from transformers import (
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
LineByLineTextDataset, LineByLineTextDataset,
TextDataset,
Trainer, Trainer,
default_data_collator,
) )
PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt" PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
@require_torch class RegressionDataset:
class DataCollatorIntegrationTest(unittest.TestCase): def __init__(self, a=2, b=3, length=64, seed=42):
def test_default_with_dict(self): np.random.seed(seed)
features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] self.length = length
batch = default_data_collator(features) self.x = np.random.normal(size=(length,)).astype(np.float32)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8))))) self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# With label_ids
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor([[0, 1, 2]] * 8)))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# Features can already be tensors
features = [{"label": i, "inputs": torch.randint(10, [10])} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
# Labels can already be tensors
features = [{"label": torch.tensor(i), "inputs": torch.randint(10, [10])} for i in range(8)]
batch = default_data_collator(features)
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
def test_default_with_no_labels(self):
features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue("labels" not in batch)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# With label_ids
features = [{"label_ids": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue("labels" not in batch)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
def test_default_classification(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = default_data_collator
batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.long)
def test_default_regression(self): def __len__(self):
MODEL_ID = "distilroberta-base" return self.length
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
data_args = GlueDataTrainingArguments( def __getitem__(self, i):
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True return {"input_x": self.x[i], "label": self.y[i]}
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = default_data_collator class AlmostAccuracy:
batch = data_collator(dataset.features) def __init__(self, thresh=0.25):
self.assertEqual(batch["labels"].dtype, torch.float) self.thresh = thresh
def test_lm_tokenizer_without_padding(self): def __call__(self, eval_pred):
tokenizer = AutoTokenizer.from_pretrained("gpt2") predictions, labels = eval_pred
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) true = np.abs(predictions - labels) <= self.thresh
# ^ causal lm return {"accuracy": true.astype(np.float32).mean().item()}
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
with self.assertRaises(ValueError):
# Expect error due to padding token missing on gpt2:
data_collator(examples)
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
def test_lm_tokenizer_with_padding(self):
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
data_collator = DataCollatorForLanguageModeling(tokenizer)
# ^ masked lm
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
def test_plm(self):
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
# ^ permutation lm
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 112)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
example = [torch.randint(5, [5])]
with self.assertRaises(ValueError):
# Expect error due to odd sequence length
data_collator(example)
if is_torch_available(): if is_torch_available():
...@@ -168,9 +59,137 @@ if is_torch_available(): ...@@ -168,9 +59,137 @@ if is_torch_available():
def __iter__(self): def __iter__(self):
return iter(self.parse_file()) return iter(self.parse_file())
class RegressionModel(torch.nn.Module):
def __init__(self, a=0, b=0):
super().__init__()
self.a = torch.nn.Parameter(torch.tensor(a).float())
self.b = torch.nn.Parameter(torch.tensor(b).float())
def forward(self, input_x=None, labels=None):
y = input_x * self.a + self.b
if labels is None:
return (y,)
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y)
def get_regression_trainer(a=0, b=0, train_len=64, eval_len=64, **kwargs):
train_dataset = RegressionDataset(length=train_len)
eval_dataset = RegressionDataset(length=eval_len)
model = RegressionModel(a, b)
compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None))
args = TrainingArguments("./regression", **kwargs)
return Trainer(
model,
args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
optimizers=optimizers,
)
@require_torch @require_torch
class TrainerIntegrationTest(unittest.TestCase): class TrainerIntegrationTest(unittest.TestCase):
def setUp(self):
# Get the default values (in case they change):
args = TrainingArguments(".")
self.n_epochs = args.num_train_epochs
self.batch_size = args.per_device_train_batch_size
def test_reproducible_training(self):
# Checks that training worked, model trained and seed made a reproducible training.
trainer = get_regression_trainer(learning_rate=0.1)
trainer.train()
self.assertTrue(torch.abs(trainer.model.a - 0.6975) < 1e-4)
self.assertTrue(torch.abs(trainer.model.b - 1.2415) < 1e-4)
# Checks that a different seed gets different (reproducible) results.
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
trainer.train()
self.assertTrue(torch.abs(trainer.model.a - 1.0171) < 1e-4)
self.assertTrue(torch.abs(trainer.model.b - 1.2494) < 1e-4)
def test_number_of_steps_in_training(self):
# Regular training has n_epochs * len(train_dl) steps
trainer = get_regression_trainer(learning_rate=0.1)
train_output = trainer.train()
self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size)
# Check passing num_train_epochs works (and a float version too):
trainer = get_regression_trainer(learning_rate=0.1, num_train_epochs=1.5)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size))
# If we pass a max_steps, num_train_epochs is ignored
trainer = get_regression_trainer(learning_rate=0.1, max_steps=10)
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
def test_train_and_eval_dataloaders(self):
trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
self.assertEqual(trainer.get_train_dataloader().batch_size, 16)
trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)
self.assertEqual(trainer.get_eval_dataloader().batch_size, 16)
# Check drop_last works
trainer = get_regression_trainer(
train_len=66, eval_len=74, learning_rate=0.1, per_device_train_batch_size=16, per_device_eval_batch_size=32
)
self.assertEqual(len(trainer.get_train_dataloader()), 66 // 16 + 1)
self.assertEqual(len(trainer.get_eval_dataloader()), 74 // 32 + 1)
trainer = get_regression_trainer(
train_len=66,
eval_len=74,
learning_rate=0.1,
per_device_train_batch_size=16,
per_device_eval_batch_size=32,
dataloader_drop_last=True,
)
self.assertEqual(len(trainer.get_train_dataloader()), 66 // 16)
self.assertEqual(len(trainer.get_eval_dataloader()), 74 // 32)
# Check passing a new dataset fpr evaluation wors
new_eval_dataset = RegressionDataset(length=128)
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // 32)
def test_evaluate(self):
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.y
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.y
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
def test_predict(self):
trainer = get_regression_trainer(a=1.5, b=2.5)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
def test_trainer_eval_mrpc(self): def test_trainer_eval_mrpc(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc" MODEL_ID = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
......
...@@ -60,7 +60,8 @@ if is_torch_available(): ...@@ -60,7 +60,8 @@ if is_torch_available():
if __name__ == "__main__": if __name__ == "__main__":
parser = HfArgumentParser((TrainingArguments,)) parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses(sys.argv + ["--output_dir", "./examples"])[0] sys.argv += ["--output_dir", "./examples"]
training_args = parser.parse_args_into_dataclasses()[0]
logger.warning( logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s", "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment