"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "192b3b3ceba3b4eec6a729c0b171ddd6cdc10025"
Unverified Commit 8f1cb0c7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Allow manually specifying checkpoint filename (#95)

parent 615f8144
...@@ -314,16 +314,20 @@ class Trainer: ...@@ -314,16 +314,20 @@ class Trainer:
filename (str): Input file name filename (str): Input file name
device (:class:`torch.device`): device to train the model device (:class:`torch.device`): device to train the model
tqdm (bool): whether to enable tqdm tqdm (bool): whether to enable tqdm
tensorboard (str): Directory to store tensorboard log file, set to\ tensorboard (str): Directory to store tensorboard log file, set to
``None`` to disable tensorboardX. ``None`` to disable tensorboardX.
aev_caching (bool): Whether to use AEV caching. aev_caching (bool): Whether to use AEV caching.
checkpoint_name (str): Name of the checkpoint file, checkpoints will be
stored in the network directory with this file name.
""" """
def __init__(self, filename, device=torch.device('cuda'), def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
tqdm=False, tensorboard=None, aev_caching=False): tensorboard=None, aev_caching=False,
checkpoint_name='model.pt'):
self.filename = filename self.filename = filename
self.device = device self.device = device
self.aev_caching = aev_caching self.aev_caching = aev_caching
self.checkpoint_name = checkpoint_name
if tqdm: if tqdm:
import tqdm import tqdm
self.tqdm = tqdm.tqdm self.tqdm = tqdm.tqdm
...@@ -475,7 +479,7 @@ class Trainer: ...@@ -475,7 +479,7 @@ class Trainer:
network_dir = os.path.join(dir, params['ntwkStoreDir']) network_dir = os.path.join(dir, params['ntwkStoreDir'])
if not os.path.exists(network_dir): if not os.path.exists(network_dir):
os.makedirs(network_dir) os.makedirs(network_dir)
self.model_checkpoint = os.path.join(network_dir, 'model.pt') self.model_checkpoint = os.path.join(network_dir, self.checkpoint_name)
del params['ntwkStoreDir'] del params['ntwkStoreDir']
self.max_nonimprove = params['tolr'] self.max_nonimprove = params['tolr']
del params['tolr'] del params['tolr']
......
...@@ -28,10 +28,13 @@ if __name__ == '__main__': ...@@ -28,10 +28,13 @@ if __name__ == '__main__':
default=None) default=None)
parser.add_argument('--cache-aev', dest='cache_aev', action='store_true', parser.add_argument('--cache-aev', dest='cache_aev', action='store_true',
help='Whether to cache AEV', default=None) help='Whether to cache AEV', default=None)
parser.add_argument('--checkpoint_name',
help='Name of checkpoint file',
default='model.pt')
parser = parser.parse_args() parser = parser.parse_args()
d = torch.device(parser.device) d = torch.device(parser.device)
trainer = Trainer(parser.config_path, d, parser.tqdm, parser.tensorboard, trainer = Trainer(parser.config_path, d, parser.tqdm, parser.tensorboard,
parser.cache_aev) parser.cache_aev, parser.checkpoint_name)
trainer.load_data(parser.training_path, parser.validation_path) trainer.load_data(parser.training_path, parser.validation_path)
trainer.run() trainer.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment