Commit 61b452e8 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merging latest master

parents fb7d4e1d 6143b30f
......@@ -5,6 +5,10 @@ def variable_is_tensor():
v = torch.autograd.Variable()
return isinstance(v, torch.Tensor)
def tensor_is_variable():
x = torch.Tensor()
return type(x) == torch.autograd.Variable
# False for post-0.4
def tensor_is_float_tensor():
x = torch.Tensor()
......
......@@ -5,7 +5,7 @@ import importlib
import torch
if compat.variable_is_tensor():
if compat.variable_is_tensor() and not compat.tensor_is_variable():
MODULE = torch.Tensor
else:
MODULE = torch.autograd.Variable
......
......@@ -8,8 +8,6 @@ import torch
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
# Should happen only pre-0.4
assert not compat.variable_is_tensor()
return
orig_fn = utils.get_func(mod, fn)
......@@ -140,7 +138,7 @@ def rnn_cast(backend, fn, verbose=False):
# autograd graph correctly backprops from the wgrads computed
# inside cuDNN (on fp16 weights) into the fp32 weights.
assert utils.type_string(flat_weight) == 'FloatTensor'
if compat.tensor_is_float_tensor():
if compat.tensor_is_float_tensor() or compat.tensor_is_variable():
# Pre-0.4. A little slower, since it zeros out memory.
flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
else:
......
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
class LARC(object):
def __init__(self, optimizer, trust_coefficient=0.02, epsilon=1e-8):
self.param_groups = optimizer.param_groups
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.eps = epsilon
def __getstate__(self):
return self.optim.__getstate__()
def __setstate__(self, state):
self.optim.__setstate__(state)
def __repr__(self):
return self.optim.__repr__()
def state_dict(self):
return self.optim.state_dict()
def load_state_dict(self, state_dict):
self.optim.load_state_dict(state_dict)
def zero_grad(self):
self.optim.zero_grad()
def add_param_group(self, param_group):
self.optim.add_param_group( param_group)
def step(self):
with torch.no_grad():
weight_decays = []
for group in self.optim.param_groups:
# absorb weight decay control from optimizer
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
weight_decays.append(weight_decay)
group['weight_decay'] = 0
for p in group['params']:
if p.grad is None:
continue
param_norm = torch.norm(p.data)
# calculate adaptive lr + weight decay
adaptive_lr = (param_norm + self.eps) / (torch.norm(p.grad.data) + param_norm * weight_decay + self.eps)
p.grad.data += weight_decay * p.data
p.grad.data *= self.trust_coefficient * adaptive_lr
self.optim.step()
# return weight decay control to optimizer
for i, group in enumerate(self.optim.param_groups):
group['weight_decay'] = weight_decays[i]
......@@ -13,7 +13,7 @@ argslist = list(sys.argv)[1:]
world_size = torch.cuda.device_count()
if '--world-size' in argslist:
argslist[argslist.index('--world-size')+1] = str(world_size)
world_size = int(argslist[argslist.index('--world-size')+1])
else:
argslist.append('--world-size')
argslist.append(str(world_size))
......
......@@ -300,6 +300,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
loss.backward()
optimizer.step()
torch.cuda.synchronize()
# measure elapsed time
batch_time.update(time.time() - end)
......@@ -309,11 +310,15 @@ def train(train_loader, model, criterion, optimizer, epoch):
if args.rank == 0 and i % args.print_freq == 0 and i > 1:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
epoch, i, len(train_loader),
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
......@@ -362,10 +367,14 @@ def validate(val_loader, model, criterion):
if args.rank == 0 and i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {2:.3f} ({3:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
i, len(val_loader),
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
input, target = prefetcher.next()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment