Unverified Commit 1a5c0172 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #224 from microsoft/master

merge master
parents b9a7a95d ae81ec47
[Documentation](https://nni.readthedocs.io/en/latest/NAS/ENAS.html)
This is a naive example that demonstrates how to use NNI interface to implement a NAS search space.
\ No newline at end of file
[Documentation](https://nni.readthedocs.io/en/latest/NAS/PDARTS.html)
# Single Path One-Shot Neural Architecture Search with Uniform Sampling
Single Path One-Shot by Megvii Research. [Paper link](https://arxiv.org/abs/1904.00420). [Official repo](https://github.com/megvii-model/SinglePathOneShot).
Block search only. Channel search is not supported yet.
Only GPU version is provided here.
## Preparation
### Requirements
* PyTorch >= 1.2
* NVIDIA DALI >= 0.16 as we use DALI to accelerate the data loading of ImageNet. [Installation guide](https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/docs/installation.html)
### Data
Need to download the flops lookup table from [here](https://1drv.ms/u/s!Am_mmG2-KsrnajesvSdfsq_cN48?e=aHVppN).
Put `op_flops_dict.pkl` and `checkpoint-150000.pth.tar` (if you don't want to retrain the supernet) under `data` directory.
Prepare ImageNet in the standard format (follow the script [here](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4)). Link it to `data/imagenet` will be more convenient.
After preparation, it's expected to have the following code structure:
```
spos
├── architecture_final.json
├── blocks.py
├── config_search.yml
├── data
│   ├── imagenet
│   │   ├── train
│   │   └── val
│   └── op_flops_dict.pkl
├── dataloader.py
├── network.py
├── readme.md
├── scratch.py
├── supernet.py
├── tester.py
├── tuner.py
└── utils.py
```
## Step 1. Train Supernet
```
python supernet.py
```
Will export the checkpoint to checkpoints directory, for the next step.
NOTE: The data loading used in the official repo is [slightly different from usual](https://github.com/megvii-model/SinglePathOneShot/issues/5), as they use BGR tensor and keep the values between 0 and 255 intentionally to align with their own DL framework. The option `--spos-preprocessing` will simulate the behavior used originally and enable you to use the checkpoints pretrained.
## Step 2. Evolution Search
Single Path One-Shot leverages evolution algorithm to search for the best architecture. The tester, which is responsible for testing the sampled architecture, recalculates all the batch norm for a subset of training images, and evaluates the architecture on the full validation set.
To have a search space ready for NNI framework, first run
```
nnictl ss_gen -t "python tester.py"
```
This will generate a file called `nni_auto_gen_search_space.json`, which is a serialized representation of your search space.
Then search with evolution tuner.
```
nnictl create --config config_search.yml
```
The final architecture exported from every epoch of evolution can be found in `checkpoints` under the working directory of your tuner, which, by default, is `$HOME/nni/experiments/your_experiment_id/log`.
## Step 3. Train from Scratch
```
python scratch.py
```
By default, it will use `architecture_final.json`. This architecture is provided by the official repo (converted into NNI format). You can use any architecture (e.g., the architecture found in step 2) with `--fixed-arc` option.
## Current Reproduction Results
Reproduction is still undergoing. Due to the gap between official release and original paper, we compare our current results with official repo (our run) and paper.
* Evolution phase is almost aligned with official repo. Our evolution algorithm shows a converging trend and reaches ~65% accuracy at the end of search. Nevertheless, this result is not on par with paper. For details, please refer to [this issue](https://github.com/megvii-model/SinglePathOneShot/issues/6).
* Retrain phase is not aligned. Our retraining code, which uses the architecture released by the authors, reaches 72.14% accuracy, still having a gap towards 73.61% by official release and 74.3% reported in original paper.
[Documentation](https://nni.readthedocs.io/en/latest/NAS/SPOS.html)
......@@ -11,6 +11,6 @@ tuner:
classFileName: tuner.py
className: EvolutionWithFlops
trial:
command: python tester.py --imagenet-dir /path/to/your/imagenet --spos-prep
command: python tester.py --spos-prep
codeDir: .
gpuNum: 1
# TextNAS: A Neural Architecture Search Space tailored for Text Representation
TextNAS by MSRA. Official Release.
[Paper link](https://arxiv.org/abs/1912.10729)
## Preparation
Prepare the word vectors and SST dataset, and organize them in data directory as shown below:
```
textnas
├── data
│ ├── sst
│ │ └── trees
│ │ ├── dev.txt
│ │ ├── test.txt
│ │ └── train.txt
│ └── glove.840B.300d.txt
├── dataloader.py
├── model.py
├── ops.py
├── README.md
├── search.py
└── utils.py
```
The following link might be helpful for finding and downloading the corresponding dataset:
* [GloVe: Global Vectors for Word Representation](https://nlp.stanford.edu/projects/glove/)
* [Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank](https://nlp.stanford.edu/sentiment/)
## Search
```
python search.py
```
After each search epoch, 10 sampled architectures will be tested directly. Their performances are expected to be 40% - 42% after 10 epochs.
By default, 20 sampled architectures will be exported into `checkpoints` directory for next step.
## Retrain
Not ready.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import pickle
from collections import Counter
import numpy as np
import torch
from torch.utils import data
logger = logging.getLogger("nni.textnas")
class PTBTree:
WORD_TO_WORD_MAPPING = {
"{": "-LCB-",
"}": "-RCB-"
}
def __init__(self):
self.subtrees = []
self.word = None
self.label = ""
self.parent = None
self.span = (-1, -1)
self.word_vector = None # HOS, store dx1 RNN word vector
self.prediction = None # HOS, store Kx1 prediction vector
def is_leaf(self):
return len(self.subtrees) == 0
def set_by_text(self, text, pos=0, left=0):
depth = 0
right = left
for i in range(pos + 1, len(text)):
char = text[i]
# update the depth
if char == "(":
depth += 1
if depth == 1:
subtree = PTBTree()
subtree.parent = self
subtree.set_by_text(text, i, right)
right = subtree.span[1]
self.span = (left, right)
self.subtrees.append(subtree)
elif char == ")":
depth -= 1
if len(self.subtrees) == 0:
pos = i
for j in range(i, 0, -1):
if text[j] == " ":
pos = j
break
self.word = text[pos + 1:i]
self.span = (left, left + 1)
# we've reached the end of the category that is the root of this subtree
if depth == 0 and char == " " and self.label == "":
self.label = text[pos + 1:i]
# we've reached the end of the scope for this bracket
if depth < 0:
break
# Fix some issues with variation in output, and one error in the treebank
# for a word with a punctuation POS
self.standardise_node()
def standardise_node(self):
if self.word in self.WORD_TO_WORD_MAPPING:
self.word = self.WORD_TO_WORD_MAPPING[self.word]
def __repr__(self, single_line=True, depth=0):
ans = ""
if not single_line and depth > 0:
ans = "\n" + depth * "\t"
ans += "(" + self.label
if self.word is not None:
ans += " " + self.word
for subtree in self.subtrees:
if single_line:
ans += " "
ans += subtree.__repr__(single_line, depth + 1)
ans += ")"
return ans
def read_tree(source):
cur_text = []
depth = 0
while True:
line = source.readline()
# Check if we are out of input
if line == "":
return None
# strip whitespace and only use if this contains something
line = line.strip()
if line == "":
continue
cur_text.append(line)
# Update depth
for char in line:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
# At depth 0 we have a complete tree
if depth == 0:
tree = PTBTree()
tree.set_by_text(" ".join(cur_text))
return tree
return None
def read_trees(source, max_sents=-1):
with open(source) as fp:
trees = []
while True:
tree = read_tree(fp)
if tree is None:
break
trees.append(tree)
if len(trees) >= max_sents > 0:
break
return trees
class SSTDataset(data.Dataset):
def __init__(self, sents, mask, labels):
self.sents = sents
self.labels = labels
self.mask = mask
def __getitem__(self, index):
return (self.sents[index], self.mask[index]), self.labels[index]
def __len__(self):
return len(self.sents)
def sst_get_id_input(content, word_id_dict, max_input_length):
words = content.split(" ")
sentence = [word_id_dict["<pad>"]] * max_input_length
mask = [0] * max_input_length
unknown = word_id_dict["<unknown>"]
for i, word in enumerate(words[:max_input_length]):
sentence[i] = word_id_dict.get(word, unknown)
mask[i] = 1
return sentence, mask
def sst_get_phrases(trees, sample_ratio=1.0, is_binary=False, only_sentence=False):
all_phrases = []
for tree in trees:
if only_sentence:
sentence = get_sentence_by_tree(tree)
label = int(tree.label)
pair = (sentence, label)
all_phrases.append(pair)
else:
phrases = get_phrases_by_tree(tree)
sentence = get_sentence_by_tree(tree)
pair = (sentence, int(tree.label))
all_phrases.append(pair)
all_phrases += phrases
if sample_ratio < 1.:
np.random.shuffle(all_phrases)
result_phrases = []
for pair in all_phrases:
if is_binary:
phrase, label = pair
if label <= 1:
pair = (phrase, 0)
elif label >= 3:
pair = (phrase, 1)
else:
continue
if sample_ratio == 1.:
result_phrases.append(pair)
else:
rand_portion = np.random.random()
if rand_portion < sample_ratio:
result_phrases.append(pair)
return result_phrases
def get_phrases_by_tree(tree):
phrases = []
if tree is None:
return phrases
if tree.is_leaf():
pair = (tree.word, int(tree.label))
phrases.append(pair)
return phrases
left_child_phrases = get_phrases_by_tree(tree.subtrees[0])
right_child_phrases = get_phrases_by_tree(tree.subtrees[1])
phrases.extend(left_child_phrases)
phrases.extend(right_child_phrases)
sentence = get_sentence_by_tree(tree)
pair = (sentence, int(tree.label))
phrases.append(pair)
return phrases
def get_sentence_by_tree(tree):
if tree is None:
return ""
if tree.is_leaf():
return tree.word
left_sentence = get_sentence_by_tree(tree.subtrees[0])
right_sentence = get_sentence_by_tree(tree.subtrees[1])
sentence = left_sentence + " " + right_sentence
return sentence.strip()
def get_word_id_dict(word_num_dict, word_id_dict, min_count):
z = [k for k in sorted(word_num_dict.keys())]
for word in z:
count = word_num_dict[word]
if count >= min_count:
index = len(word_id_dict)
if word not in word_id_dict:
word_id_dict[word] = index
return word_id_dict
def load_word_num_dict(phrases, word_num_dict):
for sentence, _ in phrases:
words = sentence.split(" ")
for cur_word in words:
word = cur_word.strip()
word_num_dict[word] += 1
return word_num_dict
def init_trainable_embedding(embedding_path, word_id_dict, embed_dim=300):
word_embed_model = load_glove_model(embedding_path, embed_dim)
assert word_embed_model["pool"].shape[1] == embed_dim
embedding = np.random.random([len(word_id_dict), embed_dim]).astype(np.float32) / 2.0 - 0.25
embedding[0] = np.zeros(embed_dim) # PAD
embedding[1] = (np.random.rand(embed_dim) - 0.5) / 2 # UNK
for word, idx in word_id_dict.items():
if idx == 0 or idx == 1:
continue
if word in word_embed_model["mapping"]:
embedding[idx] = word_embed_model["pool"][word_embed_model["mapping"][word]]
else:
embedding[idx] = np.random.rand(embed_dim) / 2.0 - 0.25
return embedding
def sst_get_trainable_data(phrases, word_id_dict, max_input_length):
texts, labels, mask = [], [], []
for phrase, label in phrases:
if not phrase.split():
continue
phrase_split, mask_split = sst_get_id_input(phrase, word_id_dict, max_input_length)
texts.append(phrase_split)
labels.append(int(label))
mask.append(mask_split) # field_input is mask
labels = np.array(labels, dtype=np.int64)
texts = np.reshape(texts, [-1, max_input_length]).astype(np.int32)
mask = np.reshape(mask, [-1, max_input_length]).astype(np.int32)
return SSTDataset(texts, mask, labels)
def load_glove_model(filename, embed_dim):
if os.path.exists(filename + ".cache"):
logger.info("Found cache. Loading...")
with open(filename + ".cache", "rb") as fp:
return pickle.load(fp)
embedding = {"mapping": dict(), "pool": []}
with open(filename) as f:
for i, line in enumerate(f):
line = line.rstrip("\n")
vocab_word, *vec = line.rsplit(" ", maxsplit=embed_dim)
assert len(vec) == 300, "Unexpected line: '%s'" % line
embedding["pool"].append(np.array(list(map(float, vec)), dtype=np.float32))
embedding["mapping"][vocab_word] = i
embedding["pool"] = np.stack(embedding["pool"])
with open(filename + ".cache", "wb") as fp:
pickle.dump(embedding, fp)
return embedding
def read_data_sst(data_path, max_input_length=64, min_count=1, train_with_valid=False,
train_ratio=1., valid_ratio=1., is_binary=False, only_sentence=False):
word_id_dict = dict()
word_num_dict = Counter()
sst_path = os.path.join(data_path, "sst")
logger.info("Reading SST data...")
train_file_name = os.path.join(sst_path, "trees", "train.txt")
valid_file_name = os.path.join(sst_path, "trees", "dev.txt")
test_file_name = os.path.join(sst_path, "trees", "test.txt")
train_trees = read_trees(train_file_name)
train_phrases = sst_get_phrases(train_trees, train_ratio, is_binary, only_sentence)
logger.info("Finish load train phrases.")
valid_trees = read_trees(valid_file_name)
valid_phrases = sst_get_phrases(valid_trees, valid_ratio, is_binary, only_sentence)
logger.info("Finish load valid phrases.")
if train_with_valid:
train_phrases += valid_phrases
test_trees = read_trees(test_file_name)
test_phrases = sst_get_phrases(test_trees, valid_ratio, is_binary, only_sentence=True)
logger.info("Finish load test phrases.")
# get word_id_dict
word_id_dict["<pad>"] = 0
word_id_dict["<unknown>"] = 1
load_word_num_dict(train_phrases, word_num_dict)
logger.info("Finish load train words: %d.", len(word_num_dict))
load_word_num_dict(valid_phrases, word_num_dict)
load_word_num_dict(test_phrases, word_num_dict)
logger.info("Finish load valid+test words: %d.", len(word_num_dict))
word_id_dict = get_word_id_dict(word_num_dict, word_id_dict, min_count)
logger.info("After trim vocab length: %d.", len(word_id_dict))
logger.info("Loading embedding...")
embedding = init_trainable_embedding(os.path.join(data_path, "glove.840B.300d.txt"), word_id_dict)
logger.info("Finish initialize word embedding.")
dataset_train = sst_get_trainable_data(train_phrases, word_id_dict, max_input_length)
logger.info("Loaded %d training samples.", len(dataset_train))
dataset_valid = sst_get_trainable_data(valid_phrases, word_id_dict, max_input_length)
logger.info("Loaded %d validation samples.", len(dataset_valid))
dataset_test = sst_get_trainable_data(test_phrases, word_id_dict, max_input_length)
logger.info("Loaded %d test samples.", len(dataset_test))
return dataset_train, dataset_valid, dataset_test, torch.from_numpy(embedding)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
import torch
import torch.nn as nn
from nni.nas.pytorch import mutables
from ops import ConvBN, LinearCombine, AvgPool, MaxPool, RNN, Attention, BatchNorm
from utils import GlobalMaxPool, GlobalAvgPool
class Layer(mutables.MutableScope):
def __init__(self, key, prev_keys, hidden_units, choose_from_k, cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask):
super(Layer, self).__init__(key)
def conv_shortcut(kernel_size):
return ConvBN(kernel_size, hidden_units, hidden_units, cnn_keep_prob, False, True)
self.n_candidates = len(prev_keys)
if self.n_candidates:
self.prec = mutables.InputChoice(choose_from=prev_keys[-choose_from_k:], n_chosen=1)
else:
# first layer, skip input choice
self.prec = None
self.op = mutables.LayerChoice([
conv_shortcut(1),
conv_shortcut(3),
conv_shortcut(5),
conv_shortcut(7),
AvgPool(3, False, True),
MaxPool(3, False, True),
RNN(hidden_units, lstm_keep_prob),
Attention(hidden_units, 4, att_keep_prob, att_mask)
])
if self.n_candidates:
self.skipconnect = mutables.InputChoice(choose_from=prev_keys)
else:
self.skipconnect = None
self.bn = BatchNorm(hidden_units, False, True)
def forward(self, last_layer, prev_layers, mask):
# pass an extra last_layer to deal with layer 0 (prev_layers is empty)
if self.prec is None:
prec = last_layer
else:
prec = self.prec(prev_layers[-self.prec.n_candidates:]) # skip first
out = self.op(prec, mask)
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[-self.skipconnect.n_candidates:])
if connection is not None:
out += connection
out = self.bn(out, mask)
return out
class Model(nn.Module):
def __init__(self, embedding, hidden_units=256, num_layers=24, num_classes=5, choose_from_k=5,
lstm_keep_prob=0.5, cnn_keep_prob=0.5, att_keep_prob=0.5, att_mask=True,
embed_keep_prob=0.5, final_output_keep_prob=1.0, global_pool="avg"):
super(Model, self).__init__()
self.embedding = nn.Embedding.from_pretrained(embedding, freeze=False)
self.hidden_units = hidden_units
self.num_layers = num_layers
self.num_classes = num_classes
self.init_conv = ConvBN(1, self.embedding.embedding_dim, hidden_units, cnn_keep_prob, False, True)
self.layers = nn.ModuleList()
candidate_keys_pool = []
for layer_id in range(self.num_layers):
k = "layer_{}".format(layer_id)
self.layers.append(Layer(k, candidate_keys_pool, hidden_units, choose_from_k,
cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask))
candidate_keys_pool.append(k)
self.linear_combine = LinearCombine(self.num_layers)
self.linear_out = nn.Linear(self.hidden_units, self.num_classes)
self.embed_dropout = nn.Dropout(p=1 - embed_keep_prob)
self.output_dropout = nn.Dropout(p=1 - final_output_keep_prob)
assert global_pool in ["max", "avg"]
if global_pool == "max":
self.global_pool = GlobalMaxPool()
elif global_pool == "avg":
self.global_pool = GlobalAvgPool()
def forward(self, inputs):
sent_ids, mask = inputs
seq = self.embedding(sent_ids.long())
seq = self.embed_dropout(seq)
seq = torch.transpose(seq, 1, 2) # from (N, L, C) -> (N, C, L)
x = self.init_conv(seq, mask)
prev_layers = []
for layer in self.layers:
x = layer(x, prev_layers, mask)
prev_layers.append(x)
x = self.linear_combine(torch.stack(prev_layers))
x = self.global_pool(x, mask)
x = self.output_dropout(x)
x = self.linear_out(x)
return x
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn.functional as F
from torch import nn
from utils import get_length, INF
class Mask(nn.Module):
def forward(self, seq, mask):
# seq: (N, C, L)
# mask: (N, L)
seq_mask = torch.unsqueeze(mask, 2)
seq_mask = torch.transpose(seq_mask.repeat(1, 1, seq.size()[1]), 1, 2)
return seq.where(torch.eq(seq_mask, 1), torch.zeros_like(seq))
class BatchNorm(nn.Module):
def __init__(self, num_features, pre_mask, post_mask, eps=1e-5, decay=0.9, affine=True):
super(BatchNorm, self).__init__()
self.mask_opt = Mask()
self.pre_mask = pre_mask
self.post_mask = post_mask
self.bn = nn.BatchNorm1d(num_features, eps=eps, momentum=1.0 - decay, affine=affine)
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.bn(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
return seq
class ConvBN(nn.Module):
def __init__(self, kernal_size, in_channels, out_channels, cnn_keep_prob,
pre_mask, post_mask, with_bn=True, with_relu=True):
super(ConvBN, self).__init__()
self.mask_opt = Mask()
self.pre_mask = pre_mask
self.post_mask = post_mask
self.with_bn = with_bn
self.with_relu = with_relu
self.conv = nn.Conv1d(in_channels, out_channels, kernal_size, 1, bias=True, padding=(kernal_size - 1) // 2)
self.dropout = nn.Dropout(p=(1 - cnn_keep_prob))
if with_bn:
self.bn = BatchNorm(out_channels, not post_mask, True)
if with_relu:
self.relu = nn.ReLU()
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.conv(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
if self.with_bn:
seq = self.bn(seq, mask)
if self.with_relu:
seq = self.relu(seq)
seq = self.dropout(seq)
return seq
class AvgPool(nn.Module):
def __init__(self, kernal_size, pre_mask, post_mask):
super(AvgPool, self).__init__()
self.avg_pool = nn.AvgPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
self.pre_mask = pre_mask
self.post_mask = post_mask
self.mask_opt = Mask()
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.avg_pool(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
return seq
class MaxPool(nn.Module):
def __init__(self, kernal_size, pre_mask, post_mask):
super(MaxPool, self).__init__()
self.max_pool = nn.MaxPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
self.pre_mask = pre_mask
self.post_mask = post_mask
self.mask_opt = Mask()
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.max_pool(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
return seq
class Attention(nn.Module):
def __init__(self, num_units, num_heads, keep_prob, is_mask):
super(Attention, self).__init__()
self.num_heads = num_heads
self.keep_prob = keep_prob
self.linear_q = nn.Linear(num_units, num_units)
self.linear_k = nn.Linear(num_units, num_units)
self.linear_v = nn.Linear(num_units, num_units)
self.bn = BatchNorm(num_units, True, is_mask)
self.dropout = nn.Dropout(p=1 - self.keep_prob)
def forward(self, seq, mask):
in_c = seq.size()[1]
seq = torch.transpose(seq, 1, 2) # (N, L, C)
queries = seq
keys = seq
num_heads = self.num_heads
# T_q = T_k = L
Q = F.relu(self.linear_q(seq)) # (N, T_q, C)
K = F.relu(self.linear_k(seq)) # (N, T_k, C)
V = F.relu(self.linear_v(seq)) # (N, T_k, C)
# Split and concat
Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h)
K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
# Multiplication
outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k)
# Scale
outputs = outputs / (K_.size()[-1] ** 0.5)
# Key Masking
key_masks = mask.repeat(num_heads, 1) # (h*N, T_k)
key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k)
key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k)
paddings = torch.ones_like(outputs) * (-INF) # extremely small value
outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs)
query_masks = mask.repeat(num_heads, 1) # (h*N, T_q)
query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1)
query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k)
att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k)
att_scores = self.dropout(att_scores)
# Weighted sum
x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h)
# Restore shape
x_outputs = torch.cat(
torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0),
dim=2) # (N, T_q, C)
x = torch.transpose(x_outputs, 1, 2) # (N, C, L)
x = self.bn(x, mask)
return x
class RNN(nn.Module):
def __init__(self, hidden_size, output_keep_prob):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.bid_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
self.output_keep_prob = output_keep_prob
self.out_dropout = nn.Dropout(p=(1 - self.output_keep_prob))
def forward(self, seq, mask):
# seq: (N, C, L)
# mask: (N, L)
max_len = seq.size()[2]
length = get_length(mask)
seq = torch.transpose(seq, 1, 2) # to (N, L, C)
packed_seq = nn.utils.rnn.pack_padded_sequence(seq, length, batch_first=True,
enforce_sorted=False)
outputs, _ = self.bid_rnn(packed_seq)
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True,
total_length=max_len)[0]
outputs = outputs.view(-1, max_len, 2, self.hidden_size).sum(2) # (N, L, C)
outputs = self.out_dropout(outputs) # output dropout
return torch.transpose(outputs, 1, 2) # back to: (N, C, L)
class LinearCombine(nn.Module):
def __init__(self, layers_num, trainable=True, input_aware=False, word_level=False):
super(LinearCombine, self).__init__()
self.input_aware = input_aware
self.word_level = word_level
if input_aware:
raise NotImplementedError("Input aware is not supported.")
self.w = nn.Parameter(torch.full((layers_num, 1, 1, 1), 1.0 / layers_num),
requires_grad=trainable)
def forward(self, seq):
nw = F.softmax(self.w, dim=0)
seq = torch.mul(seq, nw)
seq = torch.sum(seq, dim=0)
return seq
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import random
from argparse import ArgumentParser
from itertools import cycle
import numpy as np
import torch
import torch.nn as nn
from nni.nas.pytorch.enas import EnasMutator, EnasTrainer
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from dataloader import read_data_sst
from model import Model
from utils import accuracy
logger = logging.getLogger("nni.textnas")
class TextNASTrainer(EnasTrainer):
def __init__(self, *args, train_loader=None, valid_loader=None, test_loader=None, **kwargs):
super().__init__(*args, **kwargs)
self.train_loader = train_loader
self.valid_loader = valid_loader
self.test_loader = test_loader
def init_dataloader(self):
pass
if __name__ == "__main__":
parser = ArgumentParser("textnas")
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=50, type=int)
parser.add_argument("--seed", default=1234, type=int)
parser.add_argument("--epochs", default=10, type=int)
parser.add_argument("--lr", default=5e-3, type=float)
args = parser.parse_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
train_dataset, valid_dataset, test_dataset, embedding = read_data_sst("data")
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4)
train_loader, valid_loader = cycle(train_loader), cycle(valid_loader)
model = Model(embedding)
mutator = EnasMutator(model, temperature=None, tanh_constant=None, entropy_reduction="mean")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-3, weight_decay=2e-6)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-5)
trainer = TextNASTrainer(model,
loss=criterion,
metrics=lambda output, target: {"acc": accuracy(output, target)},
reward_function=accuracy,
optimizer=optimizer,
callbacks=[LRSchedulerCallback(lr_scheduler)],
batch_size=args.batch_size,
num_epochs=args.epochs,
dataset_train=None,
dataset_valid=None,
train_loader=train_loader,
valid_loader=valid_loader,
test_loader=test_loader,
log_frequency=args.log_frequency,
mutator=mutator,
mutator_lr=2e-3,
mutator_steps=500,
mutator_steps_aggregate=1,
child_steps=3000,
baseline_decay=0.99,
test_arc_per_epoch=10)
trainer.train()
os.makedirs("checkpoints", exist_ok=True)
for i in range(20):
trainer.export(os.path.join("checkpoints", "architecture_%02d.json" % i))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
INF = 1E10
EPS = 1E-12
logger = logging.getLogger("nni.textnas")
def get_length(mask):
length = torch.sum(mask, 1)
length = length.long()
return length
class GlobalAvgPool(nn.Module):
def forward(self, x, mask):
x = torch.sum(x, 2)
length = torch.sum(mask, 1, keepdim=True).float()
length += torch.eq(length, 0.0).float() * EPS
length = length.repeat(1, x.size()[1])
x /= length
return x
class GlobalMaxPool(nn.Module):
def forward(self, x, mask):
mask = torch.eq(mask.float(), 0.0).long()
mask = torch.unsqueeze(mask, dim=1).repeat(1, x.size()[1], 1)
mask *= -INF
x += mask
x, _ = torch.max(x + mask, 2)
return x
class IteratorWrapper:
def __init__(self, loader):
self.loader = loader
self.iterator = None
def __iter__(self):
self.iterator = iter(self.loader)
return self
def __len__(self):
return len(self.loader)
def __next__(self):
data = next(self.iterator)
text, length = data.text
max_length = text.size(1)
label = data.label - 1
bs = label.size(0)
mask = torch.arange(max_length, device=length.device).unsqueeze(0).repeat(bs, 1)
mask = mask < length.unsqueeze(-1).repeat(1, max_length)
return (text, mask), label
def accuracy(output, target):
batch_size = target.size(0)
_, predicted = torch.max(output.data, 1)
return (predicted == target).sum().item() / batch_size
......@@ -58,7 +58,7 @@
},
"resolutions": {
"mem": "^4.0.0",
"handlebars": "^4.1.0",
"handlebars": "^4.5.3",
"lodash": "^4.17.13",
"lodash.merge": "^4.6.2",
"node.extend": "^1.1.7",
......
......@@ -9,7 +9,7 @@ import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../
export class PAIClusterConfig {
public readonly userName: string;
public readonly passWord?: string;
public readonly host: string;
public host: string;
public readonly token?: string;
/**
......
......@@ -25,7 +25,7 @@ export class PAIJobInfoCollector {
this.finalStatuses = ['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'];
}
public async retrieveTrialStatus(token? : string, paiBaseClusterConfig?: PAIClusterConfig): Promise<void> {
public async retrieveTrialStatus(protocol: string, token? : string, paiBaseClusterConfig?: PAIClusterConfig): Promise<void> {
if (paiBaseClusterConfig === undefined || token === undefined) {
return Promise.resolve();
}
......@@ -35,13 +35,13 @@ export class PAIJobInfoCollector {
if (paiTrialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`);
}
updatePaiTrialJobs.push(this.getSinglePAITrialJobInfo(paiTrialJob, token, paiBaseClusterConfig));
updatePaiTrialJobs.push(this.getSinglePAITrialJobInfo(protocol, paiTrialJob, token, paiBaseClusterConfig));
}
await Promise.all(updatePaiTrialJobs);
}
private getSinglePAITrialJobInfo(paiTrialJob: PAITrialJobDetail, paiToken: string, paiClusterConfig: PAIClusterConfig): Promise<void> {
private getSinglePAITrialJobInfo(protocol: string, paiTrialJob: PAITrialJobDetail, paiToken: string, paiClusterConfig: PAIClusterConfig): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) {
deferred.resolve();
......@@ -52,7 +52,7 @@ export class PAIJobInfoCollector {
// Rest call to get PAI job info and update status
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const getJobInfoRequest: request.Options = {
uri: `http://${paiClusterConfig.host}/rest-server/api/v1/user/${paiClusterConfig.userName}/jobs/${paiTrialJob.paiJobName}`,
uri: `${protocol}://${paiClusterConfig.host}/rest-server/api/v1/user/${paiClusterConfig.userName}/jobs/${paiTrialJob.paiJobName}`,
method: 'GET',
json: true,
headers: {
......@@ -81,7 +81,11 @@ export class PAIJobInfoCollector {
paiTrialJob.startTime = response.body.jobStatus.appLaunchedTime;
}
if (paiTrialJob.url === undefined) {
if (response.body.jobStatus.appTrackingUrl) {
paiTrialJob.url = response.body.jobStatus.appTrackingUrl;
} else {
paiTrialJob.url = paiTrialJob.logPath;
}
}
break;
case 'SUCCEEDED':
......@@ -114,7 +118,7 @@ export class PAIJobInfoCollector {
}
// Set pai trial job's url to WebHDFS output path
if (paiTrialJob.logPath !== undefined) {
if (paiTrialJob.url) {
if (paiTrialJob.url && paiTrialJob.url !== paiTrialJob.logPath) {
paiTrialJob.url += `,${paiTrialJob.logPath}`;
} else {
paiTrialJob.url = `${paiTrialJob.logPath}`;
......
......@@ -62,6 +62,7 @@ class PAIK8STrainingService extends PAITrainingService {
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
this.paiJobRestServer = new PAIJobRestServer(component.get(PAIK8STrainingService));
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host);
if(this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
......@@ -257,7 +258,7 @@ class PAIK8STrainingService extends PAITrainingService {
// Step 3. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`,
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`,
method: 'POST',
body: paiJobConfig,
headers: {
......
......@@ -52,6 +52,7 @@ abstract class PAITrainingService implements TrainingService {
protected authFileHdfsPath: string | undefined = undefined;
protected portList?: string | undefined;
protected paiJobRestServer?: PAIJobRestServer;
protected protocol: string = 'http';
constructor() {
this.log = getLogger();
......@@ -165,7 +166,7 @@ abstract class PAITrainingService implements TrainingService {
}
const stopJobRequest: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\
/jobs/${trialJobDetail.paiJobName}/executionType`,
method: 'PUT',
json: true,
......@@ -216,6 +217,20 @@ abstract class PAITrainingService implements TrainingService {
return this.metricsEmitter;
}
protected formatPAIHost(host: string): string {
// If users' host start with 'http://' or 'https://', use the original host,
// or format to 'http//${host}'
if (host.startsWith('http://')) {
this.protocol = 'http';
return host.replace('http://', '');
} else if (host.startsWith('https://')) {
this.protocol = 'https';
return host.replace('https://', '');
} else {
return host;
}
}
protected async statusCheckingLoop(): Promise<void> {
while (!this.stopping) {
if(this.paiClusterConfig && this.paiClusterConfig.passWord) {
......@@ -229,7 +244,7 @@ abstract class PAITrainingService implements TrainingService {
}
}
}
await this.paiJobCollector.retrieveTrialStatus(this.paiToken, this.paiClusterConfig);
await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.paiClusterConfig);
if (this.paiJobRestServer === undefined) {
throw new Error('paiBaseJobRestServer not implemented!');
}
......@@ -259,7 +274,7 @@ abstract class PAITrainingService implements TrainingService {
}
const authenticationReq: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/token`,
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/token`,
method: 'POST',
json: true,
body: {
......
......@@ -91,6 +91,7 @@ class PAIYarnTrainingService extends PAITrainingService {
case TrialConfigMetadataKey.PAI_YARN_CLUSTER_CONFIG:
this.paiJobRestServer = new PAIJobRestServer(component.get(PAIYarnTrainingService));
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host);
this.hdfsClient = WebHDFS.createClient({
user: this.paiClusterConfig.userName,
......@@ -98,7 +99,9 @@ class PAIYarnTrainingService extends PAITrainingService {
port: 80,
path: '/webhdfs/api/v1',
host: this.paiClusterConfig.host
});
this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host);
if(this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
......@@ -107,7 +110,6 @@ class PAIYarnTrainingService extends PAITrainingService {
} else {
throw new Error('pai cluster config format error, please set password or token!');
}
break;
case TrialConfigMetadataKey.TRIAL_CONFIG:
......@@ -272,7 +274,7 @@ class PAIYarnTrainingService extends PAITrainingService {
// Step 3. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}/jobs`,
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}/jobs`,
method: 'POST',
json: true,
body: paiJobConfig,
......
......@@ -1072,7 +1072,7 @@ debug@^4.0.1, debug@^4.1.0, debug@^4.1.1:
dependencies:
ms "^2.1.1"
debuglog@*, debuglog@^1.0.1:
debuglog@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
......@@ -1840,9 +1840,9 @@ growl@1.10.5:
version "1.10.5"
resolved "https://registry.yarnpkg.com/growl/-/growl-1.10.5.tgz#f2735dc2283674fa67478b10181059355c369e5e"
handlebars@^4.0.11, handlebars@^4.1.0:
version "4.1.2"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.1.2.tgz#b6b37c1ced0306b221e094fc7aca3ec23b131b67"
handlebars@^4.0.11, handlebars@^4.3.0:
version "4.5.3"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.5.3.tgz#5cf75bd8714f7605713511a56be7c349becb0482"
dependencies:
neo-async "^2.6.0"
optimist "^0.6.1"
......@@ -2014,7 +2014,7 @@ import-lazy@^2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@*, imurmurhash@^0.1.4:
imurmurhash@^0.1.4:
version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
......@@ -2595,10 +2595,6 @@ lockfile@^1.0.4:
dependencies:
signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
lodash._baseuniq@~4.6.0:
version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
......@@ -2606,28 +2602,10 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0"
lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0:
version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
lodash._root@~3.0.0:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
......@@ -2676,10 +2654,6 @@ lodash.pick@^4.4.0:
version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
lodash.unescape@4.0.1:
version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import *
from .builtin_quantizers import *
from .lottery_ticket import LotteryTicketPruner
from .pruners import *
from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import *
from .quantizers import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from .compressor import Pruner
__all__ = ['ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
logger = logging.getLogger('torch activation rank filter pruners')
class ActivationRankFilterPruner(Pruner):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers (using activation values)
to achieve a preset level of network sparsity.
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.activation = torch.nn.functional.relu
elif activation == 'relu6':
self.activation = torch.nn.functional.relu6
else:
self.activation = None
def compress(self):
"""
Compress the model, register a hook for collecting activations.
"""
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))
layer.module.register_forward_hook(_hook)
return self.bound_model
def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
else:
mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1 or len(self.collected_activation[layer.name]) < self.statistics_batch_num:
return mask
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
return mask
class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest APoZ(average percentage of zeros) of output activations.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
def get_mask(self, base_mask, activations, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
apoz = self._calc_apoz(activations)
prune_indices = torch.argsort(apoz, descending=True)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
return base_mask
def _calc_apoz(self, activations):
"""
Calculate APoZ(average percentage of zeros) of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's APoZ(average percentage of zeros) of the activations
"""
activations = torch.cat(activations, 0)
_eq_zero = torch.eq(activations, torch.zeros_like(activations))
_apoz = torch.sum(_eq_zero, dim=(0, 2, 3)) / torch.numel(_eq_zero[:, 0, :, :])
return _apoz
class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest mean value of output activations.
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
def get_mask(self, base_mask, activations, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
mean_activation = self._cal_mean_activation(activations)
prune_indices = torch.argsort(mean_activation)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
return base_mask
def _cal_mean_activation(self, activations):
"""
Calculate mean value of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's mean value of the output activations
"""
activations = torch.cat(activations, 0)
mean_activation = torch.mean(activations, dim=(0, 2, 3))
return mean_activation
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment