Unverified Commit 7ecd229b authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

Smangrul/fix failing ds ci tests (#27358)

* fix failing DeepSpeed CI tests due to `safetensors` being default

* debug

* remove debug statements

* resolve comments

* Update test_deepspeed.py
parent ced9fd86
...@@ -48,7 +48,7 @@ from transformers.testing_utils import ( ...@@ -48,7 +48,7 @@ from transformers.testing_utils import (
slow, slow,
) )
from transformers.trainer_utils import get_last_checkpoint, set_seed from transformers.trainer_utils import get_last_checkpoint, set_seed
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_gpu_available from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_gpu_available
if is_torch_available(): if is_torch_available():
...@@ -565,8 +565,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T ...@@ -565,8 +565,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype): def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
# adapted from TrainerIntegrationCommon.check_saved_checkpoints # adapted from TrainerIntegrationCommon.check_saved_checkpoints
file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
if stage == ZERO2: if stage == ZERO2:
ds_file_list = ["mp_rank_00_model_states.pt"] ds_file_list = ["mp_rank_00_model_states.pt"]
...@@ -581,7 +580,6 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T ...@@ -581,7 +580,6 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
for step in range(freq, total, freq): for step in range(freq, total, freq):
checkpoint = os.path.join(output_dir, f"checkpoint-{step}") checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
self.assertTrue(os.path.isdir(checkpoint), f"[{stage}] {checkpoint} dir is not found") self.assertTrue(os.path.isdir(checkpoint), f"[{stage}] {checkpoint} dir is not found")
# common files # common files
for filename in file_list: for filename in file_list:
path = os.path.join(checkpoint, filename) path = os.path.join(checkpoint, filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment