Unverified Commit 238d2e3c authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix resuming from ckpt when using FSDP with FULL_STATE_DICT (#27891)

* fix resuming from ckpt when suing FSDP with FULL_STATE_DICT

* update tests

* fix tests
parent ebfdb9ca
...@@ -2033,10 +2033,15 @@ class Trainer: ...@@ -2033,10 +2033,15 @@ class Trainer:
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any( is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
FSDP_MODEL_NAME in folder_name # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
for folder_name in os.listdir(resume_from_checkpoint) any(
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
) )
if is_fsdp_ckpt and not self.is_fsdp_enabled: if is_fsdp_ckpt and not self.is_fsdp_enabled:
......
...@@ -41,6 +41,7 @@ from transformers.utils import is_accelerate_available, is_torch_bf16_available_ ...@@ -41,6 +41,7 @@ from transformers.utils import is_accelerate_available, is_torch_bf16_available_
if is_torch_available(): if is_torch_available():
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1 from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
from transformers.trainer import FSDP_MODEL_NAME
else: else:
is_torch_greater_or_equal_than_2_1 = False is_torch_greater_or_equal_than_2_1 = False
...@@ -211,6 +212,19 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): ...@@ -211,6 +212,19 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
# resume from ckpt # resume from ckpt
checkpoint = os.path.join(output_dir, "checkpoint-115") checkpoint = os.path.join(output_dir, "checkpoint-115")
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split() resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(checkpoint)
if os.path.isdir(os.path.join(checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
self.assertTrue(is_fsdp_ckpt)
logs_resume = self.run_cmd_and_get_logs( logs_resume = self.run_cmd_and_get_logs(
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment