Commit c547f15a authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Use Filelock to ensure distributed barriers

see context in https://github.com/huggingface/transformers/pull/4223
parent 015f7812
...@@ -118,13 +118,9 @@ class DataTrainingArguments: ...@@ -118,13 +118,9 @@ class DataTrainingArguments:
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False, local_rank=-1): def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False, local_rank=-1):
file_path = args.eval_data_file if evaluate else args.train_data_file file_path = args.eval_data_file if evaluate else args.train_data_file
if args.line_by_line: if args.line_by_line:
return LineByLineTextDataset( return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, local_rank=local_rank
)
else: else:
return TextDataset( return TextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, local_rank=local_rank,
)
def main(): def main():
......
...@@ -159,7 +159,6 @@ def main(): ...@@ -159,7 +159,6 @@ def main():
max_seq_length=data_args.max_seq_length, max_seq_length=data_args.max_seq_length,
overwrite_cache=data_args.overwrite_cache, overwrite_cache=data_args.overwrite_cache,
mode=Split.train, mode=Split.train,
local_rank=training_args.local_rank,
) )
if training_args.do_train if training_args.do_train
else None else None
...@@ -172,7 +171,6 @@ def main(): ...@@ -172,7 +171,6 @@ def main():
max_seq_length=data_args.max_seq_length, max_seq_length=data_args.max_seq_length,
overwrite_cache=data_args.overwrite_cache, overwrite_cache=data_args.overwrite_cache,
mode=Split.dev, mode=Split.dev,
local_rank=training_args.local_rank,
) )
if training_args.do_eval if training_args.do_eval
else None else None
......
...@@ -26,6 +26,7 @@ from enum import Enum ...@@ -26,6 +26,7 @@ from enum import Enum
from typing import List, Optional from typing import List, Optional
import tqdm import tqdm
from filelock import FileLock
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
...@@ -77,7 +78,6 @@ class Split(Enum): ...@@ -77,7 +78,6 @@ class Split(Enum):
if is_torch_available(): if is_torch_available():
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
from transformers import torch_distributed_zero_first
class MultipleChoiceDataset(Dataset): class MultipleChoiceDataset(Dataset):
""" """
...@@ -95,7 +95,6 @@ if is_torch_available(): ...@@ -95,7 +95,6 @@ if is_torch_available():
max_seq_length: Optional[int] = None, max_seq_length: Optional[int] = None,
overwrite_cache=False, overwrite_cache=False,
mode: Split = Split.train, mode: Split = Split.train,
local_rank=-1,
): ):
processor = processors[task]() processor = processors[task]()
...@@ -103,9 +102,11 @@ if is_torch_available(): ...@@ -103,9 +102,11 @@ if is_torch_available():
data_dir, data_dir,
"cached_{}_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length), task,), "cached_{}_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length), task,),
) )
with torch_distributed_zero_first(local_rank):
# Make sure only the first process in distributed training processes the dataset, # Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache. # and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not overwrite_cache: if os.path.exists(cached_features_file) and not overwrite_cache:
logger.info(f"Loading features from cached file {cached_features_file}") logger.info(f"Loading features from cached file {cached_features_file}")
...@@ -130,9 +131,8 @@ if is_torch_available(): ...@@ -130,9 +131,8 @@ if is_torch_available():
pad_token=tokenizer.pad_token_id, pad_token=tokenizer.pad_token_id,
pad_token_segment_id=tokenizer.pad_token_type_id, pad_token_segment_id=tokenizer.pad_token_type_id,
) )
if local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file)
logger.info("Saving features into cached file %s", cached_features_file) torch.save(self.features, cached_features_file)
torch.save(self.features, cached_features_file)
def __len__(self): def __len__(self):
return len(self.features) return len(self.features)
......
...@@ -171,7 +171,6 @@ def main(): ...@@ -171,7 +171,6 @@ def main():
max_seq_length=data_args.max_seq_length, max_seq_length=data_args.max_seq_length,
overwrite_cache=data_args.overwrite_cache, overwrite_cache=data_args.overwrite_cache,
mode=Split.train, mode=Split.train,
local_rank=training_args.local_rank,
) )
if training_args.do_train if training_args.do_train
else None else None
...@@ -185,7 +184,6 @@ def main(): ...@@ -185,7 +184,6 @@ def main():
max_seq_length=data_args.max_seq_length, max_seq_length=data_args.max_seq_length,
overwrite_cache=data_args.overwrite_cache, overwrite_cache=data_args.overwrite_cache,
mode=Split.dev, mode=Split.dev,
local_rank=training_args.local_rank,
) )
if training_args.do_eval if training_args.do_eval
else None else None
...@@ -261,7 +259,6 @@ def main(): ...@@ -261,7 +259,6 @@ def main():
max_seq_length=data_args.max_seq_length, max_seq_length=data_args.max_seq_length,
overwrite_cache=data_args.overwrite_cache, overwrite_cache=data_args.overwrite_cache,
mode=Split.test, mode=Split.test,
local_rank=training_args.local_rank,
) )
predictions, label_ids, metrics = trainer.predict(test_dataset) predictions, label_ids, metrics = trainer.predict(test_dataset)
......
...@@ -22,6 +22,8 @@ from dataclasses import dataclass ...@@ -22,6 +22,8 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Optional, Union from typing import List, Optional, Union
from filelock import FileLock
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
...@@ -68,7 +70,6 @@ if is_torch_available(): ...@@ -68,7 +70,6 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
from transformers import torch_distributed_zero_first
class NerDataset(Dataset): class NerDataset(Dataset):
""" """
...@@ -90,16 +91,16 @@ if is_torch_available(): ...@@ -90,16 +91,16 @@ if is_torch_available():
max_seq_length: Optional[int] = None, max_seq_length: Optional[int] = None,
overwrite_cache=False, overwrite_cache=False,
mode: Split = Split.train, mode: Split = Split.train,
local_rank=-1,
): ):
# Load data features from cache or dataset file # Load data features from cache or dataset file
cached_features_file = os.path.join( cached_features_file = os.path.join(
data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)), data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)),
) )
with torch_distributed_zero_first(local_rank): # Make sure only the first process in distributed training processes the dataset,
# Make sure only the first process in distributed training processes the dataset, # and the others will use the cache.
# and the others will use the cache. lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not overwrite_cache: if os.path.exists(cached_features_file) and not overwrite_cache:
logger.info(f"Loading features from cached file {cached_features_file}") logger.info(f"Loading features from cached file {cached_features_file}")
...@@ -125,9 +126,8 @@ if is_torch_available(): ...@@ -125,9 +126,8 @@ if is_torch_available():
pad_token_segment_id=tokenizer.pad_token_type_id, pad_token_segment_id=tokenizer.pad_token_type_id,
pad_token_label_id=self.pad_token_label_id, pad_token_label_id=self.pad_token_label_id,
) )
if local_rank in [-1, 0]: logger.info(f"Saving features into cached file {cached_features_file}")
logger.info(f"Saving features into cached file {cached_features_file}") torch.save(self.features, cached_features_file)
torch.save(self.features, cached_features_file)
def __len__(self): def __len__(self):
return len(self.features) return len(self.features)
......
...@@ -4,10 +4,10 @@ import pickle ...@@ -4,10 +4,10 @@ import pickle
import time import time
import torch import torch
from filelock import FileLock
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...trainer import torch_distributed_zero_first
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -20,7 +20,7 @@ class TextDataset(Dataset): ...@@ -20,7 +20,7 @@ class TextDataset(Dataset):
""" """
def __init__( def __init__(
self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False, local_rank=-1, self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False,
): ):
assert os.path.isfile(file_path) assert os.path.isfile(file_path)
...@@ -31,9 +31,10 @@ class TextDataset(Dataset): ...@@ -31,9 +31,10 @@ class TextDataset(Dataset):
directory, "cached_lm_{}_{}_{}".format(tokenizer.__class__.__name__, str(block_size), filename,), directory, "cached_lm_{}_{}_{}".format(tokenizer.__class__.__name__, str(block_size), filename,),
) )
with torch_distributed_zero_first(local_rank): # Make sure only the first process in distributed training processes the dataset,
# Make sure only the first process in distributed training processes the dataset, # and the others will use the cache.
# and the others will use the cache. lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not overwrite_cache: if os.path.exists(cached_features_file) and not overwrite_cache:
start = time.time() start = time.time()
...@@ -80,7 +81,7 @@ class LineByLineTextDataset(Dataset): ...@@ -80,7 +81,7 @@ class LineByLineTextDataset(Dataset):
soon. soon.
""" """
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, local_rank=-1): def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
assert os.path.isfile(file_path) assert os.path.isfile(file_path)
# Here, we do not cache the features, operating under the assumption # Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the # that we will soon use fast multithreaded tokenizers from the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment