"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8fe836af5a164dac60b8c45cc73a7b9ed72393a3"
Unverified Commit 115d97dd authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove subclass for sortish sampler (#9907)

* Remove subclass for sortish sampler

* Use old Seq2SeqTrainer in script

* Styling
parent 1682804e
......@@ -20,6 +20,8 @@ from dataclasses import dataclass, field
from typing import Optional
import transformers
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
......@@ -27,8 +29,6 @@ from transformers import (
HfArgumentParser,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
)
from transformers.trainer_utils import EvaluationStrategy, is_main_process
......@@ -286,6 +286,7 @@ def main():
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_args=data_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(
......@@ -323,9 +324,7 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(
metric_key_prefix="val", max_length=data_args.val_max_target_length, num_beams=data_args.eval_beams
)
metrics = trainer.evaluate(metric_key_prefix="val")
metrics["val_n_objs"] = data_args.n_val
metrics["val_loss"] = round(metrics["val_loss"], 4)
......@@ -337,12 +336,7 @@ def main():
if training_args.do_predict:
logger.info("*** Predict ***")
test_output = trainer.predict(
test_dataset=test_dataset,
metric_key_prefix="test",
max_length=data_args.val_max_target_length,
num_beams=data_args.eval_beams,
)
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
metrics = test_output.metrics
metrics["test_n_objs"] = data_args.n_test
......
......@@ -17,14 +17,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from torch.utils.data.dataset import Dataset
from .file_utils import is_torch_tpu_available
from .trainer import Trainer
from .trainer_pt_utils import get_tpu_sampler
from .trainer_utils import PredictionOutput
from .training_args import ParallelMode
from .utils import logging
......@@ -36,24 +32,6 @@ logger = logging.get_logger(__name__)
class Seq2SeqTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size,
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
)
return (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment