"csrc/git@developer.sourcefind.cn:norm/vllm.git" did not exist on "621980bdc0d5a41e224febf962a6e0474e2b14ef"
Unverified Commit 64e60980 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] add main_process_first context manager (#12351)

* main_process_first context manager

* handle multi-node, add context description

* sync desc
parent f8664258
...@@ -428,14 +428,15 @@ def main(): ...@@ -428,14 +428,15 @@ def main():
train_dataset = raw_datasets["train"] train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples)) train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map( with training_args.main_process_first(desc="train dataset map pre-processing"):
preprocess_function, train_dataset = train_dataset.map(
batched=True, preprocess_function,
num_proc=data_args.preprocessing_num_workers, batched=True,
remove_columns=column_names, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache, remove_columns=column_names,
desc="Running tokenizer on train dataset", load_from_cache_file=not data_args.overwrite_cache,
) desc="Running tokenizer on train dataset",
)
if training_args.do_eval: if training_args.do_eval:
max_target_length = data_args.val_max_target_length max_target_length = data_args.val_max_target_length
...@@ -444,14 +445,15 @@ def main(): ...@@ -444,14 +445,15 @@ def main():
eval_dataset = raw_datasets["validation"] eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map( with training_args.main_process_first(desc="validation dataset map pre-processing"):
preprocess_function, eval_dataset = eval_dataset.map(
batched=True, preprocess_function,
num_proc=data_args.preprocessing_num_workers, batched=True,
remove_columns=column_names, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache, remove_columns=column_names,
desc="Running tokenizer on validation dataset", load_from_cache_file=not data_args.overwrite_cache,
) desc="Running tokenizer on validation dataset",
)
if training_args.do_predict: if training_args.do_predict:
max_target_length = data_args.val_max_target_length max_target_length = data_args.val_max_target_length
...@@ -460,14 +462,15 @@ def main(): ...@@ -460,14 +462,15 @@ def main():
predict_dataset = raw_datasets["test"] predict_dataset = raw_datasets["test"]
if data_args.max_predict_samples is not None: if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
predict_dataset = predict_dataset.map( with training_args.main_process_first(desc="prediction dataset map pre-processing"):
preprocess_function, predict_dataset = predict_dataset.map(
batched=True, preprocess_function,
num_proc=data_args.preprocessing_num_workers, batched=True,
remove_columns=column_names, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache, remove_columns=column_names,
desc="Running tokenizer on prediction dataset", load_from_cache_file=not data_args.overwrite_cache,
) desc="Running tokenizer on prediction dataset",
)
# Data collator # Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import json import json
import os import os
import warnings import warnings
...@@ -968,6 +969,49 @@ class TrainingArguments: ...@@ -968,6 +969,49 @@ class TrainingArguments:
""" """
return not (self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled()) return not (self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled())
@contextlib.contextmanager
def main_process_first(self, local=True, desc="work"):
"""
A context manager for torch distributed environment where on needs to do something on the main process,
while blocking replicas, and when it's finished releasing the replicas.
One such use is for ``datasets``'s ``map`` feature which to be efficient should be run once on the main
process, which upon completion saves a cached version of results and which then automatically gets loaded
by the replicas.
Args:
local (:obj:`bool`, `optional`, defaults to :obj:`True`):
if :obj:`True` first means process of rank 0 of each node if :obj:`False` first means process of rank 0
of node rank 0 In multi-node environment with a shared filesystem you most likely will want to use
``local=False`` so that only the main process of the first node will do the processing. If however, the
filesystem is not shared, then the main process of each node will need to do the processing, which is
the default behavior.
desc (:obj:`str`, `optional`, defaults to ``"work"``):
a work description to be used in debug logs
"""
if is_torch_available() and self.world_size > 1:
if local:
is_main_process = self.local_process_index == 0
main_process_desc = "main local process"
else:
is_main_process = self.process_index == 0
main_process_desc = "main process"
try:
if not is_main_process:
# tell all replicas to wait
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
torch.distributed.barrier()
yield
finally:
if is_main_process:
# the wait is over
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
torch.distributed.barrier()
else:
yield
def to_dict(self): def to_dict(self):
""" """
Serializes this instance while replace `Enum` by their values (for JSON serialization support). Serializes this instance while replace `Enum` by their values (for JSON serialization support).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment