"test/srt/git@developer.sourcefind.cn:change/sglang.git" did not exist on "e51046beaa67c3dd39f1814488ffc147ce5e740d"
Unverified Commit 02842855 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Fix pad across processes dim in trainer and not being able to set the timeout (#24775)



* dim, and rm copy

* Don't rm copy for now

* Oops

* pad index

* Should be a working test

* Tickle down ddp timeout

* Put fix back in now that testing locally is done

* Better comment specifying timeout
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 4f85aaa6
...@@ -3131,9 +3131,9 @@ class Trainer: ...@@ -3131,9 +3131,9 @@ class Trainer:
losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size))) losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
if labels is not None: if labels is not None:
labels = self.accelerator.pad_across_processes(labels) labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
if inputs_decode is not None: if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(inputs_decode) inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self.accelerator.gather_for_metrics((inputs_decode)) inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
inputs_host = ( inputs_host = (
inputs_decode inputs_decode
...@@ -3141,7 +3141,7 @@ class Trainer: ...@@ -3141,7 +3141,7 @@ class Trainer:
else nested_concat(inputs_host, inputs_decode, padding_index=-100) else nested_concat(inputs_host, inputs_decode, padding_index=-100)
) )
if logits is not None: if logits is not None:
logits = self.accelerator.pad_across_processes(logits) logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None: if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self.accelerator.gather_for_metrics((logits)) logits = self.accelerator.gather_for_metrics((logits))
......
...@@ -1714,7 +1714,9 @@ class TrainingArguments: ...@@ -1714,7 +1714,9 @@ class TrainingArguments:
del os.environ["ACCELERATE_USE_DEEPSPEED"] del os.environ["ACCELERATE_USE_DEEPSPEED"]
self._n_gpu = 1 self._n_gpu = 1
else: else:
self.distributed_state = PartialState(backend=self.ddp_backend) self.distributed_state = PartialState(
backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout)
)
self._n_gpu = 1 self._n_gpu = 1
if not is_sagemaker_mp_enabled(): if not is_sagemaker_mp_enabled():
device = self.distributed_state.device device = self.distributed_state.device
......
...@@ -49,6 +49,7 @@ from transformers.testing_utils import ( ...@@ -49,6 +49,7 @@ from transformers.testing_utils import (
USER, USER,
CaptureLogger, CaptureLogger,
TestCasePlus, TestCasePlus,
execute_subprocess_async,
get_gpu_count, get_gpu_count,
get_tests_dir, get_tests_dir,
is_staging_test, is_staging_test,
...@@ -2098,6 +2099,51 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2098,6 +2099,51 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params) self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params) self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
@slow
@require_torch_multi_gpu
def test_end_to_end_example(self):
# Tests that `translation.py` will run without issues
script_path = os.path.abspath(
os.path.join(
os.path.dirname(__file__), "..", "..", "examples", "pytorch", "translation", "run_translation.py"
)
)
with tempfile.TemporaryDirectory() as tmpdir:
command = [
"accelerate",
"launch",
script_path,
"--model_name_or_path",
"t5-small",
"--per_device_train_batch_size",
"1",
"--output_dir",
tmpdir,
"--overwrite_output_dir",
"--do_train",
"--max_train_samples",
"64",
"--num_train_epochs",
"1",
"--dataset_name",
"wmt16",
"--dataset_config",
"ro-en",
"--source_lang",
"en",
"--target_lang",
"ro",
"--do_predict",
"--max_predict_samples",
"64",
"--predict_with_generate",
"--ddp_timeout",
"60",
]
execute_subprocess_async(command)
# successful return here == success - any errors would have caused an error or a timeout in the sub-call
@require_torch @require_torch
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment