Unverified Commit f1013390 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Adding back the missing softmax in DARTS and support deduplication (#3224)

parent 9e5e0e3c
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ..interface import BaseOneShotTrainer from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
...@@ -17,13 +18,14 @@ _logger = logging.getLogger(__name__) ...@@ -17,13 +18,14 @@ _logger = logging.getLogger(__name__)
class DartsLayerChoice(nn.Module): class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice): def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__() super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
self.op_choices = nn.ModuleDict(layer_choice.named_children()) self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3) self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()]) op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * self.alpha.view(*alpha_shape), 0) return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self): def parameters(self):
for _, p in self.named_parameters(): for _, p in self.named_parameters():
...@@ -42,13 +44,14 @@ class DartsLayerChoice(nn.Module): ...@@ -42,13 +44,14 @@ class DartsLayerChoice(nn.Module):
class DartsInputChoice(nn.Module): class DartsInputChoice(nn.Module):
def __init__(self, input_choice): def __init__(self, input_choice):
super(DartsInputChoice, self).__init__() super(DartsInputChoice, self).__init__()
self.name = input_choice.key
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3) self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1 self.n_chosen = input_choice.n_chosen or 1
def forward(self, inputs): def forward(self, inputs):
inputs = torch.stack(inputs) inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1) alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * self.alpha.view(*alpha_shape), 0) return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self): def parameters(self):
for _, p in self.named_parameters(): for _, p in self.named_parameters():
...@@ -123,7 +126,15 @@ class DartsTrainer(BaseOneShotTrainer): ...@@ -123,7 +126,15 @@ class DartsTrainer(BaseOneShotTrainer):
module.to(self.device) module.to(self.device)
self.model_optim = optimizer self.model_optim = optimizer
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate, betas=(0.5, 0.999), # use the same architecture weight for modules with duplicated names
ctrl_params = {}
for _, m in self.nas_modules:
if m.name in ctrl_params:
assert m.alpha.size() == ctrl_params[m.name].size(), 'Size of parameters with the same label should be same.'
m.alpha = ctrl_params[m.name]
else:
ctrl_params[m.name] = m.alpha
self.ctrl_optim = torch.optim.Adam(list(ctrl_params.values()), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3) weight_decay=1.0E-3)
self.unrolled = unrolled self.unrolled = unrolled
self.grad_clip = 5. self.grad_clip = 5.
......
...@@ -157,6 +157,7 @@ class ProxylessTrainer(BaseOneShotTrainer): ...@@ -157,6 +157,7 @@ class ProxylessTrainer(BaseOneShotTrainer):
module.to(self.device) module.to(self.device)
self.optimizer = optimizer self.optimizer = optimizer
# we do not support deduplicate control parameters with same label (like DARTS) yet.
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate, self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate,
weight_decay=0, betas=(0, 0.999), eps=1e-8) weight_decay=0, betas=(0, 0.999), eps=1e-8)
self._init_dataloader() self._init_dataloader()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment