Unverified Commit 4f5faaf0 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[deepspeed] fix a bug in a test (#15493)

* [deepspeed] fix a bug in a test

* consistency
parent 90166121
...@@ -25,6 +25,7 @@ from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available ...@@ -25,6 +25,7 @@ from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
from transformers.file_utils import WEIGHTS_NAME from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger, CaptureLogger,
CaptureStd,
CaptureStderr, CaptureStderr,
ExtendSysPath, ExtendSysPath,
LoggingLevel, LoggingLevel,
...@@ -972,7 +973,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus): ...@@ -972,7 +973,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
with CaptureStderr() as cs: with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env()) execute_subprocess_async(cmd, env=self.get_env())
assert "Detected DeepSpeed ZeRO-3" in cs.err self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)
@parameterized.expand(stages) @parameterized.expand(stages)
def test_load_best_model(self, stage): def test_load_best_model(self, stage):
...@@ -1008,14 +1009,14 @@ class TestDeepSpeedWithLauncher(TestCasePlus): ...@@ -1008,14 +1009,14 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
""".split() """.split()
args.extend(["--source_prefix", "translate English to Romanian: "]) args.extend(["--source_prefix", "translate English to Romanian: "])
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split() ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"] script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
launcher = get_launcher(distributed=False) launcher = get_launcher(distributed=False)
cmd = launcher + script + args + ds_args cmd = launcher + script + args + ds_args
# keep for quick debug # keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
with CaptureStderr() as cs: with CaptureStd() as cs:
execute_subprocess_async(cmd, env=self.get_env()) execute_subprocess_async(cmd, env=self.get_env())
# enough to test it didn't fail # enough to test it didn't fail
assert "Detected DeepSpeed ZeRO-3" in cs.err self.assertIn("DeepSpeed info", cs.out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment