Unverified Commit 379005c9 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

start using training_args.parallel_mode (#8882)

parent b08843cf
...@@ -11,6 +11,7 @@ from seq2seq_trainer import Seq2SeqTrainer ...@@ -11,6 +11,7 @@ from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments from seq2seq_training_args import Seq2SeqTrainingArguments
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
from transformers.trainer_utils import EvaluationStrategy, is_main_process from transformers.trainer_utils import EvaluationStrategy, is_main_process
from transformers.training_args import ParallelMode
from utils import ( from utils import (
Seq2SeqDataCollator, Seq2SeqDataCollator,
Seq2SeqDataset, Seq2SeqDataset,
...@@ -132,7 +133,7 @@ def main(): ...@@ -132,7 +133,7 @@ def main():
training_args.local_rank, training_args.local_rank,
training_args.device, training_args.device,
training_args.n_gpu, training_args.n_gpu,
bool(training_args.local_rank != -1), bool(training_args.parallel_mode == ParallelMode.DISTRIBUTED),
training_args.fp16, training_args.fp16,
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
......
...@@ -18,6 +18,7 @@ from transformers.optimization import ( ...@@ -18,6 +18,7 @@ from transformers.optimization import (
get_polynomial_decay_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup,
) )
from transformers.trainer_pt_utils import get_tpu_sampler from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -123,7 +124,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -123,7 +124,7 @@ class Seq2SeqTrainer(Trainer):
if self.args.sortish_sampler: if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler( self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size, self.args.per_device_train_batch_size,
distributed=(self.args.local_rank != -1), distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
) )
return ( return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment