Unverified Commit a0e2f8ef authored by Zhenhua Han's avatar Zhenhua Han Committed by GitHub
Browse files

[Retiarii] add validation in base trainers (#3184)

parent 59cd3982
import copy
from typing import Dict, Tuple, List, Any
from nni.retiarii.utils import uid
from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation
......@@ -14,7 +15,7 @@ class PhysicalDevice:
return self.server == o.server and self.device == o.device
def __hash__(self) -> int:
return hash(self.server+'_'+self.device)
return hash(self.server + '_' + self.device)
class AbstractLogicalNode(Node):
......@@ -181,10 +182,8 @@ class LogicalPlan:
if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id
if model_id not in training_config_slot:
phy_model.training_config.kwargs['model_kwargs'].append(
new_node.graph.model.training_config.kwargs.copy())
training_config_slot[model_id] = \
len(phy_model.training_config.kwargs['model_kwargs'])-1
phy_model.training_config.kwargs['model_kwargs'].append(new_node.graph.model.training_config.kwargs.copy())
training_config_slot[model_id] = len(phy_model.training_config.kwargs['model_kwargs']) - 1
slot = training_config_slot[model_id]
phy_model.training_config.kwargs['model_kwargs'][slot]['model_id'] = model_id
phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = False
......@@ -221,18 +220,14 @@ class LogicalPlan:
tail_placement = node_placements[edge.tail]
if head_placement != tail_placement:
if head_placement.server != tail_placement.server:
raise ValueError(
'Cross-server placement is not supported.')
raise ValueError('Cross-server placement is not supported.')
# Same server different devices
if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)]
else:
to_operation = Operation.new(
'ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, phy_model._uid(),
edge.head.name+"_to_"+edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot),
(to_node, None), _internal=True)._register()
to_operation = Operation.new('ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, uid(), edge.head.name + "_to_" + edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node
edge.head = to_node
edge.head_slot = None
......@@ -266,11 +261,9 @@ class LogicalPlan:
return phy_model, node_placements
def node_replace(self, old_node: Node,
new_node: Node,
input_slot_mapping=None, output_slot_mapping=None):
def node_replace(self, old_node: Node, new_node: Node, input_slot_mapping=None, output_slot_mapping=None):
# TODO: currently, only support single input slot and output slot.
if input_slot_mapping != None or output_slot_mapping != None:
if input_slot_mapping is not None or output_slot_mapping is not None:
raise ValueError('Slot mapping is not supported')
phy_graph = old_node.graph
......
from typing import List, Dict, Tuple
from nni.retiarii.utils import uid
from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
......@@ -78,8 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer):
assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node)
else:
dedup_node = DedupInputNode(logical_plan.logical_graph,
logical_plan.lp_model._uid(), nodes_to_dedup)._register()
dedup_node = DedupInputNode(logical_plan.logical_graph, uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup:
edge.head = dedup_node
......
......@@ -36,7 +36,8 @@ def get_default_transform(dataset: str) -> Any:
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
# unsupported dataset, return None
return None
......@@ -79,20 +80,30 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super(
PyTorchImageClassificationTrainer,
self).__init__(
model,
dataset_cls,
dataset_kwargs,
dataloader_kwargs,
optimizer_cls,
optimizer_kwargs,
trainer_kwargs)
self._use_cuda = torch.cuda.is_available()
self.model = model
if self._use_cuda:
self.model.cuda()
self._loss_fn = nn.CrossEntropyLoss()
self._dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)(
model.parameters(), **(optimizer_kwargs or {}))
self._train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)(model.parameters(), **(optimizer_kwargs or {}))
self._trainer_kwargs = trainer_kwargs or {'max_epochs': 10}
# TODO: we will need at least two (maybe three) data loaders in future.
self._dataloader = DataLoader(
self._dataset, **(dataloader_kwargs or {}))
self._train_dataloader = DataLoader(self._train_dataset, **(dataloader_kwargs or {}))
self._val_dataloader = DataLoader(self._val_dataset, **(dataloader_kwargs or {}))
def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
......@@ -137,12 +148,12 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
def _validate(self):
validation_outputs = []
for i, batch in enumerate(self._dataloader):
for i, batch in enumerate(self._val_dataloader):
validation_outputs.append(self.validation_step(batch, i))
return self.validation_epoch_end(validation_outputs)
def _train(self):
for i, batch in enumerate(self._dataloader):
for i, batch in enumerate(self._train_dataloader):
loss = self.training_step(batch, i)
loss.backward()
......@@ -157,25 +168,32 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def __init__(self, multi_model, kwargs=[]):
self.multi_model = multi_model
self.kwargs = kwargs
self._dataloaders = []
self._datasets = []
self._train_dataloaders = []
self._train_datasets = []
self._val_dataloaders = []
self._val_datasets = []
self._optimizers = []
self._trainers = []
self._loss_fn = nn.CrossEntropyLoss()
self.max_steps = None
if 'max_steps' in self.kwargs:
self.max_steps = self.kwargs['max_steps']
self.max_steps = self.kwargs['max_steps'] if 'makx_steps' in self.kwargs else None
self.n_model = len(self.kwargs['model_kwargs'])
for m in self.kwargs['model_kwargs']:
if m['use_input']:
dataset_cls = m['dataset_cls']
dataset_kwargs = m['dataset_kwargs']
dataloader_kwargs = m['dataloader_kwargs']
dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
dataloader = DataLoader(dataset, **(dataloader_kwargs or {}))
self._datasets.append(dataset)
self._dataloaders.append(dataloader)
train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
train_dataloader = DataLoader(train_dataset, **(dataloader_kwargs or {}))
val_dataloader = DataLoader(val_dataset, **(dataloader_kwargs or {}))
self._train_datasets.append(train_dataset)
self._train_dataloaders.append(train_dataloader)
self._val_datasets.append(val_dataset)
self._val_dataloaders.append(val_dataloader)
if m['use_output']:
optimizer_cls = m['optimizer_cls']
......@@ -195,9 +213,10 @@ class PyTorchMultiModelTrainer(BaseTrainer):
max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']])
for _ in range(max_epochs):
self._train()
nni.report_final_result(self._validate())
def _train(self):
for batch_idx, multi_model_batch in enumerate(zip(*self._dataloaders)):
for batch_idx, multi_model_batch in enumerate(zip(*self._train_dataloaders)):
for opt in self._optimizers:
opt.zero_grad()
xs = []
......@@ -225,16 +244,9 @@ class PyTorchMultiModelTrainer(BaseTrainer):
summed_loss.backward()
for opt in self._optimizers:
opt.step()
if batch_idx % 50 == 0:
nni.report_intermediate_result(report_loss)
if self.max_steps and batch_idx >= self.max_steps:
return
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.training_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.training_step_after_model(x, y, y_hat)
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if device:
......@@ -245,17 +257,47 @@ class PyTorchMultiModelTrainer(BaseTrainer):
loss = self._loss_fn(y_hat, y)
return loss
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.validation_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.validation_step_after_model(x, y, y_hat)
def _validate(self):
all_val_outputs = {idx: [] for idx in range(self.n_model)}
for batch_idx, multi_model_batch in enumerate(zip(*self._val_dataloaders)):
xs = []
ys = []
for idx, batch in enumerate(multi_model_batch):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x)
ys.append(y)
if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)')
def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
y_hats = self.multi_model(*xs)
for output_idx, yhat in enumerate(y_hats):
if len(ys) == len(y_hats):
acc = self.validation_step_after_model(xs[output_idx], ys[output_idx], yhat)
elif len(ys) == 1:
acc = self.validation_step_after_model(xs[0], ys[0].to(yhat.get_device()), yhat)
else:
raise ValueError('len(ys) should be either 1 or len(y_hats)')
all_val_outputs[output_idx].append(acc)
report_acc = {}
for idx in all_val_outputs:
avg_acc = np.mean([x['val_acc'] for x in all_val_outputs[idx]]).item()
report_acc[self.kwargs['model_kwargs'][idx]['model_id']] = avg_acc
nni.report_intermediate_result(report_acc)
return report_acc
def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if self._use_cuda:
x, y = x.cuda(), y.cuda()
if device:
x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
return x, y
def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y)
return {'val_acc': acc}
def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
......@@ -42,8 +42,7 @@ class CGOEngineTest(unittest.TestCase):
protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb')
models = _load_mnist(2)
anything = lambda: None
advisor = RetiariiAdvisor(anything)
advisor = RetiariiAdvisor()
submit_models(*models)
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
......
......@@ -54,9 +54,8 @@ class DedupInputTest(unittest.TestCase):
lp_dump = lp.logical_graph._dump()
self.assertTrue(correct_dump[0] == json.dumps(lp_dump))
anything = lambda: None
advisor = RetiariiAdvisor(anything)
advisor = RetiariiAdvisor()
cgo = CGOExecutionEngine()
phy_models = cgo._assemble(lp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment