Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a0e2f8ef
Unverified
Commit
a0e2f8ef
authored
Dec 15, 2020
by
Zhenhua Han
Committed by
GitHub
Dec 15, 2020
Browse files
[Retiarii] add validation in base trainers (#3184)
parent
59cd3982
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
92 additions
and
59 deletions
+92
-59
nni/retiarii/execution/logical_optimizer/logical_plan.py
nni/retiarii/execution/logical_optimizer/logical_plan.py
+10
-17
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
+2
-2
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+77
-35
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+1
-2
test/ut/retiarii/test_dedup_input.py
test/ut/retiarii/test_dedup_input.py
+2
-3
No files found.
nni/retiarii/execution/logical_optimizer/logical_plan.py
View file @
a0e2f8ef
import
copy
from
typing
import
Dict
,
Tuple
,
List
,
Any
from
nni.retiarii.utils
import
uid
from
...graph
import
Cell
,
Edge
,
Graph
,
Model
,
Node
from
...operation
import
Operation
,
_IOPseudoOperation
...
...
@@ -14,7 +15,7 @@ class PhysicalDevice:
return
self
.
server
==
o
.
server
and
self
.
device
==
o
.
device
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
server
+
'_'
+
self
.
device
)
return
hash
(
self
.
server
+
'_'
+
self
.
device
)
class
AbstractLogicalNode
(
Node
):
...
...
@@ -181,10 +182,8 @@ class LogicalPlan:
if
isinstance
(
new_node
.
operation
,
_IOPseudoOperation
):
model_id
=
new_node
.
graph
.
model
.
model_id
if
model_id
not
in
training_config_slot
:
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
].
append
(
new_node
.
graph
.
model
.
training_config
.
kwargs
.
copy
())
training_config_slot
[
model_id
]
=
\
len
(
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
])
-
1
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
].
append
(
new_node
.
graph
.
model
.
training_config
.
kwargs
.
copy
())
training_config_slot
[
model_id
]
=
len
(
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
])
-
1
slot
=
training_config_slot
[
model_id
]
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'model_id'
]
=
model_id
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_input'
]
=
False
...
...
@@ -221,18 +220,14 @@ class LogicalPlan:
tail_placement
=
node_placements
[
edge
.
tail
]
if
head_placement
!=
tail_placement
:
if
head_placement
.
server
!=
tail_placement
.
server
:
raise
ValueError
(
'Cross-server placement is not supported.'
)
raise
ValueError
(
'Cross-server placement is not supported.'
)
# Same server different devices
if
(
edge
.
head
,
tail_placement
)
in
copied_op
:
to_node
=
copied_op
[(
edge
.
head
,
tail_placement
)]
else
:
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
to_node
=
Node
(
phy_graph
,
phy_model
.
_uid
(),
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
Edge
((
edge
.
head
,
edge
.
head_slot
),
(
to_node
,
None
),
_internal
=
True
).
_register
()
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
to_node
=
Node
(
phy_graph
,
uid
(),
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
Edge
((
edge
.
head
,
edge
.
head_slot
),
(
to_node
,
None
),
_internal
=
True
).
_register
()
copied_op
[(
edge
.
head
,
tail_placement
)]
=
to_node
edge
.
head
=
to_node
edge
.
head_slot
=
None
...
...
@@ -266,11 +261,9 @@ class LogicalPlan:
return
phy_model
,
node_placements
def
node_replace
(
self
,
old_node
:
Node
,
new_node
:
Node
,
input_slot_mapping
=
None
,
output_slot_mapping
=
None
):
def
node_replace
(
self
,
old_node
:
Node
,
new_node
:
Node
,
input_slot_mapping
=
None
,
output_slot_mapping
=
None
):
# TODO: currently, only support single input slot and output slot.
if
input_slot_mapping
!=
None
or
output_slot_mapping
!=
None
:
if
input_slot_mapping
is
not
None
or
output_slot_mapping
is
not
None
:
raise
ValueError
(
'Slot mapping is not supported'
)
phy_graph
=
old_node
.
graph
...
...
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
View file @
a0e2f8ef
from
typing
import
List
,
Dict
,
Tuple
from
nni.retiarii.utils
import
uid
from
...graph
import
Graph
,
Model
,
Node
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
(
AbstractLogicalNode
,
LogicalGraph
,
LogicalPlan
,
...
...
@@ -78,8 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer):
assert
(
nodes_to_dedup
[
0
]
==
root_node
)
nodes_to_skip
.
add
(
root_node
)
else
:
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
logical_plan
.
lp_model
.
_uid
(),
nodes_to_dedup
).
_register
()
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
uid
(),
nodes_to_dedup
).
_register
()
for
edge
in
logical_plan
.
logical_graph
.
edges
:
if
edge
.
head
in
nodes_to_dedup
:
edge
.
head
=
dedup_node
...
...
nni/retiarii/trainer/pytorch/base.py
View file @
a0e2f8ef
...
...
@@ -36,7 +36,8 @@ def get_default_transform(dataset: str) -> Any:
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
# unsupported dataset, return None
return
None
...
...
@@ -79,20 +80,30 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super
(
PyTorchImageClassificationTrainer
,
self
).
__init__
(
model
,
dataset_cls
,
dataset_kwargs
,
dataloader_kwargs
,
optimizer_cls
,
optimizer_kwargs
,
trainer_kwargs
)
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
model
=
model
if
self
.
_use_cuda
:
self
.
model
.
cuda
()
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
_dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
model
.
parameters
(),
**
(
optimizer_kwargs
or
{}))
self
.
_train_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
True
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
self
.
_val_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
False
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
model
.
parameters
(),
**
(
optimizer_kwargs
or
{}))
self
.
_trainer_kwargs
=
trainer_kwargs
or
{
'max_epochs'
:
10
}
# TODO: we will need at least two (maybe three) data loaders in future.
self
.
_dataloader
=
DataLoader
(
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_train_dataloader
=
DataLoader
(
self
.
_train_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_val_dataloader
=
DataLoader
(
self
.
_val_dataset
,
**
(
dataloader_kwargs
or
{}))
def
_accuracy
(
self
,
input
,
target
):
# pylint: disable=redefined-builtin
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
...
...
@@ -137,12 +148,12 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
def
_validate
(
self
):
validation_outputs
=
[]
for
i
,
batch
in
enumerate
(
self
.
_dataloader
):
for
i
,
batch
in
enumerate
(
self
.
_
val_
dataloader
):
validation_outputs
.
append
(
self
.
validation_step
(
batch
,
i
))
return
self
.
validation_epoch_end
(
validation_outputs
)
def
_train
(
self
):
for
i
,
batch
in
enumerate
(
self
.
_dataloader
):
for
i
,
batch
in
enumerate
(
self
.
_
train_
dataloader
):
loss
=
self
.
training_step
(
batch
,
i
)
loss
.
backward
()
...
...
@@ -157,25 +168,32 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def
__init__
(
self
,
multi_model
,
kwargs
=
[]):
self
.
multi_model
=
multi_model
self
.
kwargs
=
kwargs
self
.
_dataloaders
=
[]
self
.
_datasets
=
[]
self
.
_train_dataloaders
=
[]
self
.
_train_datasets
=
[]
self
.
_val_dataloaders
=
[]
self
.
_val_datasets
=
[]
self
.
_optimizers
=
[]
self
.
_trainers
=
[]
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
max_steps
=
None
if
'max_steps'
in
self
.
kwargs
:
self
.
max_steps
=
self
.
kwargs
[
'max_steps'
]
self
.
max_steps
=
self
.
kwargs
[
'max_steps'
]
if
'makx_steps'
in
self
.
kwargs
else
None
self
.
n_model
=
len
(
self
.
kwargs
[
'model_kwargs'
])
for
m
in
self
.
kwargs
[
'model_kwargs'
]:
if
m
[
'use_input'
]:
dataset_cls
=
m
[
'dataset_cls'
]
dataset_kwargs
=
m
[
'dataset_kwargs'
]
dataloader_kwargs
=
m
[
'dataloader_kwargs'
]
dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
dataloader
=
DataLoader
(
dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_datasets
.
append
(
dataset
)
self
.
_dataloaders
.
append
(
dataloader
)
train_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
True
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
val_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
False
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
train_dataloader
=
DataLoader
(
train_dataset
,
**
(
dataloader_kwargs
or
{}))
val_dataloader
=
DataLoader
(
val_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_train_datasets
.
append
(
train_dataset
)
self
.
_train_dataloaders
.
append
(
train_dataloader
)
self
.
_val_datasets
.
append
(
val_dataset
)
self
.
_val_dataloaders
.
append
(
val_dataloader
)
if
m
[
'use_output'
]:
optimizer_cls
=
m
[
'optimizer_cls'
]
...
...
@@ -195,9 +213,10 @@ class PyTorchMultiModelTrainer(BaseTrainer):
max_epochs
=
max
([
x
[
'trainer_kwargs'
][
'max_epochs'
]
for
x
in
self
.
kwargs
[
'model_kwargs'
]])
for
_
in
range
(
max_epochs
):
self
.
_train
()
nni
.
report_final_result
(
self
.
_validate
())
def
_train
(
self
):
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_dataloaders
)):
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_
train_
dataloaders
)):
for
opt
in
self
.
_optimizers
:
opt
.
zero_grad
()
xs
=
[]
...
...
@@ -225,16 +244,9 @@ class PyTorchMultiModelTrainer(BaseTrainer):
summed_loss
.
backward
()
for
opt
in
self
.
_optimizers
:
opt
.
step
()
if
batch_idx
%
50
==
0
:
nni
.
report_intermediate_result
(
report_loss
)
if
self
.
max_steps
and
batch_idx
>=
self
.
max_steps
:
return
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
x
,
y
=
batch
if
device
:
...
...
@@ -245,17 +257,47 @@ class PyTorchMultiModelTrainer(BaseTrainer):
loss
=
self
.
_loss_fn
(
y_hat
,
y
)
return
loss
def
validation_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
validation_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
validation_step_after_model
(
x
,
y
,
y_hat
)
def
_validate
(
self
):
all_val_outputs
=
{
idx
:
[]
for
idx
in
range
(
self
.
n_model
)}
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_val_dataloaders
)):
xs
=
[]
ys
=
[]
for
idx
,
batch
in
enumerate
(
multi_model_batch
):
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
,
f
'cuda:
{
idx
}
'
)
xs
.
append
(
x
)
ys
.
append
(
y
)
if
len
(
ys
)
!=
len
(
xs
):
raise
ValueError
(
'len(ys) should be equal to len(xs)'
)
def
validation_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
):
y_hats
=
self
.
multi_model
(
*
xs
)
for
output_idx
,
yhat
in
enumerate
(
y_hats
):
if
len
(
ys
)
==
len
(
y_hats
):
acc
=
self
.
validation_step_after_model
(
xs
[
output_idx
],
ys
[
output_idx
],
yhat
)
elif
len
(
ys
)
==
1
:
acc
=
self
.
validation_step_after_model
(
xs
[
0
],
ys
[
0
].
to
(
yhat
.
get_device
()),
yhat
)
else
:
raise
ValueError
(
'len(ys) should be either 1 or len(y_hats)'
)
all_val_outputs
[
output_idx
].
append
(
acc
)
report_acc
=
{}
for
idx
in
all_val_outputs
:
avg_acc
=
np
.
mean
([
x
[
'val_acc'
]
for
x
in
all_val_outputs
[
idx
]]).
item
()
report_acc
[
self
.
kwargs
[
'model_kwargs'
][
idx
][
'model_id'
]]
=
avg_acc
nni
.
report_intermediate_result
(
report_acc
)
return
report_acc
def
validation_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
x
,
y
=
batch
if
self
.
_use_cuda
:
x
,
y
=
x
.
cuda
(
),
y
.
cuda
(
)
if
device
:
x
,
y
=
x
.
cuda
(
torch
.
device
(
device
)),
y
.
cuda
(
torch
.
device
(
device
)
)
return
x
,
y
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
acc
=
self
.
_accuracy
(
y_hat
,
y
)
return
{
'val_acc'
:
acc
}
def
_accuracy
(
self
,
input
,
target
):
# pylint: disable=redefined-builtin
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
correct
=
predict
.
eq
(
target
.
data
).
cpu
().
sum
().
item
()
return
correct
/
input
.
size
(
0
)
test/ut/retiarii/test_cgo_engine.py
View file @
a0e2f8ef
...
...
@@ -42,8 +42,7 @@ class CGOEngineTest(unittest.TestCase):
protocol
.
_in_file
=
open
(
'generated/debug_protocol_out_file.py'
,
'rb'
)
models
=
_load_mnist
(
2
)
anything
=
lambda
:
None
advisor
=
RetiariiAdvisor
(
anything
)
advisor
=
RetiariiAdvisor
()
submit_models
(
*
models
)
if
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
>=
2
:
...
...
test/ut/retiarii/test_dedup_input.py
View file @
a0e2f8ef
...
...
@@ -54,9 +54,8 @@ class DedupInputTest(unittest.TestCase):
lp_dump
=
lp
.
logical_graph
.
_dump
()
self
.
assertTrue
(
correct_dump
[
0
]
==
json
.
dumps
(
lp_dump
))
anything
=
lambda
:
None
advisor
=
RetiariiAdvisor
(
anything
)
advisor
=
RetiariiAdvisor
()
cgo
=
CGOExecutionEngine
()
phy_models
=
cgo
.
_assemble
(
lp
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment