Unverified Commit 115d97dd authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove subclass for sortish sampler (#9907)

* Remove subclass for sortish sampler

* Use old Seq2SeqTrainer in script

* Styling
parent 1682804e
...@@ -20,6 +20,8 @@ from dataclasses import dataclass, field ...@@ -20,6 +20,8 @@ from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import transformers import transformers
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
...@@ -27,8 +29,6 @@ from transformers import ( ...@@ -27,8 +29,6 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
MBartTokenizer, MBartTokenizer,
MBartTokenizerFast, MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import EvaluationStrategy, is_main_process from transformers.trainer_utils import EvaluationStrategy, is_main_process
...@@ -286,6 +286,7 @@ def main(): ...@@ -286,6 +286,7 @@ def main():
trainer = Seq2SeqTrainer( trainer = Seq2SeqTrainer(
model=model, model=model,
args=training_args, args=training_args,
data_args=data_args,
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator( data_collator=Seq2SeqDataCollator(
...@@ -323,9 +324,7 @@ def main(): ...@@ -323,9 +324,7 @@ def main():
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate( metrics = trainer.evaluate(metric_key_prefix="val")
metric_key_prefix="val", max_length=data_args.val_max_target_length, num_beams=data_args.eval_beams
)
metrics["val_n_objs"] = data_args.n_val metrics["val_n_objs"] = data_args.n_val
metrics["val_loss"] = round(metrics["val_loss"], 4) metrics["val_loss"] = round(metrics["val_loss"], 4)
...@@ -337,12 +336,7 @@ def main(): ...@@ -337,12 +336,7 @@ def main():
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Predict ***") logger.info("*** Predict ***")
test_output = trainer.predict( test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
test_dataset=test_dataset,
metric_key_prefix="test",
max_length=data_args.val_max_target_length,
num_beams=data_args.eval_beams,
)
metrics = test_output.metrics metrics = test_output.metrics
metrics["test_n_objs"] = data_args.n_test metrics["test_n_objs"] = data_args.n_test
......
...@@ -17,14 +17,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -17,14 +17,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
from .file_utils import is_torch_tpu_available
from .trainer import Trainer from .trainer import Trainer
from .trainer_pt_utils import get_tpu_sampler
from .trainer_utils import PredictionOutput from .trainer_utils import PredictionOutput
from .training_args import ParallelMode
from .utils import logging from .utils import logging
...@@ -36,24 +32,6 @@ logger = logging.get_logger(__name__) ...@@ -36,24 +32,6 @@ logger = logging.get_logger(__name__)
class Seq2SeqTrainer(Trainer): class Seq2SeqTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size,
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
)
return (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
def evaluate( def evaluate(
self, self,
eval_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment