Unverified Commit 58873c46 authored by Duong Nhu's avatar Duong Nhu Committed by GitHub
Browse files

Parameterized training options for EsTrainer implementation in tensorflow (#2953)

parent d5036857
...@@ -13,21 +13,29 @@ from .mutator import EnasMutator ...@@ -13,21 +13,29 @@ from .mutator import EnasMutator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
log_frequency = 100
entropy_weight = 0.0001
skip_weight = 0.8
baseline_decay = 0.999
child_steps = 500
mutator_lr = 0.00035
mutator_steps = 50
mutator_steps_aggregate = 20
aux_weight = 0.4
test_arc_per_epoch = 1
class EnasTrainer: class EnasTrainer:
def __init__(self, model, loss, metrics, reward_function, optimizer, batch_size, num_epochs, def __init__(
dataset_train, dataset_valid): self,
model,
loss,
metrics,
reward_function,
optimizer,
batch_size,
num_epochs,
dataset_train,
dataset_valid,
log_frequency=100,
entropy_weight=0.0001,
skip_weight=0.8,
baseline_decay=0.999,
child_steps=500,
mutator_lr=0.00035,
mutator_steps=50,
mutator_steps_aggregate=20,
aux_weight=0.4,
test_arc_per_epoch=1,
):
self.model = model self.model = model
self.loss = loss self.loss = loss
self.metrics = metrics self.metrics = metrics
...@@ -42,11 +50,21 @@ class EnasTrainer: ...@@ -42,11 +50,21 @@ class EnasTrainer:
self.valid_set = tf.data.Dataset.from_tensor_slices((x[split:], y[split:])) self.valid_set = tf.data.Dataset.from_tensor_slices((x[split:], y[split:]))
self.test_set = tf.data.Dataset.from_tensor_slices(dataset_valid) self.test_set = tf.data.Dataset.from_tensor_slices(dataset_valid)
self.mutator = EnasMutator(model) self.log_frequency = log_frequency
self.mutator_optim = Adam(learning_rate=mutator_lr) self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.child_steps = child_steps
self.mutator_lr = mutator_lr
self.mutator_steps = mutator_steps
self.mutator_steps_aggregate = mutator_steps_aggregate
self.aux_weight = aux_weight
self.test_arc_per_epoch = test_arc_per_epoch
self.baseline = 0. self.mutator = EnasMutator(model)
self.mutator_optim = Adam(learning_rate=self.mutator_lr)
self.baseline = 0.0
def train(self, validate=True): def train(self, validate=True):
for epoch in range(self.num_epochs): for epoch in range(self.num_epochs):
...@@ -58,14 +76,13 @@ class EnasTrainer: ...@@ -58,14 +76,13 @@ class EnasTrainer:
def validate(self): def validate(self):
self.validate_one_epoch(-1) self.validate_one_epoch(-1)
def train_one_epoch(self, epoch): def train_one_epoch(self, epoch):
train_loader, valid_loader = self._create_train_loader() train_loader, valid_loader = self._create_train_loader()
# Sample model and train # Sample model and train
meters = AverageMeterGroup() meters = AverageMeterGroup()
for step in range(1, child_steps + 1): for step in range(1, self.child_steps + 1):
x, y = next(train_loader) x, y = next(train_loader)
self.mutator.reset() self.mutator.reset()
...@@ -75,64 +92,88 @@ class EnasTrainer: ...@@ -75,64 +92,88 @@ class EnasTrainer:
logits, aux_logits = logits logits, aux_logits = logits
aux_loss = self.loss(aux_logits, y) aux_loss = self.loss(aux_logits, y)
else: else:
aux_loss = 0. aux_loss = 0.0
metrics = self.metrics(y, logits) metrics = self.metrics(y, logits)
loss = self.loss(y, logits) + aux_weight * aux_loss loss = self.loss(y, logits) + self.aux_weight * aux_loss
grads = tape.gradient(loss, self.model.trainable_weights) grads = tape.gradient(loss, self.model.trainable_weights)
grads = fill_zero_grads(grads, self.model.trainable_weights) grads = fill_zero_grads(grads, self.model.trainable_weights)
grads, _ = tf.clip_by_global_norm(grads, 5.0) grads, _ = tf.clip_by_global_norm(grads, 5.0)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights)) self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
metrics['loss'] = tf.reduce_mean(loss).numpy() metrics["loss"] = tf.reduce_mean(loss).numpy()
meters.update(metrics) meters.update(metrics)
if log_frequency and step % log_frequency == 0: if self.log_frequency and step % self.log_frequency == 0:
logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, logger.info(
self.num_epochs, step, child_steps, meters) "Model Epoch [%d/%d] Step [%d/%d] %s",
epoch + 1,
self.num_epochs,
step,
self.child_steps,
meters,
)
# Train sampler (mutator) # Train sampler (mutator)
meters = AverageMeterGroup() meters = AverageMeterGroup()
for mutator_step in range(1, mutator_steps + 1): for mutator_step in range(1, self.mutator_steps + 1):
grads_list = [] grads_list = []
for step in range(1, mutator_steps_aggregate + 1): for step in range(1, self.mutator_steps_aggregate + 1):
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
x, y = next(valid_loader) x, y = next(valid_loader)
self.mutator.reset() self.mutator.reset()
logits = self.model(x, training=False) logits = self.model(x, training=False)
metrics = self.metrics(y, logits) metrics = self.metrics(y, logits)
reward = self.reward_function(y, logits) + entropy_weight * self.mutator.sample_entropy reward = (
self.baseline = self.baseline * baseline_decay + reward * (1 - baseline_decay) self.reward_function(y, logits)
+ self.entropy_weight * self.mutator.sample_entropy
)
self.baseline = self.baseline * self.baseline_decay + reward * (
1 - self.baseline_decay
)
loss = self.mutator.sample_log_prob * (reward - self.baseline) loss = self.mutator.sample_log_prob * (reward - self.baseline)
loss += skip_weight * self.mutator.sample_skip_penalty loss += self.skip_weight * self.mutator.sample_skip_penalty
meters.update({ meters.update(
'reward': reward, {
'loss': tf.reduce_mean(loss).numpy(), "reward": reward,
'ent': self.mutator.sample_entropy.numpy(), "loss": tf.reduce_mean(loss).numpy(),
'log_prob': self.mutator.sample_log_prob.numpy(), "ent": self.mutator.sample_entropy.numpy(),
'baseline': self.baseline, "log_prob": self.mutator.sample_log_prob.numpy(),
'skip': self.mutator.sample_skip_penalty, "baseline": self.baseline,
}) "skip": self.mutator.sample_skip_penalty,
}
cur_step = step + (mutator_step - 1) * mutator_steps_aggregate )
if log_frequency and cur_step % log_frequency == 0:
logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
mutator_step, mutator_steps, step, mutator_steps_aggregate, if self.log_frequency and cur_step % self.log_frequency == 0:
meters) logger.info(
"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s",
epoch + 1,
self.num_epochs,
mutator_step,
self.mutator_steps,
step,
self.mutator_steps_aggregate,
meters,
)
grads = tape.gradient(loss, self.mutator.trainable_weights) grads = tape.gradient(loss, self.mutator.trainable_weights)
grads = fill_zero_grads(grads, self.mutator.trainable_weights) grads = fill_zero_grads(grads, self.mutator.trainable_weights)
grads_list.append(grads) grads_list.append(grads)
total_grads = [tf.math.add_n(weight_grads) for weight_grads in zip(*grads_list)] total_grads = [
tf.math.add_n(weight_grads) for weight_grads in zip(*grads_list)
]
total_grads, _ = tf.clip_by_global_norm(total_grads, 5.0) total_grads, _ = tf.clip_by_global_norm(total_grads, 5.0)
self.mutator_optim.apply_gradients(zip(total_grads, self.mutator.trainable_weights)) self.mutator_optim.apply_gradients(
zip(total_grads, self.mutator.trainable_weights)
)
def validate_one_epoch(self, epoch): def validate_one_epoch(self, epoch):
test_loader = self._create_validate_loader() test_loader = self._create_validate_loader()
for arc_id in range(test_arc_per_epoch): for arc_id in range(self.test_arc_per_epoch):
meters = AverageMeterGroup() meters = AverageMeterGroup()
for x, y in test_loader: for x, y in test_loader:
self.mutator.reset() self.mutator.reset()
...@@ -141,13 +182,17 @@ class EnasTrainer: ...@@ -141,13 +182,17 @@ class EnasTrainer:
logits, _ = logits logits, _ = logits
metrics = self.metrics(y, logits) metrics = self.metrics(y, logits)
loss = self.loss(y, logits) loss = self.loss(y, logits)
metrics['loss'] = tf.reduce_mean(loss).numpy() metrics["loss"] = tf.reduce_mean(loss).numpy()
meters.update(metrics) meters.update(metrics)
logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s", logger.info(
epoch + 1, self.num_epochs, arc_id + 1, test_arc_per_epoch, "Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
meters.summary()) epoch + 1,
self.num_epochs,
arc_id + 1,
self.test_arc_per_epoch,
meters.summary(),
)
def _create_train_loader(self): def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size) train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment