Commit adbb02f6 authored by Zeqiang Lai's avatar Zeqiang Lai Committed by zhe chen
Browse files

Support ema for main_deepspeed.py, fix torch.distribute.launch (#88)

parent 1c6361d8
import torch
import torch.nn as nn
import deepspeed
from deepspeed.runtime.zero import GatheredParameters
from contextlib import contextmanager
class EMADeepspeed(nn.Module):
""" migrated from https://github.com/microsoft/DeepSpeed/issues/2056
"""
def __init__(self, model, decay=0.9999, use_num_updates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.decay = decay
self.num_updates = 0 if use_num_updates else -1
with GatheredParameters(model.parameters(), fwd_module=self):
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
# remove as '.'-character is not allowed in buffers
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
shadow_params = dict(self.named_buffers())
with torch.no_grad():
with GatheredParameters(model.parameters()):
if deepspeed.comm.get_rank() == 0:
m_param = dict(model.named_parameters())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
shadow_params = dict(self.named_buffers())
with GatheredParameters(model.parameters(), modifier_rank=0):
if deepspeed.comm.get_rank() == 0:
m_param = dict(model.named_parameters())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, model):
"""
Save the current parameters for restoring later.
Args:
model: A model that parameters will be stored
"""
with GatheredParameters(model.parameters()):
if deepspeed.comm.get_rank() == 0:
parameters = model.parameters()
self.collected_params = [param.clone() for param in parameters]
def restore(self, model):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
model: A model that to restore its parameters.
"""
with GatheredParameters(model.parameters(), modifier_rank=0):
if deepspeed.comm.get_rank() == 0:
parameters = model.parameters()
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
@contextmanager
def activate(self, model):
try:
self.store(model)
self.copy_to(model)
yield
finally:
self.restore(model)
...@@ -131,7 +131,7 @@ def parse_option(): ...@@ -131,7 +131,7 @@ def parse_option():
help="whether to use ZeroRedundancyOptimizer (ZeRO) to save memory") help="whether to use ZeroRedundancyOptimizer (ZeRO) to save memory")
# distributed training # distributed training
parser.add_argument("--local_rank", parser.add_argument("--local-rank",
type=int, type=int,
required=True, required=True,
help='local rank for DistributedDataParallel') help='local rank for DistributedDataParallel')
......
...@@ -27,7 +27,7 @@ from optimizer import set_weight_decay_and_lr ...@@ -27,7 +27,7 @@ from optimizer import set_weight_decay_and_lr
from logger import create_logger from logger import create_logger
from utils import load_pretrained, reduce_tensor, MyAverageMeter from utils import load_pretrained, reduce_tensor, MyAverageMeter
from ddp_hooks import fp16_compress_hook from ddp_hooks import fp16_compress_hook
from ema_deepspeed import EMADeepspeed
def parse_option(): def parse_option():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -57,7 +57,7 @@ def parse_option(): ...@@ -57,7 +57,7 @@ def parse_option():
parser.add_argument('--accumulation-steps', type=int, default=1, help="gradient accumulation steps") parser.add_argument('--accumulation-steps', type=int, default=1, help="gradient accumulation steps")
# distributed training # distributed training
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') parser.add_argument("--local-rank", type=int, required=True, help='local rank for DistributedDataParallel')
parser.add_argument('--disable-grad-scalar', action='store_true', help='disable Grad Scalar') parser.add_argument('--disable-grad-scalar', action='store_true', help='disable Grad Scalar')
args, unparsed = parser.parse_known_args() args, unparsed = parser.parse_known_args()
...@@ -211,7 +211,7 @@ def throughput(data_loader, model, logger): ...@@ -211,7 +211,7 @@ def throughput(data_loader, model, logger):
return return
def train_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): def train_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, model_ema=None):
model.train() model.train()
num_steps = len(data_loader) num_steps = len(data_loader)
...@@ -237,6 +237,9 @@ def train_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_f ...@@ -237,6 +237,9 @@ def train_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_f
model.backward(loss) model.backward(loss)
model.step() model.step()
if model_ema is not None:
model_ema(model)
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
lr_scheduler.step_update(epoch * num_steps + idx) lr_scheduler.step_update(epoch * num_steps + idx)
...@@ -348,9 +351,14 @@ def train(config, ds_config): ...@@ -348,9 +351,14 @@ def train(config, ds_config):
lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
criterion = build_criterion(config) criterion = build_criterion(config)
model_ema = None
if config.TRAIN.EMA.ENABLE:
model_ema = EMADeepspeed(model, config.TRAIN.EMA.DECAY)
# -------------- resume ---------------- # # -------------- resume ---------------- #
max_accuracy = 0.0 max_accuracy = 0.0
max_accuracy_ema = 0.0
client_state = {} client_state = {}
if config.MODEL.RESUME == '' and config.TRAIN.AUTO_RESUME: if config.MODEL.RESUME == '' and config.TRAIN.AUTO_RESUME:
if os.path.exists(os.path.join(config.OUTPUT, 'latest')): if os.path.exists(os.path.join(config.OUTPUT, 'latest')):
...@@ -367,6 +375,10 @@ def train(config, ds_config): ...@@ -367,6 +375,10 @@ def train(config, ds_config):
logger.info(f'client_state={client_state.keys()}') logger.info(f'client_state={client_state.keys()}')
lr_scheduler.load_state_dict(client_state['custom_lr_scheduler']) lr_scheduler.load_state_dict(client_state['custom_lr_scheduler'])
max_accuracy = client_state['max_accuracy'] max_accuracy = client_state['max_accuracy']
if model_ema is not None:
max_accuracy_ema = client_state.get('max_accuracy_ema', 0.0)
model_ema.load_state_dict((client_state['model_ema']))
# -------------- training ---------------- # # -------------- training ---------------- #
...@@ -378,9 +390,11 @@ def train(config, ds_config): ...@@ -378,9 +390,11 @@ def train(config, ds_config):
log_model_statistic(model_without_ddp) log_model_statistic(model_without_ddp)
start_time = time.time() start_time = time.time()
for epoch in range(client_state.get('epoch', config.TRAIN.START_EPOCH), config.TRAIN.EPOCHS): start_epoch = client_state['epoch'] + 1 if 'epoch' in client_state else config.TRAIN.START_EPOCH
for epoch in range(start_epoch, config.TRAIN.EPOCHS):
data_loader_train.sampler.set_epoch(epoch) data_loader_train.sampler.set_epoch(epoch)
train_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) train_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
model_ema=model_ema)
if epoch % config.SAVE_FREQ == 0 or epoch == config.TRAIN.EPOCHS - 1: if epoch % config.SAVE_FREQ == 0 or epoch == config.TRAIN.EPOCHS - 1:
model.save_checkpoint( model.save_checkpoint(
...@@ -390,13 +404,16 @@ def train(config, ds_config): ...@@ -390,13 +404,16 @@ def train(config, ds_config):
'custom_lr_scheduler': lr_scheduler.state_dict(), 'custom_lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy, 'max_accuracy': max_accuracy,
'epoch': epoch, 'epoch': epoch,
'config': config 'config': config,
'max_accuracy_ema': max_accuracy_ema if model_ema is not None else 0.0,
'model_ema': model_ema.state_dict() if model_ema is not None else None,
} }
) )
if epoch % config.EVAL_FREQ == 0: if epoch % config.EVAL_FREQ == 0:
acc1, _, _ = eval_epoch(config, data_loader_val, model, epoch) acc1, _, _ = eval_epoch(config, data_loader_val, model, epoch)
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
if acc1 > max_accuracy: if acc1 > max_accuracy:
model.save_checkpoint( model.save_checkpoint(
save_dir=config.OUTPUT, save_dir=config.OUTPUT,
...@@ -405,13 +422,22 @@ def train(config, ds_config): ...@@ -405,13 +422,22 @@ def train(config, ds_config):
'custom_lr_scheduler': lr_scheduler.state_dict(), 'custom_lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy, 'max_accuracy': max_accuracy,
'epoch': epoch, 'epoch': epoch,
'config': config 'config': config,
'max_accuracy_ema': max_accuracy_ema if model_ema is not None else 0.0,
'model_ema': model_ema.state_dict() if model_ema is not None else None,
} }
) )
max_accuracy = max(max_accuracy, acc1) max_accuracy = max(max_accuracy, acc1)
logger.info(f'Max accuracy: {max_accuracy:.2f}%') logger.info(f'Max accuracy: {max_accuracy:.2f}%')
if model_ema is not None:
with model_ema.activate(model):
acc1_ema, _, _ = eval_epoch(config, data_loader_val, model, epoch)
logger.info(f"[EMA] Accuracy of the network on the {len(dataset_val)} test images: {acc1_ema:.1f}%")
max_accuracy_ema = max(max_accuracy_ema, acc1_ema)
logger.info(f'[EMA] Max accuracy: {max_accuracy_ema:.2f}%')
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time {}'.format(total_time_str)) logger.info('Training time {}'.format(total_time_str))
...@@ -453,7 +479,7 @@ if __name__ == '__main__': ...@@ -453,7 +479,7 @@ if __name__ == '__main__':
args, config = parse_option() args, config = parse_option()
# init distributed env # init distributed env
if 'SLURM_PROCID' in os.environ: if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE']) != 1:
print("\nDist init: SLURM") print("\nDist init: SLURM")
rank = int(os.environ['SLURM_PROCID']) rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count() gpu = rank % torch.cuda.device_count()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment