"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "bc410d52cc02fbf6b4e5eab9eb4e2c0a0ae6aa47"
Unverified Commit f1013390 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Adding back the missing softmax in DARTS and support deduplication (#3224)

parent 9e5e0e3c
......@@ -6,6 +6,7 @@ import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
......@@ -17,13 +18,14 @@ _logger = logging.getLogger(__name__)
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs):
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * self.alpha.view(*alpha_shape), 0)
return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
......@@ -42,13 +44,14 @@ class DartsLayerChoice(nn.Module):
class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.name = input_choice.key
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1
def forward(self, inputs):
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * self.alpha.view(*alpha_shape), 0)
return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
......@@ -123,7 +126,15 @@ class DartsTrainer(BaseOneShotTrainer):
module.to(self.device)
self.model_optim = optimizer
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate, betas=(0.5, 0.999),
# use the same architecture weight for modules with duplicated names
ctrl_params = {}
for _, m in self.nas_modules:
if m.name in ctrl_params:
assert m.alpha.size() == ctrl_params[m.name].size(), 'Size of parameters with the same label should be same.'
m.alpha = ctrl_params[m.name]
else:
ctrl_params[m.name] = m.alpha
self.ctrl_optim = torch.optim.Adam(list(ctrl_params.values()), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3)
self.unrolled = unrolled
self.grad_clip = 5.
......
......@@ -157,6 +157,7 @@ class ProxylessTrainer(BaseOneShotTrainer):
module.to(self.device)
self.optimizer = optimizer
# we do not support deduplicate control parameters with same label (like DARTS) yet.
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate,
weight_decay=0, betas=(0, 0.999), eps=1e-8)
self._init_dataloader()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment