Unverified Commit 59521d33 authored by DU Hao's avatar DU Hao Committed by GitHub
Browse files

cream codes update (#3383)

parent c3b60511
import re
import math import math
import torch.nn as nn import torch.nn as nn
from copy import deepcopy
from timm.utils import * from timm.utils import *
from timm.models.layers.activations import Swish from timm.models.layers.activations import Swish
from timm.models.layers import CondConv2d, get_condconv_initializer from timm.models.layers import CondConv2d, get_condconv_initializer
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com # email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import sys import sys
import logging
import torch
import argparse import argparse
import torch.nn as nn import torch.nn as nn
......
...@@ -83,7 +83,7 @@ def main(): ...@@ -83,7 +83,7 @@ def main():
'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25',
'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25',
'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
'ir_r1_k3_s2_e6_c192_se0.25'] 'ir_r1_k5_s2_e6_c192_se0.25']
arch_def = [[stem[0]]] + [[choice_block_pool[idx] arch_def = [[stem[0]]] + [[choice_block_pool[idx]
for repeat_times in range(len(arch_list[idx + 1]))] for repeat_times in range(len(arch_list[idx + 1]))]
for idx in range(len(choice_block_pool))] + [[stem[1]]] for idx in range(len(choice_block_pool))] + [[stem[1]]]
......
...@@ -230,7 +230,7 @@ class CreamSupernetTrainer(Trainer): ...@@ -230,7 +230,7 @@ class CreamSupernetTrainer(Trainer):
self.optimizer.zero_grad() self.optimizer.zero_grad()
grad_student_val = torch.autograd.grad( grad_student_val = torch.autograd.grad(
validation_loss, validation_loss,
self.model.module.rand_parameters(self.random_cand), self.model.module.rand_parameters(self.current_student_arch),
retain_graph=True) retain_graph=True)
grad_teacher = torch.autograd.grad( grad_teacher = torch.autograd.grad(
...@@ -385,7 +385,7 @@ class CreamSupernetTrainer(Trainer): ...@@ -385,7 +385,7 @@ class CreamSupernetTrainer(Trainer):
step + 1, len(self.train_loader), meters) step + 1, len(self.train_loader), meters)
if self.main_proc and self.num_epochs == epoch + 1: if self.main_proc and self.num_epochs == epoch + 1:
for idx, i in enumerate(self.best_children_pool): for idx, i in enumerate(self.prioritized_board):
logger.info("No.%s %s", idx, i[:4]) logger.info("No.%s %s", idx, i[:4])
def validate_one_epoch(self, epoch): def validate_one_epoch(self, epoch):
...@@ -396,9 +396,9 @@ class CreamSupernetTrainer(Trainer): ...@@ -396,9 +396,9 @@ class CreamSupernetTrainer(Trainer):
self.mutator.reset() self.mutator.reset()
logits = self.model(x) logits = self.model(x)
loss = self.val_loss(logits, y) loss = self.val_loss(logits, y)
prec1, prec5 = self.accuracy(logits, y, topk=(1, 5)) prec1, prec5 = accuracy(logits, y, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = self.reduce_metrics(metrics, self.distributed) metrics = reduce_metrics(metrics)
meters.update(metrics) meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
......
...@@ -17,7 +17,7 @@ def accuracy(output, target, topk=(1,)): ...@@ -17,7 +17,7 @@ def accuracy(output, target, topk=(1,)):
if target.ndimension() > 1: if target.ndimension() > 1:
target = target.max(1)[1] target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred)) correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = [] res = []
for k in topk: for k in topk:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment