"git@developer.sourcefind.cn:change/sglang.git" did not exist on "1a6e97577acb017fa9c25daf8a533969e941aa09"
Commit d80ad54f authored by Sujit Verma's avatar Sujit Verma Committed by Facebook Github Bot
Browse files

Added option to save checkpoints using Path Manager.

Summary: Added option to save checkpoints using Path Manager.

Reviewed By: hudeven

Differential Revision: D17392754

fbshipit-source-id: 4b8e556ef8455a1548e5a083d779ed809cd785be
parent 02b74c58
...@@ -65,6 +65,10 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -65,6 +65,10 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if len(checkpoints) > 0: if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state) trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]: for cp in checkpoints[1:]:
try:
from fairseq.fb_pathmgr import fb_pathmgr
fb_pathmgr.copy(checkpoints[0], cp, True)
except (ModuleNotFoundError, ImportError):
shutil.copyfile(checkpoints[0], cp) shutil.copyfile(checkpoints[0], cp)
write_timer.stop() write_timer.stop()
...@@ -132,6 +136,14 @@ def load_checkpoint(args, trainer, data_selector=None): ...@@ -132,6 +136,14 @@ def load_checkpoint(args, trainer, data_selector=None):
def load_checkpoint_to_cpu(path, arg_overrides=None): def load_checkpoint_to_cpu(path, arg_overrides=None):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).""" """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
try:
from fairseq.fb_pathmgr import fb_pathmgr
with fb_pathmgr.open(path, "rb") as f:
state = torch.load(
f, map_location=lambda s, l: default_restore_location(s, 'cpu'),
)
except (ModuleNotFoundError, ImportError):
# if path manager not found, continue with local file.
state = torch.load( state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'), path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
) )
...@@ -244,6 +256,13 @@ def save_state( ...@@ -244,6 +256,13 @@ def save_state(
state_dict['criterion'] = criterion.state_dict() state_dict['criterion'] = criterion.state_dict()
if not args.no_save_optimizer_state: if not args.no_save_optimizer_state:
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict()) state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())
try:
from fairseq.fb_pathmgr import fb_pathmgr
with fb_pathmgr.open(filename, "wb") as f:
torch_persistent_save(state_dict, f)
except (ModuleNotFoundError, ImportError):
# if path manager not found, continue with local file.
torch_persistent_save(state_dict, filename) torch_persistent_save(state_dict, filename)
......
...@@ -170,7 +170,13 @@ class Trainer(object): ...@@ -170,7 +170,13 @@ class Trainer(object):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = None, [], None extra_state, self._optim_history, last_optim_state = None, [], None
if os.path.exists(filename): try:
from fairseq.fb_pathmgr import fb_pathmgr
bexists = fb_pathmgr.isfile(filename)
except Exception:
bexists = os.path.exists(filename)
if bexists:
state = checkpoint_utils.load_checkpoint_to_cpu(filename) state = checkpoint_utils.load_checkpoint_to_cpu(filename)
# load model parameters # load model parameters
......
...@@ -19,10 +19,21 @@ from fairseq.data import iterators ...@@ -19,10 +19,21 @@ from fairseq.data import iterators
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
fb_pathmgr_registerd = False
def main(args, init_distributed=False): def main(args, init_distributed=False):
utils.import_user_module(args) utils.import_user_module(args)
try:
from fairseq.fb_pathmgr import fb_pathmgr
global fb_pathmgr_registerd
if not fb_pathmgr_registerd:
fb_pathmgr.register()
fb_pathmgr_registerd = True
except (ModuleNotFoundError, ImportError):
pass
assert args.max_tokens is not None or args.max_sentences is not None, \ assert args.max_tokens is not None or args.max_sentences is not None, \
'Must specify batch size either with --max-tokens or --max-sentences' 'Must specify batch size either with --max-tokens or --max-sentences'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment