"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "b11bc73da330355c74e16107ae1602076404178a"
Unverified Commit f2f58dbb authored by Zhenhua Han's avatar Zhenhua Han Committed by GitHub
Browse files

[Retiarii] cross-graph optimization: device placement and input deduplication (#3202)

parent 6645bd33
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
/ts/nni_manager/exp_profile.json /ts/nni_manager/exp_profile.json
/ts/nni_manager/metrics.json /ts/nni_manager/metrics.json
/ts/nni_manager/trial_jobs.json /ts/nni_manager/trial_jobs.json
/test/ut/retiarii/_debug_graph_data.json
/test/ut/retiarii/out.tmp
# Logs # Logs
logs logs
...@@ -105,5 +107,3 @@ venv.bak/ ...@@ -105,5 +107,3 @@ venv.bak/
.vscode .vscode
.vs .vs
.history .history
generated/
test/ut/retiarii/_debug_graph_data.json
...@@ -8,7 +8,7 @@ torch == 1.9.0+cpu ; sys_platform != "darwin" ...@@ -8,7 +8,7 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin" torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin" torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin" torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.1.1 pytorch-lightning >= 1.2.8
onnx onnx
peewee peewee
graphviz graphviz
......
...@@ -5,7 +5,7 @@ tensorflow ...@@ -5,7 +5,7 @@ tensorflow
keras == 2.4.3 keras == 2.4.3
torch == 1.9.0+cu111 torch == 1.9.0+cu111
torchvision == 0.10.0+cu111 torchvision == 0.10.0+cu111
pytorch-lightning >= 1.1.1 pytorch-lightning >= 1.2.8
onnx onnx
peewee peewee
graphviz graphviz
......
...@@ -9,7 +9,9 @@ class GPUDevice: ...@@ -9,7 +9,9 @@ class GPUDevice:
status: Literal['idle', 'busy', 'unknown'] = 'idle' status: Literal['idle', 'busy', 'unknown'] = 'idle'
def __eq__(self, o) -> bool: def __eq__(self, o) -> bool:
if isinstance(o, GPUDevice):
return self.node_id == o.node_id and self.gpu_id == o.gpu_id return self.node_id == o.node_id and self.gpu_id == o.gpu_id
return False
def __lt__(self, o) -> bool: def __lt__(self, o) -> bool:
if self.node_id < o.node_id: if self.node_id < o.node_id:
...@@ -23,7 +25,10 @@ class GPUDevice: ...@@ -23,7 +25,10 @@ class GPUDevice:
return "{Environment %s, GPU %d, Status %s}" % (self.node_id, self.gpu_id, self.status) return "{Environment %s, GPU %d, Status %s}" % (self.node_id, self.gpu_id, self.status)
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self.node_id + '_' + self.gpu_id) return hash(self.node_id + '_' + str(self.gpu_id))
def set_status(self, status): def set_status(self, status):
self.status = status self.status = status
def device_repr(self,):
return f"cuda:{self.gpu_id}"
...@@ -115,7 +115,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -115,7 +115,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name)) node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name))
if node_code is not None: if node_code is not None:
if placement and node in placement and len(node_code) > 0: if placement and node in placement and len(node_code) > 0:
node_codes.append(f"{node_code}.to('{placement[node].device}')") node_codes.append(f"{node_code}.to('{placement[node].device_repr()}')")
else: else:
node_codes.append(node_code) node_codes.append(node_code)
......
from typing import Any, Union, Optional, List
import torch
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
from ....serializer import serialize_cls
class BypassPlugin(TrainingTypePlugin):
""" Plugin that handles communication on a single device. """
def __init__(self, device: str):
super().__init__()
self.device: str = device
self.global_rank = 0
self.local_rank = 0
self.world_size = 1
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
self.model_to_device()
return self.model
@property
def on_tpu(self) -> bool:
return False
@property
def on_gpu(self) -> bool:
return "cuda" in self.device and torch.cuda.is_available()
def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
"""
Reduces a tensor from several distributed processes to one aggregated tensor.
As this plugin only operates with a single device, the reduction is simply the identity.
Args:
tensor: the tensor to sync and reduce
*args: ignored
**kwargs: ignored
Return:
the unmodified input as reduction is not needed for single process operation
"""
return tensor
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""Perform a all_gather on all processes """
return tensor
@property
def root_device(self) -> torch.device:
return torch.device(self.device)
def model_to_device(self) -> None:
# bypass device placement from pytorch lightning
pass
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.model_to_device()
return self.model
@property
def is_global_zero(self) -> bool:
return True
def barrier(self, *args, **kwargs) -> None:
pass
def broadcast(self, obj: object, src: int = 0) -> object:
return obj
def get_accelerator_connector(
num_processes: int = 1,
tpu_cores: Optional[Union[List[int], str, int]] = None,
distributed_backend: Optional[str] = None,
auto_select_gpus: bool = False,
gpus: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
sync_batchnorm: bool = False,
benchmark: bool = False,
replace_sampler_ddp: bool = True,
deterministic: bool = False,
precision: int = 32,
amp_backend: str = 'native',
amp_level: str = 'O2',
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None):
return AcceleratorConnector(
num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark,
replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins
)
@serialize_cls
class BypassAccelerator(Accelerator):
def __init__(self, precision_plugin=None, device="cpu"):
if precision_plugin is None:
precision_plugin = get_accelerator_connector().precision_plugin
# pylint: disable=abstract-class-instantiated
super().__init__(precision_plugin=precision_plugin, training_type_plugin=BypassPlugin(device))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import warnings
from typing import Dict, List, Optional, Union
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
from .trainer import Trainer
from ....serializer import serialize_cls
@serialize_cls
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
n_models: int = 0,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.criterion_cls = criterion
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
self.n_models = n_models
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
multi_loss = []
for idx, y_hat in enumerate(multi_y_hat):
loss = self.criterion(y_hat.to("cpu"), y.to("cpu"))
self.log(f'train_loss_{idx}', loss, prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'train_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
multi_loss.append(loss)
return sum(multi_loss)
def validation_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
for idx, y_hat in enumerate(multi_y_hat):
self.log(f'val_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'val_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
for idx, y_hat in enumerate(multi_y_hat):
self.log(f'test_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'test_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
# TODO: split metric of multiple models?
if len(self.metrics) == 1:
metric_name = next(iter(self.metrics))
ret = []
for idx in range(self.n_models):
ret.append(self.trainer.callback_metrics[f'val_{idx}_' + metric_name].item())
return ret
else:
warnings.warn('Multiple metrics without "default" is not supported by current framework.')
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
"""
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
Users who needs cross-graph optimization should use this module.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
"""
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
@serialize_cls
class _ClassificationModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Classification(Lightning):
"""
Trainer that is used for classification.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@serialize_cls
class _RegressionModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Regression(Lightning):
"""
Trainer that is used for regression.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.MSELoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
import pytorch_lightning as pl
from ....serializer import serialize_cls
from .accelerator import BypassAccelerator
@serialize_cls
class Trainer(pl.Trainer):
"""
Trainer for cross-graph optimization.
Parameters
----------
use_cgo : bool
Whether cross-graph optimization (CGO) is used.
If it is True, CGO will manage device placement.
Any device placement from pytorch lightning will be bypassed.
default: False
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def __init__(self, use_cgo=False, **trainer_kwargs):
if use_cgo:
if "accelerator" in trainer_kwargs:
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
trainer_kwargs['accelerator'] = BypassAccelerator(device='cpu')
super().__init__(**trainer_kwargs)
...@@ -12,6 +12,12 @@ import torch.optim as optim ...@@ -12,6 +12,12 @@ import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import nni import nni
try:
import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer
cgo_import_failed = False
except ImportError:
cgo_import_failed = True
from ...graph import Evaluator from ...graph import Evaluator
from ...serializer import serialize_cls from ...serializer import serialize_cls
...@@ -36,7 +42,6 @@ class LightningModule(pl.LightningModule): ...@@ -36,7 +42,6 @@ class LightningModule(pl.LightningModule):
Trainer = serialize_cls(pl.Trainer) Trainer = serialize_cls(pl.Trainer)
DataLoader = serialize_cls(DataLoader) DataLoader = serialize_cls(DataLoader)
class Lightning(Evaluator): class Lightning(Evaluator):
""" """
Delegate the whole training to PyTorch Lightning. Delegate the whole training to PyTorch Lightning.
...@@ -67,7 +72,11 @@ class Lightning(Evaluator): ...@@ -67,7 +72,11 @@ class Lightning(Evaluator):
train_dataloader: Optional[DataLoader] = None, train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None): val_dataloaders: Union[DataLoader, List[DataLoader], None] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.' assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
assert isinstance(trainer, Trainer), f'Trainer must be imported from {__name__}.' if cgo_import_failed:
assert isinstance(trainer, Trainer), f'Trainer must be imported from {__name__}'
else:
assert isinstance(trainer, Trainer) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.' assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.' assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
self.module = lightning_module self.module = lightning_module
...@@ -91,7 +100,21 @@ class Lightning(Evaluator): ...@@ -91,7 +100,21 @@ class Lightning(Evaluator):
return self.fit(model_cls) return self.fit(model_cls)
def __eq__(self, other): def __eq__(self, other):
return self.function == other.function and self.arguments == other.arguments eq_func = False
eq_args = False
if other is None:
return False
if hasattr(self, "function") and hasattr(other, "function"):
eq_func = (self.function == other.function)
elif not (hasattr(self, "function") or hasattr(other, "function")):
eq_func = True
if hasattr(self, "arguments") and hasattr(other, "arguments"):
eq_args = (self.arguments == other.arguments)
elif not (hasattr(self, "arguments") or hasattr(other, "arguments")):
eq_args = True
return eq_func and eq_args
def fit(self, model): def fit(self, model):
""" """
......
...@@ -2,14 +2,22 @@ ...@@ -2,14 +2,22 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import os
import random
import string
import time
import threading
from typing import Iterable, List, Dict, Tuple from typing import Iterable, List, Dict, Tuple
from nni.common.device import GPUDevice
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData from ..graph import Model, ModelStatus, MetricData, Node
from ..integration_api import send_trial, receive_trial_parameters, get_advisor from ..integration_api import send_trial, receive_trial_parameters, get_advisor
from .logical_optimizer.logical_plan import LogicalPlan, PhysicalDevice from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from ..evaluator.pytorch.lightning import Lightning
from ..evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule
from .base import BaseGraphData from .base import BaseGraphData
...@@ -17,29 +25,93 @@ _logger = logging.getLogger(__name__) ...@@ -17,29 +25,93 @@ _logger = logging.getLogger(__name__)
class CGOExecutionEngine(AbstractExecutionEngine): class CGOExecutionEngine(AbstractExecutionEngine):
def __init__(self, devices=None, n_model_per_graph=4) -> None: """
The execution engine with Cross-Graph Optimization (CGO).
Only models using PyTorch Lighting and MultiModelSupervisedLearningModule as the evaluator can be optimized.
Otherwise, a model will be submitted independently without any cross-graph optimization.
Parameters
----------
devices : List[str] or List[GPUDevice]
Available devices for execution.
If a list of str is provided, it will build a list of GPUDevice in a server named ``single_server``
max_concurrency : int
The maximum number of trials to run concurrently.
batch_waiting_time: int
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
"""
def __init__(self, devices: List[GPUDevice] = None,
max_concurrency: int = None,
batch_waiting_time: int = 60,
) -> None:
self._listeners: List[AbstractGraphListener] = [] self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict() self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0 self.logical_plan_counter = 0
self.n_model_per_graph = n_model_per_graph self.available_devices: List[GPUDevice] = []
self.max_concurrency: int = max_concurrency
for device in devices:
self.available_devices.append(device)
self.all_devices = self.available_devices.copy()
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
self._optimizers = [DedupInputOptimizer()] self._optimizers = [DedupInputOptimizer()]
self._original_models = {} self._original_models = {}
self._original_model_to_multi_model = {} self._original_model_to_multi_model = {}
self.devices = [] if devices is None else devices self._trial_to_original_models = {}
self._trial_used_devices: Dict[int, List[GPUDevice]] = {}
self._history: List[Model] = []
self._queuing_jobs: List[Model] = []
self._queue_lock = threading.Lock()
# register advisor callbacks # register advisor callbacks
advisor = get_advisor() advisor = get_advisor()
advisor.send_trial_callback = self._send_trial_callback # advisor.send_trial_callback = self._send_trial_callback
advisor.request_trial_jobs_callback = self._request_trial_jobs_callback # advisor.request_trial_jobs_callback = self._request_trial_jobs_callback
advisor.trial_end_callback = self._trial_end_callback advisor.trial_end_callback = self._trial_end_callback
advisor.intermediate_metric_callback = self._intermediate_metric_callback advisor.intermediate_metric_callback = self._intermediate_metric_callback
advisor.final_metric_callback = self._final_metric_callback advisor.final_metric_callback = self._final_metric_callback
self._stopped = False
self._consumer_thread = threading.Thread(target=self._consume_queue)
self._consumer_thread.start()
def join(self):
self._stopped = True
self._consumer_thread.join()
def add_optimizer(self, opt): def add_optimizer(self, opt):
self._optimizers.append(opt) self._optimizers.append(opt)
def submit_models(self, *models: List[Model]) -> None: def submit_models(self, *models: List[Model]) -> None:
curr_time = time.time()
_logger.info('%d models are submitted', len(models)) _logger.info('%d models are submitted', len(models))
self._queue_lock.acquire()
self._queuing_jobs.extend([(curr_time, _) for _ in models])
self._queue_lock.release()
def _consume_queue(self):
# a thread to monitor self.queuing_jobs to consume them in batch
while not self._stopped:
if len(self._queuing_jobs) > 0:
curr_time = time.time()
self._queue_lock.acquire()
if (self.max_concurrency and len(self._queuing_jobs) >= self.max_concurrency):
self._submit_models_in_batch(*[_[1] for _ in self._queuing_jobs[:self.max_concurrency]])
self._queuing_jobs = self._queuing_jobs[self.max_concurrency:]
elif len(self.available_devices) <= len(self._queuing_jobs) or \
(curr_time - self._queuing_jobs[0][0] > self._batch_waiting_time):
self._submit_models_in_batch(*[_[1] for _ in self._queuing_jobs])
self._queuing_jobs = []
self._queue_lock.release()
time.sleep(1)
def _submit_models_in_batch(self, *models: List[Model]) -> None:
_logger.info('%d models are submitted in batch', len(models))
logical = self._build_logical(models) logical = self._build_logical(models)
for opt in self._optimizers: for opt in self._optimizers:
...@@ -47,31 +119,51 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -47,31 +119,51 @@ class CGOExecutionEngine(AbstractExecutionEngine):
phy_models_and_placements = self._assemble(logical) phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements: for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator)
model.evaluator) trial_id = send_trial(data.dump())
# unique non-cpu devices used by the trial
self._trial_used_devices[trial_id] = list([_ for _ in set(placement.values()) if isinstance(_, GPUDevice)])
# currently, it is impossible for search strategy to submit models more than the number of available devices
for used_device in self._trial_used_devices[trial_id]:
self.available_devices.remove(used_device) # used_device must be in self.available_devices
self._running_models[trial_id] = model
self._trial_to_original_models[trial_id] = []
for m in grouped_models: for m in grouped_models:
self._original_models[m.model_id] = m self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model self._original_model_to_multi_model[m.model_id] = model
self._running_models[send_trial(data.dump())] = model self._trial_to_original_models[trial_id].append(m.model_id)
self._history.append(m)
# for model in models:
# data = BaseGraphData(codegen.model_to_pytorch_script(model),
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
def list_models(self) -> Iterable[Model]: def list_models(self) -> Iterable[Model]:
raise NotImplementedError return self._history
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]: def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, GPUDevice], List[Model]]]:
# unique_models = set() # try to use the available_devices first so that it can be launched as early as possible
# for node in logical_plan.graph.nodes: # if free devices are not enough to assemble all models in one trial, try all devices
# if node.graph.model not in unique_models: if len(self.available_devices) > 0:
# unique_models.add(node.graph.model) grouped_models: List[Dict[Model, GPUDevice]] = AssemblePolicy().group(logical_plan, self.available_devices)
# return [m for m in unique_models]
grouped_models: List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan) if len(self.available_devices) == 0 or len(grouped_models) > 1:
grouped_models: List[Dict[Model, GPUDevice]] = AssemblePolicy().group(logical_plan, self.all_devices)
phy_models_and_placements = [] phy_models_and_placements = []
for multi_model in grouped_models: for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model) model, model_placement = logical_plan.assemble(multi_model)
assert isinstance(model.evaluator, Lightning), \
"cross-graph optimization only supports pytorch lighting as evaluator"
assert isinstance(model.evaluator.module, _MultiModelSupervisedLearningModule), \
"cross-graph optimization only support MultiModelSupervisedLearningModule"
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params = model.evaluator.module._init_parameters.copy()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params['n_models'] = len(multi_model)
new_module = _MultiModelSupervisedLearningModule(**new_module_init_params)
model.evaluator.module = new_module
phy_models_and_placements.append((model, model_placement, multi_model.keys())) phy_models_and_placements.append((model, model_placement, multi_model.keys()))
return phy_models_and_placements return phy_models_and_placements
...@@ -85,13 +177,14 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -85,13 +177,14 @@ class CGOExecutionEngine(AbstractExecutionEngine):
def register_graph_listener(self, listener: AbstractGraphListener) -> None: def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener) self._listeners.append(listener)
def _send_trial_callback(self, paramater: dict) -> None: # def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners: # if len(self.available_devices) == 0:
listener.on_resource_used(0) # FIXME: find the real resource id # _logger.warning('There is no available devices, but trial is submitted.')
# _logger.debug('Resource used. Remaining: %d', len(self.available_devices))
def _request_trial_jobs_callback(self, num_trials: int) -> None: # def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners: # self.resources += num_trials
listener.on_resource_available([0] * num_trials) # FIXME: find the real resource id # _logger.info('on_resource_available: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None: def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id] model = self._running_models[trial_id]
...@@ -108,31 +201,40 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -108,31 +201,40 @@ class CGOExecutionEngine(AbstractExecutionEngine):
original_model.status = ModelStatus.Failed original_model.status = ModelStatus.Failed
for listener in self._listeners: for listener in self._listeners:
listener.on_training_end(original_model, success) listener.on_training_end(original_model, success)
self.available_devices.extend(self._trial_used_devices[trial_id])
self.available_devices = sorted(list(set(self.available_devices)))
del self._running_models[trial_id]
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None: def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
# model = self._running_models[trial_id] merged_metrics = {}
merged_metrics = dict(metrics) for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics: for model_id in merged_metrics:
int_model_id = int(model_id) self._original_models[model_id].intermediate_metrics.append(merged_metrics[model_id])
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
# model.intermediate_metrics.append(metrics)
for listener in self._listeners: for listener in self._listeners:
listener.on_intermediate_metric(self._original_models[int_model_id], merged_metrics[model_id]) listener.on_intermediate_metric(self._original_models[model_id], merged_metrics[model_id])
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None: def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
merged_metrics = dict(metrics) _logger.debug(metrics)
if isinstance(metrics, float):
self._listeners[0].on_metric(self._running_models[trial_id], metrics)
else:
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics: for model_id in merged_metrics:
int_model_id = int(model_id) self._original_models[model_id].metric = merged_metrics[model_id]
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
# model.intermediate_metrics.append(metrics)
for listener in self._listeners: for listener in self._listeners:
listener.on_metric(self._original_models[int_model_id], merged_metrics[model_id]) listener.on_metric(self._original_models[model_id], merged_metrics[model_id])
def query_available_resource(self) -> List[WorkerInfo]: def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here? # the _queuing_jobs need to use available_devices first
return len(self.available_devices) - len(self._queuing_jobs)
def budget_exhausted(self) -> bool: def budget_exhausted(self) -> bool:
raise NotImplementedError advisor = get_advisor()
return advisor.stopping
@classmethod @classmethod
def trial_execute_graph(cls) -> None: def trial_execute_graph(cls) -> None:
...@@ -141,20 +243,86 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -141,20 +243,86 @@ class CGOExecutionEngine(AbstractExecutionEngine):
""" """
graph_data = BaseGraphData.load(receive_trial_parameters()) graph_data = BaseGraphData.load(receive_trial_parameters())
_logger.info('CGO_ENGINE trial parameters received') _logger.info('CGO_ENGINE trial parameters received')
with open('_generated_model.py', 'w') as f: random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script) f.write(graph_data.model_script)
# with open('_debug_graph_data.json', 'w') as f:
# json.dump(graph_data.dump(), f) trainer_instance = graph_data.evaluator
trainer_cls = utils.import_(graph_data.training_module) model_cls = utils.import_(f'_generated_model.{random_str}._model')
model_cls = utils.import_(f"_generated_model.{graph_data.training_kwargs['model_cls']}")
trainer_instance = trainer_cls(model_cls(), graph_data.training_kwargs) trainer_instance.fit(model_cls())
trainer_instance.fit() os.remove(file_name)
def _remap_cuda_device(group_model: Dict[Model, GPUDevice]):
used_devices = {}
for m in group_model:
if group_model[m].node_id not in used_devices:
used_devices[group_model[m].node_id] = {}
if isinstance(group_model[m], GPUDevice):
if group_model[m].gpu_id not in used_devices[group_model[m].node_id]:
n_used_gpu_in_server = len(used_devices[group_model[m].node_id])
used_devices[group_model[m].node_id][group_model[m].gpu_id] = n_used_gpu_in_server
group_model[m].gpu_id = used_devices[group_model[m].node_id][group_model[m].gpu_id]
return group_model
class AssemblePolicy: class AssemblePolicy:
@staticmethod @staticmethod
def group(logical_plan): def _is_related_node(model: Model, node: Node):
if isinstance(node, AbstractLogicalNode):
if model in node.related_models:
return True
else:
if model == node.graph.model:
return True
return False
@staticmethod
def _check_graph_connectivity(model: Model,
group_model: Dict[Model, GPUDevice],
logical_plan: LogicalPlan) -> bool:
for edge in logical_plan.logical_graph.edges:
if AssemblePolicy._is_related_node(model, edge.head) or \
AssemblePolicy._is_related_node(model, edge.tail):
for grouped_model in group_model:
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
AssemblePolicy._is_related_node(grouped_model, edge.tail):
return True
return False
@staticmethod
def _check_evaluator(new_model: Model, group_model: Dict[Model, GPUDevice]) -> bool:
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, MultiModelSupervisedLearningModule)):
return False
for m in group_model:
if not m.evaluator == new_model.evaluator:
return False
return True
@staticmethod
def group(logical_plan, available_devices):
# TODO: Packing multiple model in one GPU
# Currently, we only support one model per GPU
all_grouped_models = []
group_model = {} group_model = {}
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
for idx, m in enumerate(logical_plan.models): for idx, m in enumerate(logical_plan.models):
group_model[m] = PhysicalDevice('server', f'cuda:{idx}') # models in one group should
return [group_model] # (1) not use more GPUs than available_devices
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
# (3) use same MultiModelSupervisedLearningModule
if len(group_model) > 0 and \
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
AssemblePolicy._check_evaluator(m, group_model) == False):
all_grouped_models.append(_remap_cuda_device(group_model))
group_model = {}
group_model[m] = available_devices[idx % len(available_devices)]
if len(group_model) == len(available_devices) or \
idx == len(logical_plan.models) - 1:
all_grouped_models.append(_remap_cuda_device(group_model))
group_model = {}
return all_grouped_models
...@@ -2,30 +2,30 @@ ...@@ -2,30 +2,30 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import copy import copy
from typing import Dict, Tuple, List, Any from typing import Dict, Tuple, Any, Union
from nni.retiarii.utils import uid from nni.retiarii.utils import uid
from nni.common.device import GPUDevice
from ...graph import Cell, Edge, Graph, Model, Node from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation from ...operation import Operation, _IOPseudoOperation
class PhysicalDevice: class CPUDevice:
def __init__(self, server: str, device: str): def __init__(self, node_id):
self.server = server self.node_id = node_id
self.device = device self.device = 'cpu'
def __eq__(self, o) -> bool:
return self.server == o.server and self.device == o.device
def __hash__(self) -> int: def device_repr(self):
return hash(self.server + '_' + self.device) return "cpu"
class AbstractLogicalNode(Node): class AbstractLogicalNode(Node):
def __init__(self, graph, node_id, name, operation, _internal=False): def __init__(self, graph, node_id, name, operation, _internal=False):
super().__init__(graph, node_id, name, operation, _internal=_internal) super().__init__(graph, node_id, name, operation, _internal=_internal)
self.related_models = []
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]: def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
raise NotImplementedError raise NotImplementedError
def _fork_to(self, graph: Graph): def _fork_to(self, graph: Graph):
...@@ -40,8 +40,7 @@ class LogicalGraph(Graph): ...@@ -40,8 +40,7 @@ class LogicalGraph(Graph):
nodes_dump = {} nodes_dump = {}
for node in self.hidden_nodes: for node in self.hidden_nodes:
if isinstance(node, OriginNode): if isinstance(node, OriginNode):
nodes_dump[f"{node.original_graph.model.model_id}_{node.name}"] = node._dump( nodes_dump[f"{node.original_graph.model.model_id}_{node.name}"] = node._dump()
)
else: else:
nodes_dump[f"{node.graph.model.model_id}_{node.name}"] = node._dump() nodes_dump[f"{node.graph.model.model_id}_{node.name}"] = node._dump()
...@@ -93,7 +92,7 @@ class OriginNode(AbstractLogicalNode): ...@@ -93,7 +92,7 @@ class OriginNode(AbstractLogicalNode):
self.original_graph = original_graph self.original_graph = original_graph
self.original_node = original_node self.original_node = original_node
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]: def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
model_id = self.original_node.graph.model.model_id model_id = self.original_node.graph.model.model_id
new_node = Node(self.original_node.graph, self.original_node.id, new_node = Node(self.original_node.graph, self.original_node.id,
f"M_{model_id}_" + f"M_{model_id}_" +
...@@ -137,30 +136,32 @@ class LogicalPlan: ...@@ -137,30 +136,32 @@ class LogicalPlan:
for edge in from_graph.edges: for edge in from_graph.edges:
new_head = id_to_new_node[edge.head.id] new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id] new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
edge.tail_slot), _internal=True)._register()
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \ def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) \
-> Tuple[Model, Dict[Node, PhysicalDevice], List[Model]]: -> Tuple[Model, Dict[Node, Union[GPUDevice, CPUDevice]]]:
phy_model = Model(_internal=True) # self.lp_model.fork() phy_model = Model(_internal=True)
phy_graph = self.lp_model.root_graph._fork_to(phy_model) phy_graph = self.lp_model.root_graph._fork_to(phy_model)
phy_graph._rename_graph(phy_graph.name, "_model")
# Add a flag to mark multi-model in graph json.
# Multi-model has a list of training configs in kwargs['model_kwargs']
if len(multi_model_placement) > 1:
phy_model.evaluator.kwargs['is_multi_model'] = True
phy_model.evaluator.kwargs['model_cls'] = phy_graph.name
phy_model.evaluator.kwargs['model_kwargs'] = []
# FIXME: allow user to specify
phy_model.evaluator.module = 'nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer'
# merge sub-graphs # merge sub-graphs
for model in multi_model_placement: for model in multi_model_placement:
if phy_model.evaluator is None and model.evaluator is not None:
phy_model.evaluator = model.evaluator
for graph_name in model.graphs: for graph_name in model.graphs:
if graph_name != model._root_graph_name: if graph_name != model._root_graph_name:
model.graphs[graph_name]._fork_to( new_graph = model.graphs[graph_name]._fork_to(
phy_model, name_prefix=f'M_{model.model_id}_') phy_model, name_prefix=f'M_{model.model_id}_')
# prefix of M_ of hidden_nodes name in non-root graphs is added here
for new_node in new_graph.hidden_nodes:
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model.model_id}_{old_cell_name}'
assert(phy_model.evaluator is not None)
# When replace logical nodes, merge the training configs when # When replace logical nodes, merge the training configs when
# input/output nodes are replaced. # input/output nodes are replaced.
evaluator_slot = {} # Model ID -> Slot ID evaluator_slot = {} # Model ID -> Slot ID
...@@ -169,6 +170,9 @@ class LogicalPlan: ...@@ -169,6 +170,9 @@ class LogicalPlan:
# Replace all logical nodes to executable physical nodes # Replace all logical nodes to executable physical nodes
hidden_nodes = phy_graph.hidden_nodes.copy() hidden_nodes = phy_graph.hidden_nodes.copy()
node_placements = {} node_placements = {}
added_models = []
for node in hidden_nodes: for node in hidden_nodes:
if isinstance(node, OriginNode): if isinstance(node, OriginNode):
model_id = node.original_graph.model.model_id model_id = node.original_graph.model.model_id
...@@ -185,12 +189,9 @@ class LogicalPlan: ...@@ -185,12 +189,9 @@ class LogicalPlan:
if isinstance(new_node.operation, _IOPseudoOperation): if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id model_id = new_node.graph.model.model_id
if model_id not in evaluator_slot: if model_id not in evaluator_slot:
phy_model.evaluator.kwargs['model_kwargs'].append(new_node.graph.model.evaluator.kwargs.copy()) added_models.append(model_id)
evaluator_slot[model_id] = len(phy_model.evaluator.kwargs['model_kwargs']) - 1 evaluator_slot[model_id] = len(added_models) - 1
slot = evaluator_slot[model_id] slot = evaluator_slot[model_id]
phy_model.evaluator.kwargs['model_kwargs'][slot]['model_id'] = model_id
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_input'] = False
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_output'] = False
else: else:
slot = evaluator_slot[model_id] slot = evaluator_slot[model_id]
# If a model's inputs/outputs are not used in the multi-model # If a model's inputs/outputs are not used in the multi-model
...@@ -199,17 +200,23 @@ class LogicalPlan: ...@@ -199,17 +200,23 @@ class LogicalPlan:
# an input/output of a model is used in a multi-model # an input/output of a model is used in a multi-model
if new_node.operation.type == '_inputs': if new_node.operation.type == '_inputs':
input_slot_mapping[new_node] = slot input_slot_mapping[new_node] = slot
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_input'] = True
if new_node.operation.type == '_outputs': if new_node.operation.type == '_outputs':
output_slot_mapping[new_node] = slot output_slot_mapping[new_node] = slot
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_output'] = True
self.node_replace(node, new_node) self.node_replace(node, new_node)
# name prefix of M_ of cells in hidden_nodes of root graphs is added here
# FIXME: merge this rename with non-root graph, only do once.
if isinstance(new_node.operation, Cell): if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation) new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model_id}_{old_cell_name}' new_node.operation.cell_name = f'M_{model_id}_{old_cell_name}'
# input should be at CPU, move it to GPU first if necessary
if isinstance(new_node.operation, _IOPseudoOperation) and new_node.operation.type == '_inputs':
# hack: only support single_server
node_placements[new_node] = CPUDevice(node_id=placement.node_id)
else:
node_placements[new_node] = placement node_placements[new_node] = placement
node.remove() node.remove()
...@@ -217,19 +224,23 @@ class LogicalPlan: ...@@ -217,19 +224,23 @@ class LogicalPlan:
# If two nodes are placed on different devices, use ToDevice op to copy the node # If two nodes are placed on different devices, use ToDevice op to copy the node
existing_edges = phy_graph.edges.copy() existing_edges = phy_graph.edges.copy()
# Avoid a node is copied multiple times on the same device # Avoid a node is copied multiple times on the same device
copied_op: Dict[Tuple(Node, PhysicalDevice), Node] = {} copied_op: Dict[Tuple(Node, Union[GPUDevice, CPUDevice]), Node] = {}
for edge in existing_edges: for edge in existing_edges:
head_placement = node_placements[edge.head] head_placement = node_placements[edge.head]
tail_placement = node_placements[edge.tail] tail_placement = node_placements[edge.tail]
if head_placement != tail_placement: if head_placement != tail_placement:
if head_placement.server != tail_placement.server: if head_placement.node_id != tail_placement.node_id:
raise ValueError('Cross-server placement is not supported.') raise ValueError('Cross-server placement is not supported.')
# Same server different devices # Same server different devices
if (edge.head, tail_placement) in copied_op: if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)] to_node = copied_op[(edge.head, tail_placement)]
else: else:
to_operation = Operation.new('ToDevice', {"device": tail_placement.device}) dst_name = edge.head.name + "_to_" + edge.tail.name
to_node = Node(phy_graph, uid(), edge.head.name + "_to_" + edge.tail.name, to_operation)._register() to_operation = Operation.new(
'ToDevice', {
"device": tail_placement.device_repr(), "src": (
edge.head.name, edge.head_slot), "dst": dst_name})
to_node = Node(phy_graph, uid(), dst_name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register() Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node copied_op[(edge.head, tail_placement)] = to_node
edge.head = to_node edge.head = to_node
......
...@@ -4,23 +4,28 @@ ...@@ -4,23 +4,28 @@
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
from nni.retiarii.utils import uid from nni.retiarii.utils import uid
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
from nni.common.device import GPUDevice
from ...graph import Graph, Model, Node from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan, from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode, PhysicalDevice) OriginNode)
_supported_training_modules = ['nni.retiarii.trainer.pytorch.PyTorchImageClassificationTrainer'] _supported_evaluators = [MultiModelSupervisedLearningModule]
class DedupInputNode(AbstractLogicalNode): class DedupInputNode(AbstractLogicalNode):
def __init__(self, logical_graph: LogicalGraph, node_id: int, def __init__(self, logical_graph: LogicalGraph, node_id: int,
nodes_to_dedup: List[Node], _internal=False): nodes_to_dedup: List[Node], _internal=False):
super().__init__(logical_graph, node_id, super().__init__(logical_graph, node_id,
"Dedup_"+nodes_to_dedup[0].name, "Dedup_" + nodes_to_dedup[0].name,
nodes_to_dedup[0].operation) nodes_to_dedup[0].operation)
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy() self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
self.related_models = [_.original_graph.model for _ in self.origin_nodes]
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]: def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
for node in self.origin_nodes: for node in self.origin_nodes:
if node.original_graph.model in multi_model_placement: if node.original_graph.model in multi_model_placement:
new_node = Node(node.original_graph, node.id, new_node = Node(node.original_graph, node.id,
...@@ -41,6 +46,12 @@ class DedupInputOptimizer(AbstractOptimizer): ...@@ -41,6 +46,12 @@ class DedupInputOptimizer(AbstractOptimizer):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def _check_supported_evaluator(self, evaluator):
for e in _supported_evaluators:
if isinstance(evaluator, e):
return True
return False
def _check_deduplicate_by_node(self, root_node, node_to_check): def _check_deduplicate_by_node(self, root_node, node_to_check):
if root_node == node_to_check: if root_node == node_to_check:
return True return True
...@@ -48,7 +59,7 @@ class DedupInputOptimizer(AbstractOptimizer): ...@@ -48,7 +59,7 @@ class DedupInputOptimizer(AbstractOptimizer):
node_to_check.operation.type == '_inputs' and \ node_to_check.operation.type == '_inputs' and \
isinstance(root_node, OriginNode) and \ isinstance(root_node, OriginNode) and \
isinstance(node_to_check, OriginNode): isinstance(node_to_check, OriginNode):
if root_node.original_graph.model.evaluator.module not in _supported_training_modules: if self._check_supported_evaluator(root_node.original_graph.model.evaluator):
return False return False
if root_node.original_graph.model.evaluator == node_to_check.original_graph.model.evaluator: if root_node.original_graph.model.evaluator == node_to_check.original_graph.model.evaluator:
return True return True
...@@ -68,7 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer): ...@@ -68,7 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer):
continue continue
root_node = node root_node = node
break break
if root_node == None: if root_node is None:
break # end of convert break # end of convert
else: else:
nodes_to_dedup = [] nodes_to_dedup = []
......
...@@ -50,8 +50,11 @@ class RetiariiExeConfig(ConfigBase): ...@@ -50,8 +50,11 @@ class RetiariiExeConfig(ConfigBase):
trial_code_directory: PathLike = '.' trial_code_directory: PathLike = '.'
trial_concurrency: int trial_concurrency: int
trial_gpu_number: int = 0 trial_gpu_number: int = 0
devices: Optional[List[Union[str, GPUDevice]]] = None
max_experiment_duration: Optional[str] = None max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None max_trial_number: Optional[int] = None
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
nni_manager_ip: Optional[str] = None nni_manager_ip: Optional[str] = None
debug: bool = False debug: bool = False
log_level: Optional[str] = None log_level: Optional[str] = None
...@@ -139,6 +142,7 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_ ...@@ -139,6 +142,7 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_
applied_mutators = mutators applied_mutators = mutators
return base_model_ir, applied_mutators return base_model_ir, applied_mutators
def debug_mutated_model(base_model, trainer, applied_mutators): def debug_mutated_model(base_model, trainer, applied_mutators):
""" """
Locally run only one trial without launching an experiment for debug purpose, then exit. Locally run only one trial without launching an experiment for debug purpose, then exit.
...@@ -189,7 +193,7 @@ class RetiariiExperiment(Experiment): ...@@ -189,7 +193,7 @@ class RetiariiExperiment(Experiment):
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit') _logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI # TODO: find out a proper way to show no more trial message on WebUI
#self._dispatcher.mark_experiment_as_ending() # self._dispatcher.mark_experiment_as_ending()
def start(self, port: int = 8080, debug: bool = False) -> None: def start(self, port: int = 8080, debug: bool = False) -> None:
""" """
...@@ -205,14 +209,18 @@ class RetiariiExperiment(Experiment): ...@@ -205,14 +209,18 @@ class RetiariiExperiment(Experiment):
""" """
atexit.register(self.stop) atexit.register(self.stop)
devices = self._construct_devices()
# we will probably need a execution engine factory to make this clean and elegant # we will probably need a execution engine factory to make this clean and elegant
if self.config.execution_engine == 'base': if self.config.execution_engine == 'base':
from ..execution.base import BaseExecutionEngine from ..execution.base import BaseExecutionEngine
engine = BaseExecutionEngine() engine = BaseExecutionEngine()
elif self.config.execution_engine == 'cgo': elif self.config.execution_engine == 'cgo':
from ..execution.cgo_engine import CGOExecutionEngine from ..execution.cgo_engine import CGOExecutionEngine
engine = CGOExecutionEngine(devices = devices) # assert self.config.trial_gpu_number==1, "trial_gpu_number must be 1 to use CGOExecutionEngine"
assert self.config.batch_waiting_time is not None
devices = self._construct_devices()
engine = CGOExecutionEngine(devices,
max_concurrency=self.config.max_concurrency_cgo,
batch_waiting_time=self.config.batch_waiting_time)
elif self.config.execution_engine == 'py': elif self.config.execution_engine == 'py':
from ..execution.python import PurePythonExecutionEngine from ..execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine() engine = PurePythonExecutionEngine()
......
...@@ -410,7 +410,7 @@ class Graph: ...@@ -410,7 +410,7 @@ class Graph:
return self is other return self is other
def _fork_to(self, model: Model, name_prefix='') -> 'Graph': def _fork_to(self, model: Model, name_prefix='') -> 'Graph':
new_graph = Graph(model, self.id, name_prefix+self.name, _internal=True)._register() new_graph = Graph(model, self.id, name_prefix + self.name, _internal=True)._register()
# TODO: use node copy instead # TODO: use node copy instead
new_graph.input_node.operation.io_names = self.input_node.operation.io_names new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names new_graph.output_node.operation.io_names = self.output_node.operation.io_names
...@@ -458,6 +458,11 @@ class Graph: ...@@ -458,6 +458,11 @@ class Graph:
self.model.graphs[self.name] = self self.model.graphs[self.name] = self
return self return self
def _rename_graph(self, old_name, new_name):
self.model.graphs[old_name].name = new_name
self.model.graphs[new_name] = self.model.graphs[old_name]
del self.model.graphs[old_name]
@staticmethod @staticmethod
def _load(model: Model, name: str, ir: Any) -> 'Graph': def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, uid(), name, _internal=True) graph = Graph(model, uid(), name, _internal=True)
......
...@@ -98,6 +98,10 @@ class PyTorchOperation(Operation): ...@@ -98,6 +98,10 @@ class PyTorchOperation(Operation):
if hasattr(subclass, '_ori_type_name') and \ if hasattr(subclass, '_ori_type_name') and \
subclass_name in subclass._ori_type_name: subclass_name in subclass._ori_type_name:
return subclass return subclass
for subclass in cls.__subclasses__():
if hasattr(subclass, '_artificial_op_name') and \
subclass_name in subclass._artificial_op_name:
return subclass
return cls return cls
@classmethod @classmethod
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import (Any, List) from typing import (Any, Dict, List)
import torch import torch
...@@ -32,21 +32,27 @@ scalar_type_to_pytorch_type = [ ...@@ -32,21 +32,27 @@ scalar_type_to_pytorch_type = [
'torch.bool', # 11 'torch.bool', # 11
] ]
class NoOpIdentity(PyTorchOperation): class NoOpIdentity(PyTorchOperation):
""" """
this operator type is added by us this operator type is added by us
""" """
_ori_type_name = ['noop_identity'] _ori_type_name = ['noop_identity']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {", ".join(inputs)}' return f'{output} = {", ".join(inputs)}'
class ModuleOperator(PyTorchOperation): class ModuleOperator(PyTorchOperation):
_ori_type_name = ['ModuleOperator', 'shared'] _ori_type_name = ['ModuleOperator', 'shared']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = self.{field}({", ".join(inputs)})' return f'{output} = self.{field}({", ".join(inputs)})'
class FunctionalOperator(PyTorchOperation): class FunctionalOperator(PyTorchOperation):
_ori_type_name = ['FunctionalOperator'] _ori_type_name = ['FunctionalOperator']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
func_name = self.type[len('Function.'):] func_name = self.type[len('Function.'):]
if not hasattr(torch.nn.functional, func_name): if not hasattr(torch.nn.functional, func_name):
...@@ -54,8 +60,10 @@ class FunctionalOperator(PyTorchOperation): ...@@ -54,8 +60,10 @@ class FunctionalOperator(PyTorchOperation):
f'{func_name} is not in it.') f'{func_name} is not in it.')
return f'{output} = F.{func_name}({", ".join(inputs)})' return f'{output} = F.{func_name}({", ".join(inputs)})'
class PrimConstant(PyTorchOperation): class PrimConstant(PyTorchOperation):
_ori_type_name = ['prim::Constant'] _ori_type_name = ['prim::Constant']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant # TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types # TODO: deal with all the types
...@@ -75,63 +83,83 @@ class PrimConstant(PyTorchOperation): ...@@ -75,63 +83,83 @@ class PrimConstant(PyTorchOperation):
else: else:
raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}') raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}')
class PrimListConstruct(PyTorchOperation): class PrimListConstruct(PyTorchOperation):
_ori_type_name = ['prim::ListConstruct'] _ori_type_name = ['prim::ListConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = [{", ".join(inputs)}]' return f'{output} = [{", ".join(inputs)}]'
class PrimListUnpack(PyTorchOperation): class PrimListUnpack(PyTorchOperation):
_ori_type_name = ['prim::ListUnpack'] _ori_type_name = ['prim::ListUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]}' return f'{output} = {inputs[0]}'
class PrimTupleConstruct(PyTorchOperation): class PrimTupleConstruct(PyTorchOperation):
_ori_type_name = ['prim::TupleConstruct'] _ori_type_name = ['prim::TupleConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = ({", ".join(inputs)})' return f'{output} = ({", ".join(inputs)})'
class PrimTupleUnpack(PyTorchOperation): class PrimTupleUnpack(PyTorchOperation):
_ori_type_name = ['prim::TupleUnpack'] _ori_type_name = ['prim::TupleUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# have single output here, because the following code uses index to access the unpacked values # have single output here, because the following code uses index to access the unpacked values
assert len(inputs) == 1 assert len(inputs) == 1
return f'{output} = {inputs[0]}' return f'{output} = {inputs[0]}'
class PrimGetAttr(PyTorchOperation): class PrimGetAttr(PyTorchOperation):
_ori_type_name = ['prim::GetAttr'] _ori_type_name = ['prim::GetAttr']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
if self.parameters['value'] is not None: if self.parameters['value'] is not None:
return f"{output} = {self.parameters['value']}" return f"{output} = {self.parameters['value']}"
else: else:
return f"{output} = {self.parameters['input']}.{self.parameters['name']}" return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
class SimpleMember(PyTorchOperation): class SimpleMember(PyTorchOperation):
_ori_type_name = ['prim::is_cuda', 'prim::data'] _ori_type_name = ['prim::is_cuda', 'prim::data']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
member_name = self.type.split('::')[-1] member_name = self.type.split('::')[-1]
return f'{output} = {inputs[0]}.{member_name}' return f'{output} = {inputs[0]}.{member_name}'
class AtenContiguous(PyTorchOperation): class AtenContiguous(PyTorchOperation):
_ori_type_name = ['aten::contiguous'] _ori_type_name = ['aten::contiguous']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# defined in pytorch/c10/core/MemoryFormat.h # defined in pytorch/c10/core/MemoryFormat.h
assert inputs_value[1] in [0, 1, 2] assert inputs_value[1] in [0, 1, 2]
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})' return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
class AtenGetitem(PyTorchOperation): class AtenGetitem(PyTorchOperation):
_ori_type_name = ['aten::__getitem__'] _ori_type_name = ['aten::__getitem__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
assert len(inputs) == 2 assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]' return f'{output} = {inputs[0]}[{inputs[1]}]'
class AtenAppend(PyTorchOperation): class AtenAppend(PyTorchOperation):
_ori_type_name = ['aten::append'] _ori_type_name = ['aten::append']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
assert len(inputs) == 2 assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}' return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
class MergedSlice(PyTorchOperation): class MergedSlice(PyTorchOperation):
_ori_type_name = ['MergedSlice'] _ori_type_name = ['MergedSlice']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
if (len(inputs) - 1) % 4 == 0: if (len(inputs) - 1) % 4 == 0:
slices = [] slices = []
...@@ -148,23 +176,30 @@ class MergedSlice(PyTorchOperation): ...@@ -148,23 +176,30 @@ class MergedSlice(PyTorchOperation):
# the following Aten classes means these aten ops are not in torch.Tensor # the following Aten classes means these aten ops are not in torch.Tensor
class AtenBool(PyTorchOperation): class AtenBool(PyTorchOperation):
_ori_type_name = ['aten::Bool'] _ori_type_name = ['aten::Bool']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = bool({inputs[0]})' return f'{output} = bool({inputs[0]})'
class AtenNot(PyTorchOperation): class AtenNot(PyTorchOperation):
_ori_type_name = ['aten::__not__'] _ori_type_name = ['aten::__not__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = not {inputs[0]}' return f'{output} = not {inputs[0]}'
class AtenCat(PyTorchOperation): class AtenCat(PyTorchOperation):
_ori_type_name = ['aten::cat'] _ori_type_name = ['aten::cat']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
assert len(inputs) == 2 assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})' return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
#==================================== # ====================================
class AtenTensors(PyTorchOperation): class AtenTensors(PyTorchOperation):
_ori_type_name = ['aten::full', 'aten::full_like', 'aten::empty_like', _ori_type_name = ['aten::full', 'aten::full_like', 'aten::empty_like',
...@@ -209,20 +244,26 @@ class AtenTensors(PyTorchOperation): ...@@ -209,20 +244,26 @@ class AtenTensors(PyTorchOperation):
else: else:
return f'{output} = {inputs[0]}.{op_name}({", ".join(args_list[1:])})' return f'{output} = {inputs[0]}.{op_name}({", ".join(args_list[1:])})'
#==================================== # ====================================
class AtenFloordiv(PyTorchOperation): class AtenFloordiv(PyTorchOperation):
_ori_type_name = ['aten::floordiv'] _ori_type_name = ['aten::floordiv']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]} // {inputs[1]}' return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenLen(PyTorchOperation): class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len'] _ori_type_name = ['aten::len']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = len({inputs[0]})' return f'{output} = len({inputs[0]})'
class AtenIntImplicit(PyTorchOperation): class AtenIntImplicit(PyTorchOperation):
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit'] _ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
if self.type.endswith('Implicit'): if self.type.endswith('Implicit'):
return f'{output} = {inputs[0]}' return f'{output} = {inputs[0]}'
...@@ -231,11 +272,14 @@ class AtenIntImplicit(PyTorchOperation): ...@@ -231,11 +272,14 @@ class AtenIntImplicit(PyTorchOperation):
elif self.type == 'aten::Float': elif self.type == 'aten::Float':
return f'{output} = float({inputs[0]})' return f'{output} = float({inputs[0]})'
class AtenIndex(PyTorchOperation): class AtenIndex(PyTorchOperation):
_ori_type_name = ['aten::index'] _ori_type_name = ['aten::index']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]}[{inputs[1]}]' return f'{output} = {inputs[0]}[{inputs[1]}]'
ManuallyChooseDef = { ManuallyChooseDef = {
'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')], 'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')],
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')], 'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')],
...@@ -255,15 +299,18 @@ TensorOpExceptions = { ...@@ -255,15 +299,18 @@ TensorOpExceptions = {
TorchOpExclude = ['aten::Size', 'aten::as_tensor', 'aten::device', TorchOpExclude = ['aten::Size', 'aten::as_tensor', 'aten::device',
'aten::manual_seed', 'aten::quantized_gru', 'aten::quantized_lstm', 'aten::manual_seed', 'aten::quantized_gru', 'aten::quantized_lstm',
'aten::save', 'aten::tensor', 'aten::wait' 'aten::save', 'aten::tensor', 'aten::wait'
] ]
def _hidden(name): def _hidden(name):
return name.startswith('_') and not name.startswith('__') return name.startswith('_') and not name.startswith('__')
def _emit_args(args): def _emit_args(args):
# filter out the `out` argument here # filter out the `out` argument here
return [(arg.name, str(arg.type), str(arg.default_value)) for arg in args] # if arg.name != 'out' return [(arg.name, str(arg.type), str(arg.default_value)) for arg in args] # if arg.name != 'out'
def _get_tensor_ops(): def _get_tensor_ops():
def is_tensor_method(schema): def is_tensor_method(schema):
if len(schema.arguments) == 0: if len(schema.arguments) == 0:
...@@ -291,6 +338,7 @@ def _get_tensor_ops(): ...@@ -291,6 +338,7 @@ def _get_tensor_ops():
return op_args.keys(), op_args return op_args.keys(), op_args
def _get_torch_ops(): def _get_torch_ops():
torch_op_args = {} torch_op_args = {}
for mod in torch.jit._builtins._modules_containing_builtins: for mod in torch.jit._builtins._modules_containing_builtins:
...@@ -316,6 +364,7 @@ def _get_torch_ops(): ...@@ -316,6 +364,7 @@ def _get_torch_ops():
return torch_op_args.keys(), torch_op_args return torch_op_args.keys(), torch_op_args
def _get_torch_ops_exclude_tensor_ops(): def _get_torch_ops_exclude_tensor_ops():
tensor_op_names, _ = _get_tensor_ops() tensor_op_names, _ = _get_tensor_ops()
torch_op_names, torch_ops = _get_torch_ops() torch_op_names, torch_ops = _get_torch_ops()
...@@ -330,6 +379,7 @@ def _get_torch_ops_exclude_tensor_ops(): ...@@ -330,6 +379,7 @@ def _get_torch_ops_exclude_tensor_ops():
return torch_exclude_ops.keys(), torch_exclude_ops return torch_exclude_ops.keys(), torch_exclude_ops
class TensorOps(PyTorchOperation): class TensorOps(PyTorchOperation):
""" """
corresponding to _get_tensor_ops in torch.jit.supported_ops corresponding to _get_tensor_ops in torch.jit.supported_ops
...@@ -346,7 +396,7 @@ class TensorOps(PyTorchOperation): ...@@ -346,7 +396,7 @@ class TensorOps(PyTorchOperation):
name = ','.join([arg[0] for arg in each]) name = ','.join([arg[0] for arg in each])
concated_names.append(name) concated_names.append(name)
for i in range(len(concated_names) - 1): for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i+1]: if concated_names[i] != concated_names[i + 1]:
return False return False
return True return True
...@@ -383,6 +433,7 @@ class TensorOps(PyTorchOperation): ...@@ -383,6 +433,7 @@ class TensorOps(PyTorchOperation):
args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)]) args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = {inputs[0]}.{op_name}({args_str})' return f'{output} = {inputs[0]}.{op_name}({args_str})'
class TorchOps(PyTorchOperation): class TorchOps(PyTorchOperation):
""" """
corresponding to _get_nn_functional_ops in torch.jit.supported_ops corresponding to _get_nn_functional_ops in torch.jit.supported_ops
...@@ -400,7 +451,7 @@ class TorchOps(PyTorchOperation): ...@@ -400,7 +451,7 @@ class TorchOps(PyTorchOperation):
name = ','.join([arg[0] for arg in each]) name = ','.join([arg[0] for arg in each])
concated_names.append(name) concated_names.append(name)
for i in range(len(concated_names) - 1): for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i+1]: if concated_names[i] != concated_names[i + 1]:
return False return False
return True return True
...@@ -424,16 +475,33 @@ class TorchOps(PyTorchOperation): ...@@ -424,16 +475,33 @@ class TorchOps(PyTorchOperation):
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
matched_args = TorchOps._get_matched_args(self.type, inputs) matched_args = TorchOps._get_matched_args(self.type, inputs)
op_name = self.type.split('::')[-1] op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}' \ args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}'
for i, (name, t, default) in enumerate(matched_args)]) for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = torch.{op_name}({args_str})' return f'{output} = torch.{op_name}({args_str})'
class AtenAvgpool2d(PyTorchOperation): class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason # NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d'] _ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})' return f'{output} = F.avg_pool2d({", ".join(inputs)})'
class ToDevice(PyTorchOperation):
_artificial_op_name = "ToDevice"
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False):
self.type = "ToDevice"
self.device = parameters['device']
self.src = parameters['src']
self.dst = parameters['dst']
def __repr__(self):
return f'to("{self.device}")'
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}.to("{self.device}")'
class AtenDet(PyTorchOperation): class AtenDet(PyTorchOperation):
# for torch 1.9 # for torch 1.9
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det # NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
......
from collections import OrderedDict
from typing import (List, Optional)
import torch
import torch.nn as torch_nn
#sys.path.append(str(Path(__file__).resolve().parents[2]))
import ops
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
@basic_unit
class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
def __init__(self, input_size, C, n_classes):
""" assuming input size 7x7 or 8x8 """
assert input_size in [7, 8]
super().__init__()
self.net = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out
nn.Conv2d(C, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.linear = nn.Linear(768, n_classes)
def forward(self, x):
out = self.net(x)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
return logits
class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
nn.LayerChoice([
ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
]))
self.drop_path = ops.DropPath()
self.input_switch = nn.InputChoice(n_candidates=num_prev_nodes, n_chosen=2)
def forward(self, prev_nodes: List['Tensor']) -> 'Tensor':
#assert self.ops.__len__() == len(prev_nodes)
#out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out = []
for i, op in enumerate(self.ops):
out.append(op(prev_nodes[i]))
#out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out)
class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth, channels, 2 if reduction else 0))
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
new_tensors = []
for node in self.mutable_ops:
tmp = tensors + new_tensors
cur_tensor = node(tmp)
new_tensors.append(cur_tensor)
output = torch.cat(new_tensors, dim=1)
return output
class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
stem_multiplier=3, auxiliary=False):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.n_classes = n_classes
self.n_layers = n_layers
self.aux_pos = 2 * n_layers // 3 if auxiliary else -1
c_cur = stem_multiplier * self.channels
self.stem = nn.Sequential(
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
nn.BatchNorm2d(c_cur)
)
# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp, channels_p, c_cur = c_cur, c_cur, channels
self.cells = nn.ModuleList()
reduction_p, reduction = False, False
for i in range(n_layers):
reduction_p, reduction = reduction, False
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
if i in [n_layers // 3, 2 * n_layers // 3]:
c_cur *= 2
reduction = True
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out
#if i == self.aux_pos:
# self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)
self.gap = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(channels_p, n_classes)
def forward(self, x):
s0 = s1 = self.stem(x)
#aux_logits = None
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1)
#if i == self.aux_pos and self.training:
# aux_logits = self.aux_head(s1)
out = self.gap(s1)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
#if aux_logits is not None:
# return logits, aux_logits
return logits
def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath):
module.p = p
if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8)
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
@basic_unit
class DropPath(nn.Module):
def __init__(self, p=0.):
"""
Drop path with probability.
Parameters
----------
p : float
Probability of an path to be zeroed.
"""
super().__init__()
self.p = p
def forward(self, x):
if self.training and self.p > 0.:
keep_prob = 1. - self.p
# per data point mask
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
return x / keep_prob * mask
return x
@basic_unit
class PoolBN(nn.Module):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C, affine=affine)
def forward(self, x):
out = self.pool(x)
out = self.bn(out)
return out
@basic_unit
class StdConv(nn.Module):
"""
Standard conv: ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@basic_unit
class FacConv(nn.Module):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@basic_unit
class DilConv(nn.Module):
"""
(Dilated) depthwise separable conv.
ReLU - (Dilated) depthwise separable - Pointwise - BN.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@basic_unit
class SepConv(nn.Module):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
)
def forward(self, x):
return self.net(x)
@basic_unit
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment