architect.py 3.91 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable


def _concat(xs):
  return torch.cat([x.view(-1) for x in xs])


def _clip(grads, max_norm):
    total_norm = 0
    for g in grads:
        param_norm = g.data.norm(2)
        total_norm += param_norm ** 2
    total_norm = total_norm ** 0.5
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for g in grads:
            g.data.mul_(clip_coef)
    return clip_coef


class Architect(object):

  def __init__(self, model, args):
    self.network_weight_decay = args.wdecay
    self.network_clip = args.clip
    self.model = model
    self.optimizer = torch.optim.Adam(self.model.arch_parameters(), lr=args.arch_lr, weight_decay=args.arch_wdecay)

  def _compute_unrolled_model(self, hidden, input, target, eta):
    loss, hidden_next = self.model._loss(hidden, input, target)
    theta = _concat(self.model.parameters()).data
    grads = torch.autograd.grad(loss, self.model.parameters())
    clip_coef = _clip(grads, self.network_clip)
    dtheta = _concat(grads).data + self.network_weight_decay*theta
    unrolled_model = self._construct_model_from_theta(theta.sub(eta, dtheta))
    return unrolled_model, clip_coef

  def step(self,
          hidden_train, input_train, target_train,
          hidden_valid, input_valid, target_valid,
          network_optimizer, unrolled):
    eta = network_optimizer.param_groups[0]['lr']
    self.optimizer.zero_grad()
    if unrolled:
        hidden = self._backward_step_unrolled(hidden_train, input_train, target_train, hidden_valid, input_valid, target_valid, eta)
    else:
        hidden = self._backward_step(hidden_valid, input_valid, target_valid)
    self.optimizer.step()
    return hidden, None

  def _backward_step(self, hidden, input, target):
    loss, hidden_next = self.model._loss(hidden, input, target)
    loss.backward()
    return hidden_next

  def _backward_step_unrolled(self,
          hidden_train, input_train, target_train,
          hidden_valid, input_valid, target_valid, eta):
    unrolled_model, clip_coef = self._compute_unrolled_model(hidden_train, input_train, target_train, eta)
    unrolled_loss, hidden_next = unrolled_model._loss(hidden_valid, input_valid, target_valid)

    unrolled_loss.backward()
    dalpha = [v.grad for v in unrolled_model.arch_parameters()]
    dtheta = [v.grad for v in unrolled_model.parameters()]
    _clip(dtheta, self.network_clip)
    vector = [dt.data for dt in dtheta]
    implicit_grads = self._hessian_vector_product(vector, hidden_train, input_train, target_train, r=1e-2)

    for g, ig in zip(dalpha, implicit_grads):
      g.data.sub_(eta * clip_coef, ig.data)

    for v, g in zip(self.model.arch_parameters(), dalpha):
      if v.grad is None:
        v.grad = Variable(g.data)
      else:
        v.grad.data.copy_(g.data)
    return hidden_next

  def _construct_model_from_theta(self, theta):
    model_new = self.model.new()
    model_dict = self.model.state_dict()

    params, offset = {}, 0
    for k, v in self.model.named_parameters():
      v_length = np.prod(v.size())
      params[k] = theta[offset: offset+v_length].view(v.size())
      offset += v_length

    assert offset == len(theta)
    model_dict.update(params)
    model_new.load_state_dict(model_dict)
    return model_new.cuda()

  def _hessian_vector_product(self, vector, hidden, input, target, r=1e-2):
    R = r / _concat(vector).norm()
    for p, v in zip(self.model.parameters(), vector):
      p.data.add_(R, v)
    loss, _ = self.model._loss(hidden, input, target)
    grads_p = torch.autograd.grad(loss, self.model.arch_parameters())

    for p, v in zip(self.model.parameters(), vector):
      p.data.sub_(2*R, v)
    loss, _ = self.model._loss(hidden, input, target)
    grads_n = torch.autograd.grad(loss, self.model.arch_parameters())

    for p, v in zip(self.model.parameters(), vector):
      p.data.add_(R, v)

    return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]