"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "8a60d624e0daf9a414495c34c7171fa5e36f89c0"
Unverified Commit a0e2f8ef authored by Zhenhua Han's avatar Zhenhua Han Committed by GitHub
Browse files

[Retiarii] add validation in base trainers (#3184)

parent 59cd3982
import copy import copy
from typing import Dict, Tuple, List, Any from typing import Dict, Tuple, List, Any
from nni.retiarii.utils import uid
from ...graph import Cell, Edge, Graph, Model, Node from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation from ...operation import Operation, _IOPseudoOperation
...@@ -14,7 +15,7 @@ class PhysicalDevice: ...@@ -14,7 +15,7 @@ class PhysicalDevice:
return self.server == o.server and self.device == o.device return self.server == o.server and self.device == o.device
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self.server+'_'+self.device) return hash(self.server + '_' + self.device)
class AbstractLogicalNode(Node): class AbstractLogicalNode(Node):
...@@ -181,10 +182,8 @@ class LogicalPlan: ...@@ -181,10 +182,8 @@ class LogicalPlan:
if isinstance(new_node.operation, _IOPseudoOperation): if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id model_id = new_node.graph.model.model_id
if model_id not in training_config_slot: if model_id not in training_config_slot:
phy_model.training_config.kwargs['model_kwargs'].append( phy_model.training_config.kwargs['model_kwargs'].append(new_node.graph.model.training_config.kwargs.copy())
new_node.graph.model.training_config.kwargs.copy()) training_config_slot[model_id] = len(phy_model.training_config.kwargs['model_kwargs']) - 1
training_config_slot[model_id] = \
len(phy_model.training_config.kwargs['model_kwargs'])-1
slot = training_config_slot[model_id] slot = training_config_slot[model_id]
phy_model.training_config.kwargs['model_kwargs'][slot]['model_id'] = model_id phy_model.training_config.kwargs['model_kwargs'][slot]['model_id'] = model_id
phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = False phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = False
...@@ -221,18 +220,14 @@ class LogicalPlan: ...@@ -221,18 +220,14 @@ class LogicalPlan:
tail_placement = node_placements[edge.tail] tail_placement = node_placements[edge.tail]
if head_placement != tail_placement: if head_placement != tail_placement:
if head_placement.server != tail_placement.server: if head_placement.server != tail_placement.server:
raise ValueError( raise ValueError('Cross-server placement is not supported.')
'Cross-server placement is not supported.')
# Same server different devices # Same server different devices
if (edge.head, tail_placement) in copied_op: if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)] to_node = copied_op[(edge.head, tail_placement)]
else: else:
to_operation = Operation.new( to_operation = Operation.new('ToDevice', {"device": tail_placement.device})
'ToDevice', {"device": tail_placement.device}) to_node = Node(phy_graph, uid(), edge.head.name + "_to_" + edge.tail.name, to_operation)._register()
to_node = Node(phy_graph, phy_model._uid(), Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
edge.head.name+"_to_"+edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot),
(to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node copied_op[(edge.head, tail_placement)] = to_node
edge.head = to_node edge.head = to_node
edge.head_slot = None edge.head_slot = None
...@@ -266,11 +261,9 @@ class LogicalPlan: ...@@ -266,11 +261,9 @@ class LogicalPlan:
return phy_model, node_placements return phy_model, node_placements
def node_replace(self, old_node: Node, def node_replace(self, old_node: Node, new_node: Node, input_slot_mapping=None, output_slot_mapping=None):
new_node: Node,
input_slot_mapping=None, output_slot_mapping=None):
# TODO: currently, only support single input slot and output slot. # TODO: currently, only support single input slot and output slot.
if input_slot_mapping != None or output_slot_mapping != None: if input_slot_mapping is not None or output_slot_mapping is not None:
raise ValueError('Slot mapping is not supported') raise ValueError('Slot mapping is not supported')
phy_graph = old_node.graph phy_graph = old_node.graph
......
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
from nni.retiarii.utils import uid
from ...graph import Graph, Model, Node from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan, from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
...@@ -78,8 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer): ...@@ -78,8 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer):
assert(nodes_to_dedup[0] == root_node) assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node) nodes_to_skip.add(root_node)
else: else:
dedup_node = DedupInputNode(logical_plan.logical_graph, dedup_node = DedupInputNode(logical_plan.logical_graph, uid(), nodes_to_dedup)._register()
logical_plan.lp_model._uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges: for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup: if edge.head in nodes_to_dedup:
edge.head = dedup_node edge.head = dedup_node
......
...@@ -36,7 +36,8 @@ def get_default_transform(dataset: str) -> Any: ...@@ -36,7 +36,8 @@ def get_default_transform(dataset: str) -> Any:
transforms.RandomCrop(32, padding=4), transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
]) ])
# unsupported dataset, return None # unsupported dataset, return None
return None return None
...@@ -79,20 +80,30 @@ class PyTorchImageClassificationTrainer(BaseTrainer): ...@@ -79,20 +80,30 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently, Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful. only the key ``max_epochs`` is useful.
""" """
super(
PyTorchImageClassificationTrainer,
self).__init__(
model,
dataset_cls,
dataset_kwargs,
dataloader_kwargs,
optimizer_cls,
optimizer_kwargs,
trainer_kwargs)
self._use_cuda = torch.cuda.is_available() self._use_cuda = torch.cuda.is_available()
self.model = model self.model = model
if self._use_cuda: if self._use_cuda:
self.model.cuda() self.model.cuda()
self._loss_fn = nn.CrossEntropyLoss() self._loss_fn = nn.CrossEntropyLoss()
self._dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls), self._train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {})) **(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)( self._val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
model.parameters(), **(optimizer_kwargs or {})) **(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)(model.parameters(), **(optimizer_kwargs or {}))
self._trainer_kwargs = trainer_kwargs or {'max_epochs': 10} self._trainer_kwargs = trainer_kwargs or {'max_epochs': 10}
# TODO: we will need at least two (maybe three) data loaders in future. self._train_dataloader = DataLoader(self._train_dataset, **(dataloader_kwargs or {}))
self._dataloader = DataLoader( self._val_dataloader = DataLoader(self._val_dataset, **(dataloader_kwargs or {}))
self._dataset, **(dataloader_kwargs or {}))
def _accuracy(self, input, target): # pylint: disable=redefined-builtin def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1) _, predict = torch.max(input.data, 1)
...@@ -137,12 +148,12 @@ class PyTorchImageClassificationTrainer(BaseTrainer): ...@@ -137,12 +148,12 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
def _validate(self): def _validate(self):
validation_outputs = [] validation_outputs = []
for i, batch in enumerate(self._dataloader): for i, batch in enumerate(self._val_dataloader):
validation_outputs.append(self.validation_step(batch, i)) validation_outputs.append(self.validation_step(batch, i))
return self.validation_epoch_end(validation_outputs) return self.validation_epoch_end(validation_outputs)
def _train(self): def _train(self):
for i, batch in enumerate(self._dataloader): for i, batch in enumerate(self._train_dataloader):
loss = self.training_step(batch, i) loss = self.training_step(batch, i)
loss.backward() loss.backward()
...@@ -157,25 +168,32 @@ class PyTorchMultiModelTrainer(BaseTrainer): ...@@ -157,25 +168,32 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def __init__(self, multi_model, kwargs=[]): def __init__(self, multi_model, kwargs=[]):
self.multi_model = multi_model self.multi_model = multi_model
self.kwargs = kwargs self.kwargs = kwargs
self._dataloaders = [] self._train_dataloaders = []
self._datasets = [] self._train_datasets = []
self._val_dataloaders = []
self._val_datasets = []
self._optimizers = [] self._optimizers = []
self._trainers = [] self._trainers = []
self._loss_fn = nn.CrossEntropyLoss() self._loss_fn = nn.CrossEntropyLoss()
self.max_steps = None self.max_steps = self.kwargs['max_steps'] if 'makx_steps' in self.kwargs else None
if 'max_steps' in self.kwargs: self.n_model = len(self.kwargs['model_kwargs'])
self.max_steps = self.kwargs['max_steps']
for m in self.kwargs['model_kwargs']: for m in self.kwargs['model_kwargs']:
if m['use_input']: if m['use_input']:
dataset_cls = m['dataset_cls'] dataset_cls = m['dataset_cls']
dataset_kwargs = m['dataset_kwargs'] dataset_kwargs = m['dataset_kwargs']
dataloader_kwargs = m['dataloader_kwargs'] dataloader_kwargs = m['dataloader_kwargs']
dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls), train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {})) **(dataset_kwargs or {}))
dataloader = DataLoader(dataset, **(dataloader_kwargs or {})) val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
self._datasets.append(dataset) **(dataset_kwargs or {}))
self._dataloaders.append(dataloader) train_dataloader = DataLoader(train_dataset, **(dataloader_kwargs or {}))
val_dataloader = DataLoader(val_dataset, **(dataloader_kwargs or {}))
self._train_datasets.append(train_dataset)
self._train_dataloaders.append(train_dataloader)
self._val_datasets.append(val_dataset)
self._val_dataloaders.append(val_dataloader)
if m['use_output']: if m['use_output']:
optimizer_cls = m['optimizer_cls'] optimizer_cls = m['optimizer_cls']
...@@ -195,9 +213,10 @@ class PyTorchMultiModelTrainer(BaseTrainer): ...@@ -195,9 +213,10 @@ class PyTorchMultiModelTrainer(BaseTrainer):
max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']]) max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']])
for _ in range(max_epochs): for _ in range(max_epochs):
self._train() self._train()
nni.report_final_result(self._validate())
def _train(self): def _train(self):
for batch_idx, multi_model_batch in enumerate(zip(*self._dataloaders)): for batch_idx, multi_model_batch in enumerate(zip(*self._train_dataloaders)):
for opt in self._optimizers: for opt in self._optimizers:
opt.zero_grad() opt.zero_grad()
xs = [] xs = []
...@@ -225,16 +244,9 @@ class PyTorchMultiModelTrainer(BaseTrainer): ...@@ -225,16 +244,9 @@ class PyTorchMultiModelTrainer(BaseTrainer):
summed_loss.backward() summed_loss.backward()
for opt in self._optimizers: for opt in self._optimizers:
opt.step() opt.step()
if batch_idx % 50 == 0:
nni.report_intermediate_result(report_loss)
if self.max_steps and batch_idx >= self.max_steps: if self.max_steps and batch_idx >= self.max_steps:
return return
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.training_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.training_step_after_model(x, y, y_hat)
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None): def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch x, y = batch
if device: if device:
...@@ -245,17 +257,47 @@ class PyTorchMultiModelTrainer(BaseTrainer): ...@@ -245,17 +257,47 @@ class PyTorchMultiModelTrainer(BaseTrainer):
loss = self._loss_fn(y_hat, y) loss = self._loss_fn(y_hat, y)
return loss return loss
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]: def _validate(self):
x, y = self.validation_step_before_model(batch, batch_idx) all_val_outputs = {idx: [] for idx in range(self.n_model)}
y_hat = self.model(x) for batch_idx, multi_model_batch in enumerate(zip(*self._val_dataloaders)):
return self.validation_step_after_model(x, y, y_hat) xs = []
ys = []
for idx, batch in enumerate(multi_model_batch):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x)
ys.append(y)
if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)')
def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int): y_hats = self.multi_model(*xs)
for output_idx, yhat in enumerate(y_hats):
if len(ys) == len(y_hats):
acc = self.validation_step_after_model(xs[output_idx], ys[output_idx], yhat)
elif len(ys) == 1:
acc = self.validation_step_after_model(xs[0], ys[0].to(yhat.get_device()), yhat)
else:
raise ValueError('len(ys) should be either 1 or len(y_hats)')
all_val_outputs[output_idx].append(acc)
report_acc = {}
for idx in all_val_outputs:
avg_acc = np.mean([x['val_acc'] for x in all_val_outputs[idx]]).item()
report_acc[self.kwargs['model_kwargs'][idx]['model_id']] = avg_acc
nni.report_intermediate_result(report_acc)
return report_acc
def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch x, y = batch
if self._use_cuda: if device:
x, y = x.cuda(), y.cuda() x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
return x, y return x, y
def validation_step_after_model(self, x, y, y_hat): def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y) acc = self._accuracy(y_hat, y)
return {'val_acc': acc} return {'val_acc': acc}
def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
...@@ -42,8 +42,7 @@ class CGOEngineTest(unittest.TestCase): ...@@ -42,8 +42,7 @@ class CGOEngineTest(unittest.TestCase):
protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb') protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb')
models = _load_mnist(2) models = _load_mnist(2)
anything = lambda: None advisor = RetiariiAdvisor()
advisor = RetiariiAdvisor(anything)
submit_models(*models) submit_models(*models)
if torch.cuda.is_available() and torch.cuda.device_count() >= 2: if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
......
...@@ -54,9 +54,8 @@ class DedupInputTest(unittest.TestCase): ...@@ -54,9 +54,8 @@ class DedupInputTest(unittest.TestCase):
lp_dump = lp.logical_graph._dump() lp_dump = lp.logical_graph._dump()
self.assertTrue(correct_dump[0] == json.dumps(lp_dump)) self.assertTrue(correct_dump[0] == json.dumps(lp_dump))
anything = lambda: None advisor = RetiariiAdvisor()
advisor = RetiariiAdvisor(anything)
cgo = CGOExecutionEngine() cgo = CGOExecutionEngine()
phy_models = cgo._assemble(lp) phy_models = cgo._assemble(lp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment