Unverified Commit 70cee7d8 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

TextNAS without retrain (#1890)

parent eb39749f
# TextNAS: A Neural Architecture Search Space tailored for Text Representation
TextNAS by MSRA. Official Release.
[Paper link](https://arxiv.org/abs/1912.10729)
## Preparation
Prepare the word vectors and SST dataset, and organize them in data directory as shown below:
```
textnas
├── data
│ ├── sst
│ │ └── trees
│ │ ├── dev.txt
│ │ ├── test.txt
│ │ └── train.txt
│ └── glove.840B.300d.txt
├── dataloader.py
├── model.py
├── ops.py
├── README.md
├── search.py
└── utils.py
```
The following link might be helpful for finding and downloading the corresponding dataset:
* [GloVe: Global Vectors for Word Representation](https://nlp.stanford.edu/projects/glove/)
* [Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank](https://nlp.stanford.edu/sentiment/)
## Search
```
python search.py
```
After each search epoch, 10 sampled architectures will be tested directly. Their performances are expected to be 40% - 42% after 10 epochs.
By default, 20 sampled architectures will be exported into `checkpoints` directory for next step.
## Retrain
Not ready.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import pickle
from collections import Counter
import numpy as np
import torch
from torch.utils import data
logger = logging.getLogger("nni.textnas")
class PTBTree:
WORD_TO_WORD_MAPPING = {
"{": "-LCB-",
"}": "-RCB-"
}
def __init__(self):
self.subtrees = []
self.word = None
self.label = ""
self.parent = None
self.span = (-1, -1)
self.word_vector = None # HOS, store dx1 RNN word vector
self.prediction = None # HOS, store Kx1 prediction vector
def is_leaf(self):
return len(self.subtrees) == 0
def set_by_text(self, text, pos=0, left=0):
depth = 0
right = left
for i in range(pos + 1, len(text)):
char = text[i]
# update the depth
if char == "(":
depth += 1
if depth == 1:
subtree = PTBTree()
subtree.parent = self
subtree.set_by_text(text, i, right)
right = subtree.span[1]
self.span = (left, right)
self.subtrees.append(subtree)
elif char == ")":
depth -= 1
if len(self.subtrees) == 0:
pos = i
for j in range(i, 0, -1):
if text[j] == " ":
pos = j
break
self.word = text[pos + 1:i]
self.span = (left, left + 1)
# we've reached the end of the category that is the root of this subtree
if depth == 0 and char == " " and self.label == "":
self.label = text[pos + 1:i]
# we've reached the end of the scope for this bracket
if depth < 0:
break
# Fix some issues with variation in output, and one error in the treebank
# for a word with a punctuation POS
self.standardise_node()
def standardise_node(self):
if self.word in self.WORD_TO_WORD_MAPPING:
self.word = self.WORD_TO_WORD_MAPPING[self.word]
def __repr__(self, single_line=True, depth=0):
ans = ""
if not single_line and depth > 0:
ans = "\n" + depth * "\t"
ans += "(" + self.label
if self.word is not None:
ans += " " + self.word
for subtree in self.subtrees:
if single_line:
ans += " "
ans += subtree.__repr__(single_line, depth + 1)
ans += ")"
return ans
def read_tree(source):
cur_text = []
depth = 0
while True:
line = source.readline()
# Check if we are out of input
if line == "":
return None
# strip whitespace and only use if this contains something
line = line.strip()
if line == "":
continue
cur_text.append(line)
# Update depth
for char in line:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
# At depth 0 we have a complete tree
if depth == 0:
tree = PTBTree()
tree.set_by_text(" ".join(cur_text))
return tree
return None
def read_trees(source, max_sents=-1):
with open(source) as fp:
trees = []
while True:
tree = read_tree(fp)
if tree is None:
break
trees.append(tree)
if len(trees) >= max_sents > 0:
break
return trees
class SSTDataset(data.Dataset):
def __init__(self, sents, mask, labels):
self.sents = sents
self.labels = labels
self.mask = mask
def __getitem__(self, index):
return (self.sents[index], self.mask[index]), self.labels[index]
def __len__(self):
return len(self.sents)
def sst_get_id_input(content, word_id_dict, max_input_length):
words = content.split(" ")
sentence = [word_id_dict["<pad>"]] * max_input_length
mask = [0] * max_input_length
unknown = word_id_dict["<unknown>"]
for i, word in enumerate(words[:max_input_length]):
sentence[i] = word_id_dict.get(word, unknown)
mask[i] = 1
return sentence, mask
def sst_get_phrases(trees, sample_ratio=1.0, is_binary=False, only_sentence=False):
all_phrases = []
for tree in trees:
if only_sentence:
sentence = get_sentence_by_tree(tree)
label = int(tree.label)
pair = (sentence, label)
all_phrases.append(pair)
else:
phrases = get_phrases_by_tree(tree)
sentence = get_sentence_by_tree(tree)
pair = (sentence, int(tree.label))
all_phrases.append(pair)
all_phrases += phrases
if sample_ratio < 1.:
np.random.shuffle(all_phrases)
result_phrases = []
for pair in all_phrases:
if is_binary:
phrase, label = pair
if label <= 1:
pair = (phrase, 0)
elif label >= 3:
pair = (phrase, 1)
else:
continue
if sample_ratio == 1.:
result_phrases.append(pair)
else:
rand_portion = np.random.random()
if rand_portion < sample_ratio:
result_phrases.append(pair)
return result_phrases
def get_phrases_by_tree(tree):
phrases = []
if tree is None:
return phrases
if tree.is_leaf():
pair = (tree.word, int(tree.label))
phrases.append(pair)
return phrases
left_child_phrases = get_phrases_by_tree(tree.subtrees[0])
right_child_phrases = get_phrases_by_tree(tree.subtrees[1])
phrases.extend(left_child_phrases)
phrases.extend(right_child_phrases)
sentence = get_sentence_by_tree(tree)
pair = (sentence, int(tree.label))
phrases.append(pair)
return phrases
def get_sentence_by_tree(tree):
if tree is None:
return ""
if tree.is_leaf():
return tree.word
left_sentence = get_sentence_by_tree(tree.subtrees[0])
right_sentence = get_sentence_by_tree(tree.subtrees[1])
sentence = left_sentence + " " + right_sentence
return sentence.strip()
def get_word_id_dict(word_num_dict, word_id_dict, min_count):
z = [k for k in sorted(word_num_dict.keys())]
for word in z:
count = word_num_dict[word]
if count >= min_count:
index = len(word_id_dict)
if word not in word_id_dict:
word_id_dict[word] = index
return word_id_dict
def load_word_num_dict(phrases, word_num_dict):
for sentence, _ in phrases:
words = sentence.split(" ")
for cur_word in words:
word = cur_word.strip()
word_num_dict[word] += 1
return word_num_dict
def init_trainable_embedding(embedding_path, word_id_dict, embed_dim=300):
word_embed_model = load_glove_model(embedding_path, embed_dim)
assert word_embed_model["pool"].shape[1] == embed_dim
embedding = np.random.random([len(word_id_dict), embed_dim]).astype(np.float32) / 2.0 - 0.25
embedding[0] = np.zeros(embed_dim) # PAD
embedding[1] = (np.random.rand(embed_dim) - 0.5) / 2 # UNK
for word, idx in word_id_dict.items():
if idx == 0 or idx == 1:
continue
if word in word_embed_model["mapping"]:
embedding[idx] = word_embed_model["pool"][word_embed_model["mapping"][word]]
else:
embedding[idx] = np.random.rand(embed_dim) / 2.0 - 0.25
return embedding
def sst_get_trainable_data(phrases, word_id_dict, max_input_length):
texts, labels, mask = [], [], []
for phrase, label in phrases:
if not phrase.split():
continue
phrase_split, mask_split = sst_get_id_input(phrase, word_id_dict, max_input_length)
texts.append(phrase_split)
labels.append(int(label))
mask.append(mask_split) # field_input is mask
labels = np.array(labels, dtype=np.int64)
texts = np.reshape(texts, [-1, max_input_length]).astype(np.int32)
mask = np.reshape(mask, [-1, max_input_length]).astype(np.int32)
return SSTDataset(texts, mask, labels)
def load_glove_model(filename, embed_dim):
if os.path.exists(filename + ".cache"):
logger.info("Found cache. Loading...")
with open(filename + ".cache", "rb") as fp:
return pickle.load(fp)
embedding = {"mapping": dict(), "pool": []}
with open(filename) as f:
for i, line in enumerate(f):
line = line.rstrip("\n")
vocab_word, *vec = line.rsplit(" ", maxsplit=embed_dim)
assert len(vec) == 300, "Unexpected line: '%s'" % line
embedding["pool"].append(np.array(list(map(float, vec)), dtype=np.float32))
embedding["mapping"][vocab_word] = i
embedding["pool"] = np.stack(embedding["pool"])
with open(filename + ".cache", "wb") as fp:
pickle.dump(embedding, fp)
return embedding
def read_data_sst(data_path, max_input_length=64, min_count=1, train_with_valid=False,
train_ratio=1., valid_ratio=1., is_binary=False, only_sentence=False):
word_id_dict = dict()
word_num_dict = Counter()
sst_path = os.path.join(data_path, "sst")
logger.info("Reading SST data...")
train_file_name = os.path.join(sst_path, "trees", "train.txt")
valid_file_name = os.path.join(sst_path, "trees", "dev.txt")
test_file_name = os.path.join(sst_path, "trees", "test.txt")
train_trees = read_trees(train_file_name)
train_phrases = sst_get_phrases(train_trees, train_ratio, is_binary, only_sentence)
logger.info("Finish load train phrases.")
valid_trees = read_trees(valid_file_name)
valid_phrases = sst_get_phrases(valid_trees, valid_ratio, is_binary, only_sentence)
logger.info("Finish load valid phrases.")
if train_with_valid:
train_phrases += valid_phrases
test_trees = read_trees(test_file_name)
test_phrases = sst_get_phrases(test_trees, valid_ratio, is_binary, only_sentence=True)
logger.info("Finish load test phrases.")
# get word_id_dict
word_id_dict["<pad>"] = 0
word_id_dict["<unknown>"] = 1
load_word_num_dict(train_phrases, word_num_dict)
logger.info("Finish load train words: %d.", len(word_num_dict))
load_word_num_dict(valid_phrases, word_num_dict)
load_word_num_dict(test_phrases, word_num_dict)
logger.info("Finish load valid+test words: %d.", len(word_num_dict))
word_id_dict = get_word_id_dict(word_num_dict, word_id_dict, min_count)
logger.info("After trim vocab length: %d.", len(word_id_dict))
logger.info("Loading embedding...")
embedding = init_trainable_embedding(os.path.join(data_path, "glove.840B.300d.txt"), word_id_dict)
logger.info("Finish initialize word embedding.")
dataset_train = sst_get_trainable_data(train_phrases, word_id_dict, max_input_length)
logger.info("Loaded %d training samples.", len(dataset_train))
dataset_valid = sst_get_trainable_data(valid_phrases, word_id_dict, max_input_length)
logger.info("Loaded %d validation samples.", len(dataset_valid))
dataset_test = sst_get_trainable_data(test_phrases, word_id_dict, max_input_length)
logger.info("Loaded %d test samples.", len(dataset_test))
return dataset_train, dataset_valid, dataset_test, torch.from_numpy(embedding)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
import torch
import torch.nn as nn
from nni.nas.pytorch import mutables
from ops import ConvBN, LinearCombine, AvgPool, MaxPool, RNN, Attention, BatchNorm
from utils import GlobalMaxPool, GlobalAvgPool
class Layer(mutables.MutableScope):
def __init__(self, key, prev_keys, hidden_units, choose_from_k, cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask):
super(Layer, self).__init__(key)
def conv_shortcut(kernel_size):
return ConvBN(kernel_size, hidden_units, hidden_units, cnn_keep_prob, False, True)
self.n_candidates = len(prev_keys)
if self.n_candidates:
self.prec = mutables.InputChoice(choose_from=prev_keys[-choose_from_k:], n_chosen=1)
else:
# first layer, skip input choice
self.prec = None
self.op = mutables.LayerChoice([
conv_shortcut(1),
conv_shortcut(3),
conv_shortcut(5),
conv_shortcut(7),
AvgPool(3, False, True),
MaxPool(3, False, True),
RNN(hidden_units, lstm_keep_prob),
Attention(hidden_units, 4, att_keep_prob, att_mask)
])
if self.n_candidates:
self.skipconnect = mutables.InputChoice(choose_from=prev_keys)
else:
self.skipconnect = None
self.bn = BatchNorm(hidden_units, False, True)
def forward(self, last_layer, prev_layers, mask):
# pass an extra last_layer to deal with layer 0 (prev_layers is empty)
if self.prec is None:
prec = last_layer
else:
prec = self.prec(prev_layers[-self.prec.n_candidates:]) # skip first
out = self.op(prec, mask)
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[-self.skipconnect.n_candidates:])
if connection is not None:
out += connection
out = self.bn(out, mask)
return out
class Model(nn.Module):
def __init__(self, embedding, hidden_units=256, num_layers=24, num_classes=5, choose_from_k=5,
lstm_keep_prob=0.5, cnn_keep_prob=0.5, att_keep_prob=0.5, att_mask=True,
embed_keep_prob=0.5, final_output_keep_prob=1.0, global_pool="avg"):
super(Model, self).__init__()
self.embedding = nn.Embedding.from_pretrained(embedding, freeze=False)
self.hidden_units = hidden_units
self.num_layers = num_layers
self.num_classes = num_classes
self.init_conv = ConvBN(1, self.embedding.embedding_dim, hidden_units, cnn_keep_prob, False, True)
self.layers = nn.ModuleList()
candidate_keys_pool = []
for layer_id in range(self.num_layers):
k = "layer_{}".format(layer_id)
self.layers.append(Layer(k, candidate_keys_pool, hidden_units, choose_from_k,
cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask))
candidate_keys_pool.append(k)
self.linear_combine = LinearCombine(self.num_layers)
self.linear_out = nn.Linear(self.hidden_units, self.num_classes)
self.embed_dropout = nn.Dropout(p=1 - embed_keep_prob)
self.output_dropout = nn.Dropout(p=1 - final_output_keep_prob)
assert global_pool in ["max", "avg"]
if global_pool == "max":
self.global_pool = GlobalMaxPool()
elif global_pool == "avg":
self.global_pool = GlobalAvgPool()
def forward(self, inputs):
sent_ids, mask = inputs
seq = self.embedding(sent_ids.long())
seq = self.embed_dropout(seq)
seq = torch.transpose(seq, 1, 2) # from (N, L, C) -> (N, C, L)
x = self.init_conv(seq, mask)
prev_layers = []
for layer in self.layers:
x = layer(x, prev_layers, mask)
prev_layers.append(x)
x = self.linear_combine(torch.stack(prev_layers))
x = self.global_pool(x, mask)
x = self.output_dropout(x)
x = self.linear_out(x)
return x
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn.functional as F
from torch import nn
from utils import get_length, INF
class Mask(nn.Module):
def forward(self, seq, mask):
# seq: (N, C, L)
# mask: (N, L)
seq_mask = torch.unsqueeze(mask, 2)
seq_mask = torch.transpose(seq_mask.repeat(1, 1, seq.size()[1]), 1, 2)
return seq.where(torch.eq(seq_mask, 1), torch.zeros_like(seq))
class BatchNorm(nn.Module):
def __init__(self, num_features, pre_mask, post_mask, eps=1e-5, decay=0.9, affine=True):
super(BatchNorm, self).__init__()
self.mask_opt = Mask()
self.pre_mask = pre_mask
self.post_mask = post_mask
self.bn = nn.BatchNorm1d(num_features, eps=eps, momentum=1.0 - decay, affine=affine)
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.bn(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
return seq
class ConvBN(nn.Module):
def __init__(self, kernal_size, in_channels, out_channels, cnn_keep_prob,
pre_mask, post_mask, with_bn=True, with_relu=True):
super(ConvBN, self).__init__()
self.mask_opt = Mask()
self.pre_mask = pre_mask
self.post_mask = post_mask
self.with_bn = with_bn
self.with_relu = with_relu
self.conv = nn.Conv1d(in_channels, out_channels, kernal_size, 1, bias=True, padding=(kernal_size - 1) // 2)
self.dropout = nn.Dropout(p=(1 - cnn_keep_prob))
if with_bn:
self.bn = BatchNorm(out_channels, not post_mask, True)
if with_relu:
self.relu = nn.ReLU()
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.conv(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
if self.with_bn:
seq = self.bn(seq, mask)
if self.with_relu:
seq = self.relu(seq)
seq = self.dropout(seq)
return seq
class AvgPool(nn.Module):
def __init__(self, kernal_size, pre_mask, post_mask):
super(AvgPool, self).__init__()
self.avg_pool = nn.AvgPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
self.pre_mask = pre_mask
self.post_mask = post_mask
self.mask_opt = Mask()
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.avg_pool(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
return seq
class MaxPool(nn.Module):
def __init__(self, kernal_size, pre_mask, post_mask):
super(MaxPool, self).__init__()
self.max_pool = nn.MaxPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
self.pre_mask = pre_mask
self.post_mask = post_mask
self.mask_opt = Mask()
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.max_pool(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
return seq
class Attention(nn.Module):
def __init__(self, num_units, num_heads, keep_prob, is_mask):
super(Attention, self).__init__()
self.num_heads = num_heads
self.keep_prob = keep_prob
self.linear_q = nn.Linear(num_units, num_units)
self.linear_k = nn.Linear(num_units, num_units)
self.linear_v = nn.Linear(num_units, num_units)
self.bn = BatchNorm(num_units, True, is_mask)
self.dropout = nn.Dropout(p=1 - self.keep_prob)
def forward(self, seq, mask):
in_c = seq.size()[1]
seq = torch.transpose(seq, 1, 2) # (N, L, C)
queries = seq
keys = seq
num_heads = self.num_heads
# T_q = T_k = L
Q = F.relu(self.linear_q(seq)) # (N, T_q, C)
K = F.relu(self.linear_k(seq)) # (N, T_k, C)
V = F.relu(self.linear_v(seq)) # (N, T_k, C)
# Split and concat
Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h)
K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
# Multiplication
outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k)
# Scale
outputs = outputs / (K_.size()[-1] ** 0.5)
# Key Masking
key_masks = mask.repeat(num_heads, 1) # (h*N, T_k)
key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k)
key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k)
paddings = torch.ones_like(outputs) * (-INF) # extremely small value
outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs)
query_masks = mask.repeat(num_heads, 1) # (h*N, T_q)
query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1)
query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k)
att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k)
att_scores = self.dropout(att_scores)
# Weighted sum
x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h)
# Restore shape
x_outputs = torch.cat(
torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0),
dim=2) # (N, T_q, C)
x = torch.transpose(x_outputs, 1, 2) # (N, C, L)
x = self.bn(x, mask)
return x
class RNN(nn.Module):
def __init__(self, hidden_size, output_keep_prob):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.bid_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
self.output_keep_prob = output_keep_prob
self.out_dropout = nn.Dropout(p=(1 - self.output_keep_prob))
def forward(self, seq, mask):
# seq: (N, C, L)
# mask: (N, L)
max_len = seq.size()[2]
length = get_length(mask)
seq = torch.transpose(seq, 1, 2) # to (N, L, C)
packed_seq = nn.utils.rnn.pack_padded_sequence(seq, length, batch_first=True,
enforce_sorted=False)
outputs, _ = self.bid_rnn(packed_seq)
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True,
total_length=max_len)[0]
outputs = outputs.view(-1, max_len, 2, self.hidden_size).sum(2) # (N, L, C)
outputs = self.out_dropout(outputs) # output dropout
return torch.transpose(outputs, 1, 2) # back to: (N, C, L)
class LinearCombine(nn.Module):
def __init__(self, layers_num, trainable=True, input_aware=False, word_level=False):
super(LinearCombine, self).__init__()
self.input_aware = input_aware
self.word_level = word_level
if input_aware:
raise NotImplementedError("Input aware is not supported.")
self.w = nn.Parameter(torch.full((layers_num, 1, 1, 1), 1.0 / layers_num),
requires_grad=trainable)
def forward(self, seq):
nw = F.softmax(self.w, dim=0)
seq = torch.mul(seq, nw)
seq = torch.sum(seq, dim=0)
return seq
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import random
from argparse import ArgumentParser
from itertools import cycle
import numpy as np
import torch
import torch.nn as nn
from nni.nas.pytorch.enas import EnasMutator, EnasTrainer
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from dataloader import read_data_sst
from model import Model
from utils import accuracy
logger = logging.getLogger("nni.textnas")
class TextNASTrainer(EnasTrainer):
def __init__(self, *args, train_loader=None, valid_loader=None, test_loader=None, **kwargs):
super().__init__(*args, **kwargs)
self.train_loader = train_loader
self.valid_loader = valid_loader
self.test_loader = test_loader
def init_dataloader(self):
pass
if __name__ == "__main__":
parser = ArgumentParser("textnas")
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=50, type=int)
parser.add_argument("--seed", default=1234, type=int)
parser.add_argument("--epochs", default=10, type=int)
parser.add_argument("--lr", default=5e-3, type=float)
args = parser.parse_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
train_dataset, valid_dataset, test_dataset, embedding = read_data_sst("data")
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4)
train_loader, valid_loader = cycle(train_loader), cycle(valid_loader)
model = Model(embedding)
mutator = EnasMutator(model, temperature=None, tanh_constant=None, entropy_reduction="mean")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-3, weight_decay=2e-6)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-5)
trainer = TextNASTrainer(model,
loss=criterion,
metrics=lambda output, target: {"acc": accuracy(output, target)},
reward_function=accuracy,
optimizer=optimizer,
callbacks=[LRSchedulerCallback(lr_scheduler)],
batch_size=args.batch_size,
num_epochs=args.epochs,
dataset_train=None,
dataset_valid=None,
train_loader=train_loader,
valid_loader=valid_loader,
test_loader=test_loader,
log_frequency=args.log_frequency,
mutator=mutator,
mutator_lr=2e-3,
mutator_steps=500,
mutator_steps_aggregate=1,
child_steps=3000,
baseline_decay=0.99,
test_arc_per_epoch=10)
trainer.train()
os.makedirs("checkpoints", exist_ok=True)
for i in range(20):
trainer.export(os.path.join("checkpoints", "architecture_%02d.json" % i))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
INF = 1E10
EPS = 1E-12
logger = logging.getLogger("nni.textnas")
def get_length(mask):
length = torch.sum(mask, 1)
length = length.long()
return length
class GlobalAvgPool(nn.Module):
def forward(self, x, mask):
x = torch.sum(x, 2)
length = torch.sum(mask, 1, keepdim=True).float()
length += torch.eq(length, 0.0).float() * EPS
length = length.repeat(1, x.size()[1])
x /= length
return x
class GlobalMaxPool(nn.Module):
def forward(self, x, mask):
mask = torch.eq(mask.float(), 0.0).long()
mask = torch.unsqueeze(mask, dim=1).repeat(1, x.size()[1], 1)
mask *= -INF
x += mask
x, _ = torch.max(x + mask, 2)
return x
class IteratorWrapper:
def __init__(self, loader):
self.loader = loader
self.iterator = None
def __iter__(self):
self.iterator = iter(self.loader)
return self
def __len__(self):
return len(self.loader)
def __next__(self):
data = next(self.iterator)
text, length = data.text
max_length = text.size(1)
label = data.label - 1
bs = label.size(0)
mask = torch.arange(max_length, device=length.device).unsqueeze(0).repeat(bs, 1)
mask = mask < length.unsqueeze(-1).repeat(1, max_length)
return (text, mask), label
def accuracy(output, target):
batch_size = target.size(0)
_, predicted = torch.max(output.data, 1)
return (predicted == target).sum().item() / batch_size
......@@ -30,7 +30,7 @@ class StackedLSTMCell(nn.Module):
class EnasMutator(Mutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, branch_bias=0.25):
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"):
"""
Initialize a EnasMutator.
......@@ -49,17 +49,22 @@ class EnasMutator(Mutator):
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
super().__init__(model)
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target
self.branch_bias = branch_bias
......@@ -70,6 +75,8 @@ class EnasMutator(Mutator):
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean."
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict()
......@@ -135,15 +142,17 @@ class EnasMutator(Mutator):
def _sample_layer_choice(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += torch.sum(log_prob)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy)
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = self.embedding(branch_id)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
......@@ -158,6 +167,8 @@ class EnasMutator(Mutator):
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
......@@ -178,7 +189,7 @@ class EnasMutator(Mutator):
log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()]
self.sample_log_prob += torch.sum(log_prob)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy)
self.sample_entropy += self.entropy_reduction(entropy)
return skip.bool()
......@@ -2,11 +2,14 @@
# Licensed under the MIT license.
import logging
from itertools import cycle
import torch
import torch.nn as nn
import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from nni.nas.pytorch.utils import AverageMeterGroup, to_device
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
......@@ -16,8 +19,9 @@ class EnasTrainer(Trainer):
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4):
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4,
test_arc_per_epoch=1):
"""
Initialize an EnasTrainer.
......@@ -57,6 +61,8 @@ class EnasTrainer(Trainer):
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
Learning rate for RL controller.
mutator_steps_aggregate : int
......@@ -65,12 +71,16 @@ class EnasTrainer(Trainer):
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
super().__init__(model, mutator if mutator is not None else EnasMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.reward_function = reward_function
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
self.batch_size = batch_size
self.workers = workers
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
......@@ -78,32 +88,40 @@ class EnasTrainer(Trainer):
self.baseline = 0.
self.mutator_steps_aggregate = mutator_steps_aggregate
self.mutator_steps = mutator_steps
self.child_steps = child_steps
self.aux_weight = aux_weight
self.test_arc_per_epoch = test_arc_per_epoch
self.init_dataloader()
def init_dataloader(self):
n_train = len(self.dataset_train)
split = n_train // 10
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=workers)
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=workers)
num_workers=self.workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
batch_size=self.batch_size,
num_workers=self.workers)
self.train_loader = cycle(self.train_loader)
self.valid_loader = cycle(self.valid_loader)
def train_one_epoch(self, epoch):
# Sample model and train
self.model.train()
self.mutator.eval()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
for step in range(1, self.child_steps + 1):
x, y = next(self.train_loader)
x, y = to_device(x, self.device), to_device(y, self.device)
self.optimizer.zero_grad()
with torch.no_grad():
......@@ -119,55 +137,71 @@ class EnasTrainer(Trainer):
loss = self.loss(logits, y)
loss = loss + self.aux_weight * aux_loss
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
self.optimizer.step()
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1,
self.num_epochs, step, self.child_steps, meters)
# Train sampler (mutator)
self.model.eval()
self.mutator.train()
meters = AverageMeterGroup()
mutator_step, total_mutator_steps = 0, self.mutator_steps * self.mutator_steps_aggregate
while mutator_step < total_mutator_steps:
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
for mutator_step in range(1, self.mutator_steps + 1):
self.mutator_optim.zero_grad()
for step in range(1, self.mutator_steps_aggregate + 1):
x, y = next(self.valid_loader)
x, y = to_device(x, self.device), to_device(y, self.device)
self.mutator.reset()
with torch.no_grad():
logits = self.model(x)
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight is not None:
reward += self.entropy_weight * self.mutator.sample_entropy
if self.entropy_weight:
reward += self.entropy_weight * self.mutator.sample_entropy.item()
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
self.baseline = self.baseline.detach().item()
loss = self.mutator.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.mutator.sample_skip_penalty
metrics["reward"] = reward
metrics["loss"] = loss.item()
metrics["ent"] = self.mutator.sample_entropy.item()
metrics["log_prob"] = self.mutator.sample_log_prob.item()
metrics["baseline"] = self.baseline
metrics["skip"] = self.mutator.sample_skip_penalty
loss = loss / self.mutator_steps_aggregate
loss /= self.mutator_steps_aggregate
loss.backward()
meters.update(metrics)
if mutator_step % self.mutator_steps_aggregate == 0:
self.mutator_optim.step()
self.mutator_optim.zero_grad()
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
if self.log_frequency is not None and cur_step % self.log_frequency == 0:
logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs,
mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate,
meters)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs,
mutator_step // self.mutator_steps_aggregate + 1, self.mutator_steps, meters)
mutator_step += 1
if mutator_step >= total_mutator_steps:
break
nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
self.mutator_optim.step()
def validate_one_epoch(self, epoch):
pass
with torch.no_grad():
for arc_id in range(self.test_arc_per_epoch):
meters = AverageMeterGroup()
for x, y in self.test_loader:
x, y = to_device(x, self.device), to_device(y, self.device)
self.mutator.reset()
logits = self.model(x)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch,
meters.summary())
......@@ -159,7 +159,7 @@ class InputChoice(Mutable):
"than number of candidates."
self.n_candidates = n_candidates
self.choose_from = choose_from
self.choose_from = choose_from.copy()
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
......
......@@ -96,12 +96,12 @@ class Trainer(BaseTrainer):
callback.on_epoch_begin(epoch)
# training
_logger.info("Epoch %d Training", epoch)
_logger.info("Epoch %d Training", epoch + 1)
self.train_one_epoch(epoch)
if validate:
# validation
_logger.info("Epoch %d Validating", epoch)
_logger.info("Epoch %d Validating", epoch + 1)
self.validate_one_epoch(epoch)
for callback in self.callbacks:
......
......@@ -4,6 +4,8 @@
import logging
from collections import OrderedDict
import torch
_counter = 0
_logger = logging.getLogger(__name__)
......@@ -15,7 +17,22 @@ def global_mutable_counting():
return _counter
def to_device(obj, device):
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, tuple):
return tuple(to_device(t, device) for t in obj)
if isinstance(obj, list):
return [to_device(t, device) for t in obj]
if isinstance(obj, dict):
return {k: to_device(v, device) for k, v in obj.items()}
if isinstance(obj, (int, float, str)):
return obj
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))
class AverageMeterGroup:
"""Average meter group for multiple average meters"""
def __init__(self):
self.meters = OrderedDict()
......@@ -33,7 +50,10 @@ class AverageMeterGroup:
return self.meters[item]
def __str__(self):
return " ".join(str(v) for _, v in self.meters.items())
return " ".join(str(v) for v in self.meters.values())
def summary(self):
return " ".join(v.summary() for v in self.meters.values())
class AverageMeter:
......@@ -72,6 +92,10 @@ class AverageMeter:
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = '{name}: {avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
class StructuredMutableTreeNode:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment