Unverified Commit 122b5b89 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Policy-based RL Strategy (#3650)

parent 92f6754e
......@@ -10,3 +10,5 @@ pytorch-lightning >= 1.1.1
onnx
peewee
graphviz
gym
tianshou >= 0.4.1
......@@ -11,3 +11,5 @@ keras == 2.1.6
onnx
peewee
graphviz
gym
tianshou >= 0.4.1
......@@ -22,5 +22,7 @@ prettytable
psutil
ruamel.yaml
ipython
gym
tianshou
https://download.pytorch.org/whl/cpu/torch-1.7.1%2Bcpu-cp37-cp37m-linux_x86_64.whl
https://download.pytorch.org/whl/cpu/torchvision-0.8.2%2Bcpu-cp37-cp37m-linux_x86_64.whl
......@@ -6,3 +6,4 @@ from .bruteforce import Random, GridSearch
from .evolution import RegularizedEvolution
from .tpe_strategy import TPEStrategy
from .local_debug_strategy import _LocalDebugStrategy
from .rl import PolicyBasedRL
# This file might cause import error for those who didn't install RL-related dependencies
import logging
import gym
import numpy as np
import torch
import torch.nn as nn
from gym import spaces
from tianshou.data import to_torch
from .utils import get_targeted_model
from ..graph import ModelStatus
from ..execution import submit_models, wait_models
_logger = logging.getLogger(__name__)
class ModelEvaluationEnv(gym.Env):
def __init__(self, base_model, mutators, search_space):
self.base_model = base_model
self.mutators = mutators
self.search_space = search_space
self.ss_keys = list(self.search_space.keys())
self.action_dim = max(map(lambda v: len(v), self.search_space.values()))
self.num_steps = len(self.search_space)
@property
def observation_space(self):
return spaces.Dict({
'action_history': spaces.MultiDiscrete([self.action_dim] * self.num_steps),
'cur_step': spaces.Discrete(self.num_steps + 1),
'action_dim': spaces.Discrete(self.action_dim + 1)
})
@property
def action_space(self):
return spaces.Discrete(self.action_dim)
def reset(self):
self.action_history = np.zeros(self.num_steps, dtype=np.int32)
self.cur_step = 0
self.sample = {}
return {
'action_history': self.action_history,
'cur_step': self.cur_step,
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]])
}
def step(self, action):
cur_key = self.ss_keys[self.cur_step]
assert action < len(self.search_space[cur_key]), \
f'Current action {action} out of range {self.search_space[cur_key]}.'
self.action_history[self.cur_step] = action
self.sample[cur_key] = self.search_space[cur_key][action]
self.cur_step += 1
obs = {
'action_history': self.action_history,
'cur_step': self.cur_step,
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) \
if self.cur_step < self.num_steps else self.action_dim
}
if self.cur_step == self.num_steps:
model = get_targeted_model(self.base_model, self.mutators, self.sample)
_logger.info(f'New model created: {self.sample}')
submit_models(model)
wait_models(model)
if model.status == ModelStatus.Failed:
return self.reset(), 0., False, {}
rew = model.metric
_logger.info(f'Model metric received as reward: {rew}')
return obs, rew, True, {}
else:
return obs, 0., False, {}
class Preprocessor(nn.Module):
def __init__(self, obs_space, hidden_dim=64, num_layers=1):
super().__init__()
self.action_dim = obs_space['action_history'].nvec[0]
self.hidden_dim = hidden_dim
# first token is [SOS]
self.embedding = nn.Embedding(self.action_dim + 1, hidden_dim)
self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
def forward(self, obs):
seq = nn.functional.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token
# end token is used to avoid out-of-range of v_s_. Will not actually affect BP.
seq = self.embedding(seq.long())
feature, _ = self.rnn(seq)
return feature[torch.arange(len(feature), device=feature.device), obs['cur_step'].long() + 1]
class Actor(nn.Module):
def __init__(self, action_space, preprocess):
super().__init__()
self.preprocess = preprocess
self.action_dim = action_space.n
self.linear = nn.Linear(self.preprocess.hidden_dim, self.action_dim)
def forward(self, obs, **kwargs):
obs = to_torch(obs, device=self.linear.weight.device)
out = self.linear(self.preprocess(obs))
# to take care of choices with different number of options
mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1)
out[mask.to(out.device)] = float('-inf')
return nn.functional.softmax(out), kwargs.get('state', None)
class Critic(nn.Module):
def __init__(self, preprocess):
super().__init__()
self.preprocess = preprocess
self.linear = nn.Linear(self.preprocess.hidden_dim, 1)
def forward(self, obs, **kwargs):
obs = to_torch(obs, device=self.linear.weight.device)
return self.linear(self.preprocess(obs)).squeeze(-1)
import logging
from typing import Optional, Callable
from .base import BaseStrategy
from .utils import dry_run_for_search_space
from ..execution import query_available_resources
try:
has_tianshou = True
import torch
from tianshou.data import AsyncCollector, Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.policy import BasePolicy, PPOPolicy # pylint: disable=unused-import
from ._rl_impl import ModelEvaluationEnv, Preprocessor, Actor, Critic
except ImportError:
has_tianshou = False
_logger = logging.getLogger(__name__)
class PolicyBasedRL(BaseStrategy):
"""
Algorithm for policy-based reinforcement learning.
This is a wrapper of algorithms provided in tianshou (PPO by default),
and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., REINFORCE [1]_).
Note that RL algorithms are known to have issues on Windows and MacOS. They will be supported in future.
Parameters
----------
max_collect : int
How many times collector runs to collect trials for RL. Default 100.
trial_per_collect : int
How many trials (trajectories) each time collector collects.
After each collect, trainer will sample batch from replay buffer and do the update. Default: 20.
policy_fn : function
Takes ``ModelEvaluationEnv`` as input and return a policy. See ``_default_policy_fn`` for an example.
asynchronous : bool
If true, in each step, collector won't wait for all the envs to complete.
This should generally not affect the result, but might affect the efficiency. Note that a slightly more trials
than expected might be collected if this is enabled.
If asynchronous is false, collector will wait for all parallel environments to complete in each step.
See ``tianshou.data.AsyncCollector`` for more details.
References
----------
.. [1] Barret Zoph and Quoc V. Le, "Neural Architecture Search with Reinforcement Learning".
https://arxiv.org/abs/1611.01578
"""
def __init__(self, max_collect: int = 100, trial_per_collect = 20,
policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None, asynchronous: bool = True):
if not has_tianshou:
raise ImportError('`tianshou` is required to run RL-based strategy. '
'Please use "pip install tianshou" to install it beforehand.')
self.policy_fn = policy_fn or self._default_policy_fn
self.max_collect = max_collect
self.trial_per_collect = trial_per_collect
self.asynchronous = asynchronous
@staticmethod
def _default_policy_fn(env):
net = Preprocessor(env.observation_space)
actor = Actor(env.action_space, net)
critic = Critic(net)
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=1e-4)
return PPOPolicy(actor, critic, optim, torch.distributions.Categorical,
discount_factor=1., action_space=env.action_space)
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
concurrency = query_available_resources()
env_fn = lambda: ModelEvaluationEnv(base_model, applied_mutators, search_space)
policy = self.policy_fn(env_fn())
if self.asynchronous:
# wait for half of the env complete in each step
env = SubprocVectorEnv([env_fn for _ in range(concurrency)], wait_num=int(concurrency * 0.5))
collector = AsyncCollector(policy, env, VectorReplayBuffer(20000, len(env)))
else:
env = SubprocVectorEnv([env_fn for _ in range(concurrency)])
collector = Collector(policy, env, VectorReplayBuffer(20000, len(env)))
for cur_collect in range(1, self.max_collect + 1):
_logger.info('Collect [%d] Running...', cur_collect)
result = collector.collect(n_episode=self.trial_per_collect)
_logger.info('Collect [%d] Result: %s', cur_collect, str(result))
policy.update(0, collector.buffer, batch_size=64, repeat=5)
......@@ -35,6 +35,7 @@ jobs:
python3 -m pip install keras==2.1.6
python3 -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
python3 -m pip install thop
python3 -m pip install tianshou>=0.4.1 gym
sudo apt-get install swig -y
displayName: Install extra dependencies
......
......@@ -30,6 +30,7 @@ jobs:
python -m pip install torch==1.6.0 torchvision==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install 'pytorch-lightning>=1.1.1'
python -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
python -m pip install tianshou>=0.4.1 gym
displayName: Install extra dependencies
# Need del later
......
import random
import sys
import time
import threading
from typing import *
......@@ -6,6 +7,7 @@ from typing import *
import nni.retiarii.execution.api
import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy
import pytest
import torch
import torch.nn.functional as F
from nni.retiarii import Model
......@@ -58,7 +60,7 @@ def _reset_execution_engine(engine=None):
class Net(nn.Module):
def __init__(self, hidden_size=32):
def __init__(self, hidden_size=32, diff_size=False):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
......@@ -69,7 +71,7 @@ class Net(nn.Module):
self.fc2 = nn.LayerChoice([
nn.Linear(hidden_size, 10, bias=False),
nn.Linear(hidden_size, 10, bias=True)
], label='fc2')
] + ([] if not diff_size else [nn.Linear(hidden_size, 10, bias=False)]), label='fc2')
def forward(self, x):
x = F.relu(self.conv1(x))
......@@ -82,8 +84,8 @@ class Net(nn.Module):
return F.log_softmax(x, dim=1)
def _get_model_and_mutators():
base_model = Net()
def _get_model_and_mutators(**kwargs):
base_model = Net(**kwargs)
script_module = torch.jit.script(base_model)
base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.evaluator = DebugEvaluator()
......@@ -139,7 +141,25 @@ def test_evolution():
_reset_execution_engine()
@pytest.mark.skipif(sys.platform in ('win32', 'darwin'), reason='Does not run on Windows and MacOS')
def test_rl():
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators(diff_size=True))
wait_models(*engine.models)
_reset_execution_engine()
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10, asynchronous=False)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators())
wait_models(*engine.models)
_reset_execution_engine()
if __name__ == '__main__':
test_grid_search()
test_random_search()
test_evolution()
test_rl()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment