Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a0e2f8ef
Unverified
Commit
a0e2f8ef
authored
Dec 15, 2020
by
Zhenhua Han
Committed by
GitHub
Dec 15, 2020
Browse files
[Retiarii] add validation in base trainers (#3184)
parent
59cd3982
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
92 additions
and
59 deletions
+92
-59
nni/retiarii/execution/logical_optimizer/logical_plan.py
nni/retiarii/execution/logical_optimizer/logical_plan.py
+10
-17
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
+2
-2
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+77
-35
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+1
-2
test/ut/retiarii/test_dedup_input.py
test/ut/retiarii/test_dedup_input.py
+2
-3
No files found.
nni/retiarii/execution/logical_optimizer/logical_plan.py
View file @
a0e2f8ef
import
copy
from
typing
import
Dict
,
Tuple
,
List
,
Any
from
nni.retiarii.utils
import
uid
from
...graph
import
Cell
,
Edge
,
Graph
,
Model
,
Node
from
...operation
import
Operation
,
_IOPseudoOperation
...
...
@@ -14,7 +15,7 @@ class PhysicalDevice:
return
self
.
server
==
o
.
server
and
self
.
device
==
o
.
device
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
server
+
'_'
+
self
.
device
)
return
hash
(
self
.
server
+
'_'
+
self
.
device
)
class
AbstractLogicalNode
(
Node
):
...
...
@@ -181,10 +182,8 @@ class LogicalPlan:
if
isinstance
(
new_node
.
operation
,
_IOPseudoOperation
):
model_id
=
new_node
.
graph
.
model
.
model_id
if
model_id
not
in
training_config_slot
:
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
].
append
(
new_node
.
graph
.
model
.
training_config
.
kwargs
.
copy
())
training_config_slot
[
model_id
]
=
\
len
(
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
])
-
1
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
].
append
(
new_node
.
graph
.
model
.
training_config
.
kwargs
.
copy
())
training_config_slot
[
model_id
]
=
len
(
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
])
-
1
slot
=
training_config_slot
[
model_id
]
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'model_id'
]
=
model_id
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_input'
]
=
False
...
...
@@ -221,18 +220,14 @@ class LogicalPlan:
tail_placement
=
node_placements
[
edge
.
tail
]
if
head_placement
!=
tail_placement
:
if
head_placement
.
server
!=
tail_placement
.
server
:
raise
ValueError
(
'Cross-server placement is not supported.'
)
raise
ValueError
(
'Cross-server placement is not supported.'
)
# Same server different devices
if
(
edge
.
head
,
tail_placement
)
in
copied_op
:
to_node
=
copied_op
[(
edge
.
head
,
tail_placement
)]
else
:
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
to_node
=
Node
(
phy_graph
,
phy_model
.
_uid
(),
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
Edge
((
edge
.
head
,
edge
.
head_slot
),
(
to_node
,
None
),
_internal
=
True
).
_register
()
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
to_node
=
Node
(
phy_graph
,
uid
(),
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
Edge
((
edge
.
head
,
edge
.
head_slot
),
(
to_node
,
None
),
_internal
=
True
).
_register
()
copied_op
[(
edge
.
head
,
tail_placement
)]
=
to_node
edge
.
head
=
to_node
edge
.
head_slot
=
None
...
...
@@ -266,11 +261,9 @@ class LogicalPlan:
return
phy_model
,
node_placements
def
node_replace
(
self
,
old_node
:
Node
,
new_node
:
Node
,
input_slot_mapping
=
None
,
output_slot_mapping
=
None
):
def
node_replace
(
self
,
old_node
:
Node
,
new_node
:
Node
,
input_slot_mapping
=
None
,
output_slot_mapping
=
None
):
# TODO: currently, only support single input slot and output slot.
if
input_slot_mapping
!=
None
or
output_slot_mapping
!=
None
:
if
input_slot_mapping
is
not
None
or
output_slot_mapping
is
not
None
:
raise
ValueError
(
'Slot mapping is not supported'
)
phy_graph
=
old_node
.
graph
...
...
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
View file @
a0e2f8ef
from
typing
import
List
,
Dict
,
Tuple
from
nni.retiarii.utils
import
uid
from
...graph
import
Graph
,
Model
,
Node
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
(
AbstractLogicalNode
,
LogicalGraph
,
LogicalPlan
,
...
...
@@ -78,8 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer):
assert
(
nodes_to_dedup
[
0
]
==
root_node
)
nodes_to_skip
.
add
(
root_node
)
else
:
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
logical_plan
.
lp_model
.
_uid
(),
nodes_to_dedup
).
_register
()
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
uid
(),
nodes_to_dedup
).
_register
()
for
edge
in
logical_plan
.
logical_graph
.
edges
:
if
edge
.
head
in
nodes_to_dedup
:
edge
.
head
=
dedup_node
...
...
nni/retiarii/trainer/pytorch/base.py
View file @
a0e2f8ef
...
...
@@ -36,7 +36,8 @@ def get_default_transform(dataset: str) -> Any:
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
# unsupported dataset, return None
return
None
...
...
@@ -79,20 +80,30 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super
(
PyTorchImageClassificationTrainer
,
self
).
__init__
(
model
,
dataset_cls
,
dataset_kwargs
,
dataloader_kwargs
,
optimizer_cls
,
optimizer_kwargs
,
trainer_kwargs
)
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
model
=
model
if
self
.
_use_cuda
:
self
.
model
.
cuda
()
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
_dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
model
.
parameters
(),
**
(
optimizer_kwargs
or
{}))
self
.
_train_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
True
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
self
.
_val_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
False
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
model
.
parameters
(),
**
(
optimizer_kwargs
or
{}))
self
.
_trainer_kwargs
=
trainer_kwargs
or
{
'max_epochs'
:
10
}
# TODO: we will need at least two (maybe three) data loaders in future.
self
.
_dataloader
=
DataLoader
(
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_train_dataloader
=
DataLoader
(
self
.
_train_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_val_dataloader
=
DataLoader
(
self
.
_val_dataset
,
**
(
dataloader_kwargs
or
{}))
def
_accuracy
(
self
,
input
,
target
):
# pylint: disable=redefined-builtin
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
...
...
@@ -137,12 +148,12 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
def
_validate
(
self
):
validation_outputs
=
[]
for
i
,
batch
in
enumerate
(
self
.
_dataloader
):
for
i
,
batch
in
enumerate
(
self
.
_
val_
dataloader
):
validation_outputs
.
append
(
self
.
validation_step
(
batch
,
i
))
return
self
.
validation_epoch_end
(
validation_outputs
)
def
_train
(
self
):
for
i
,
batch
in
enumerate
(
self
.
_dataloader
):
for
i
,
batch
in
enumerate
(
self
.
_
train_
dataloader
):
loss
=
self
.
training_step
(
batch
,
i
)
loss
.
backward
()
...
...
@@ -157,25 +168,32 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def
__init__
(
self
,
multi_model
,
kwargs
=
[]):
self
.
multi_model
=
multi_model
self
.
kwargs
=
kwargs
self
.
_dataloaders
=
[]
self
.
_datasets
=
[]
self
.
_train_dataloaders
=
[]
self
.
_train_datasets
=
[]
self
.
_val_dataloaders
=
[]
self
.
_val_datasets
=
[]
self
.
_optimizers
=
[]
self
.
_trainers
=
[]
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
max_steps
=
None
if
'max_steps'
in
self
.
kwargs
:
self
.
max_steps
=
self
.
kwargs
[
'max_steps'
]
self
.
max_steps
=
self
.
kwargs
[
'max_steps'
]
if
'makx_steps'
in
self
.
kwargs
else
None
self
.
n_model
=
len
(
self
.
kwargs
[
'model_kwargs'
])
for
m
in
self
.
kwargs
[
'model_kwargs'
]:
if
m
[
'use_input'
]:
dataset_cls
=
m
[
'dataset_cls'
]
dataset_kwargs
=
m
[
'dataset_kwargs'
]
dataloader_kwargs
=
m
[
'dataloader_kwargs'
]
dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
dataloader
=
DataLoader
(
dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_datasets
.
append
(
dataset
)
self
.
_dataloaders
.
append
(
dataloader
)
train_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
True
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
val_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
False
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
train_dataloader
=
DataLoader
(
train_dataset
,
**
(
dataloader_kwargs
or
{}))
val_dataloader
=
DataLoader
(
val_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_train_datasets
.
append
(
train_dataset
)
self
.
_train_dataloaders
.
append
(
train_dataloader
)
self
.
_val_datasets
.
append
(
val_dataset
)
self
.
_val_dataloaders
.
append
(
val_dataloader
)
if
m
[
'use_output'
]:
optimizer_cls
=
m
[
'optimizer_cls'
]
...
...
@@ -195,9 +213,10 @@ class PyTorchMultiModelTrainer(BaseTrainer):
max_epochs
=
max
([
x
[
'trainer_kwargs'
][
'max_epochs'
]
for
x
in
self
.
kwargs
[
'model_kwargs'
]])
for
_
in
range
(
max_epochs
):
self
.
_train
()
nni
.
report_final_result
(
self
.
_validate
())
def
_train
(
self
):
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_dataloaders
)):
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_
train_
dataloaders
)):
for
opt
in
self
.
_optimizers
:
opt
.
zero_grad
()
xs
=
[]
...
...
@@ -225,16 +244,9 @@ class PyTorchMultiModelTrainer(BaseTrainer):
summed_loss
.
backward
()
for
opt
in
self
.
_optimizers
:
opt
.
step
()
if
batch_idx
%
50
==
0
:
nni
.
report_intermediate_result
(
report_loss
)
if
self
.
max_steps
and
batch_idx
>=
self
.
max_steps
:
return
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
x
,
y
=
batch
if
device
:
...
...
@@ -245,17 +257,47 @@ class PyTorchMultiModelTrainer(BaseTrainer):
loss
=
self
.
_loss_fn
(
y_hat
,
y
)
return
loss
def
validation_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
validation_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
validation_step_after_model
(
x
,
y
,
y_hat
)
def
_validate
(
self
):
all_val_outputs
=
{
idx
:
[]
for
idx
in
range
(
self
.
n_model
)}
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_val_dataloaders
)):
xs
=
[]
ys
=
[]
for
idx
,
batch
in
enumerate
(
multi_model_batch
):
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
,
f
'cuda:
{
idx
}
'
)
xs
.
append
(
x
)
ys
.
append
(
y
)
if
len
(
ys
)
!=
len
(
xs
):
raise
ValueError
(
'len(ys) should be equal to len(xs)'
)
def
validation_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
):
y_hats
=
self
.
multi_model
(
*
xs
)
for
output_idx
,
yhat
in
enumerate
(
y_hats
):
if
len
(
ys
)
==
len
(
y_hats
):
acc
=
self
.
validation_step_after_model
(
xs
[
output_idx
],
ys
[
output_idx
],
yhat
)
elif
len
(
ys
)
==
1
:
acc
=
self
.
validation_step_after_model
(
xs
[
0
],
ys
[
0
].
to
(
yhat
.
get_device
()),
yhat
)
else
:
raise
ValueError
(
'len(ys) should be either 1 or len(y_hats)'
)
all_val_outputs
[
output_idx
].
append
(
acc
)
report_acc
=
{}
for
idx
in
all_val_outputs
:
avg_acc
=
np
.
mean
([
x
[
'val_acc'
]
for
x
in
all_val_outputs
[
idx
]]).
item
()
report_acc
[
self
.
kwargs
[
'model_kwargs'
][
idx
][
'model_id'
]]
=
avg_acc
nni
.
report_intermediate_result
(
report_acc
)
return
report_acc
def
validation_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
x
,
y
=
batch
if
self
.
_use_cuda
:
x
,
y
=
x
.
cuda
(
),
y
.
cuda
(
)
if
device
:
x
,
y
=
x
.
cuda
(
torch
.
device
(
device
)),
y
.
cuda
(
torch
.
device
(
device
)
)
return
x
,
y
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
acc
=
self
.
_accuracy
(
y_hat
,
y
)
return
{
'val_acc'
:
acc
}
def
_accuracy
(
self
,
input
,
target
):
# pylint: disable=redefined-builtin
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
correct
=
predict
.
eq
(
target
.
data
).
cpu
().
sum
().
item
()
return
correct
/
input
.
size
(
0
)
test/ut/retiarii/test_cgo_engine.py
View file @
a0e2f8ef
...
...
@@ -42,8 +42,7 @@ class CGOEngineTest(unittest.TestCase):
protocol
.
_in_file
=
open
(
'generated/debug_protocol_out_file.py'
,
'rb'
)
models
=
_load_mnist
(
2
)
anything
=
lambda
:
None
advisor
=
RetiariiAdvisor
(
anything
)
advisor
=
RetiariiAdvisor
()
submit_models
(
*
models
)
if
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
>=
2
:
...
...
test/ut/retiarii/test_dedup_input.py
View file @
a0e2f8ef
...
...
@@ -54,9 +54,8 @@ class DedupInputTest(unittest.TestCase):
lp_dump
=
lp
.
logical_graph
.
_dump
()
self
.
assertTrue
(
correct_dump
[
0
]
==
json
.
dumps
(
lp_dump
))
anything
=
lambda
:
None
advisor
=
RetiariiAdvisor
(
anything
)
advisor
=
RetiariiAdvisor
()
cgo
=
CGOExecutionEngine
()
phy_models
=
cgo
.
_assemble
(
lp
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment