Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a0e2f8ef
"vscode:/vscode.git/clone" did not exist on "a396b8f0d1c3da8d32b071c7e90ad17b559bfb66"
Unverified
Commit
a0e2f8ef
authored
Dec 15, 2020
by
Zhenhua Han
Committed by
GitHub
Dec 15, 2020
Browse files
[Retiarii] add validation in base trainers (#3184)
parent
59cd3982
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
92 additions
and
59 deletions
+92
-59
nni/retiarii/execution/logical_optimizer/logical_plan.py
nni/retiarii/execution/logical_optimizer/logical_plan.py
+10
-17
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
+2
-2
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+77
-35
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+1
-2
test/ut/retiarii/test_dedup_input.py
test/ut/retiarii/test_dedup_input.py
+2
-3
No files found.
nni/retiarii/execution/logical_optimizer/logical_plan.py
View file @
a0e2f8ef
import
copy
import
copy
from
typing
import
Dict
,
Tuple
,
List
,
Any
from
typing
import
Dict
,
Tuple
,
List
,
Any
from
nni.retiarii.utils
import
uid
from
...graph
import
Cell
,
Edge
,
Graph
,
Model
,
Node
from
...graph
import
Cell
,
Edge
,
Graph
,
Model
,
Node
from
...operation
import
Operation
,
_IOPseudoOperation
from
...operation
import
Operation
,
_IOPseudoOperation
...
@@ -14,7 +15,7 @@ class PhysicalDevice:
...
@@ -14,7 +15,7 @@ class PhysicalDevice:
return
self
.
server
==
o
.
server
and
self
.
device
==
o
.
device
return
self
.
server
==
o
.
server
and
self
.
device
==
o
.
device
def
__hash__
(
self
)
->
int
:
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
server
+
'_'
+
self
.
device
)
return
hash
(
self
.
server
+
'_'
+
self
.
device
)
class
AbstractLogicalNode
(
Node
):
class
AbstractLogicalNode
(
Node
):
...
@@ -181,10 +182,8 @@ class LogicalPlan:
...
@@ -181,10 +182,8 @@ class LogicalPlan:
if
isinstance
(
new_node
.
operation
,
_IOPseudoOperation
):
if
isinstance
(
new_node
.
operation
,
_IOPseudoOperation
):
model_id
=
new_node
.
graph
.
model
.
model_id
model_id
=
new_node
.
graph
.
model
.
model_id
if
model_id
not
in
training_config_slot
:
if
model_id
not
in
training_config_slot
:
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
].
append
(
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
].
append
(
new_node
.
graph
.
model
.
training_config
.
kwargs
.
copy
())
new_node
.
graph
.
model
.
training_config
.
kwargs
.
copy
())
training_config_slot
[
model_id
]
=
len
(
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
])
-
1
training_config_slot
[
model_id
]
=
\
len
(
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
])
-
1
slot
=
training_config_slot
[
model_id
]
slot
=
training_config_slot
[
model_id
]
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'model_id'
]
=
model_id
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'model_id'
]
=
model_id
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_input'
]
=
False
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_input'
]
=
False
...
@@ -221,18 +220,14 @@ class LogicalPlan:
...
@@ -221,18 +220,14 @@ class LogicalPlan:
tail_placement
=
node_placements
[
edge
.
tail
]
tail_placement
=
node_placements
[
edge
.
tail
]
if
head_placement
!=
tail_placement
:
if
head_placement
!=
tail_placement
:
if
head_placement
.
server
!=
tail_placement
.
server
:
if
head_placement
.
server
!=
tail_placement
.
server
:
raise
ValueError
(
raise
ValueError
(
'Cross-server placement is not supported.'
)
'Cross-server placement is not supported.'
)
# Same server different devices
# Same server different devices
if
(
edge
.
head
,
tail_placement
)
in
copied_op
:
if
(
edge
.
head
,
tail_placement
)
in
copied_op
:
to_node
=
copied_op
[(
edge
.
head
,
tail_placement
)]
to_node
=
copied_op
[(
edge
.
head
,
tail_placement
)]
else
:
else
:
to_operation
=
Operation
.
new
(
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
to_node
=
Node
(
phy_graph
,
uid
(),
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
to_node
=
Node
(
phy_graph
,
phy_model
.
_uid
(),
Edge
((
edge
.
head
,
edge
.
head_slot
),
(
to_node
,
None
),
_internal
=
True
).
_register
()
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
Edge
((
edge
.
head
,
edge
.
head_slot
),
(
to_node
,
None
),
_internal
=
True
).
_register
()
copied_op
[(
edge
.
head
,
tail_placement
)]
=
to_node
copied_op
[(
edge
.
head
,
tail_placement
)]
=
to_node
edge
.
head
=
to_node
edge
.
head
=
to_node
edge
.
head_slot
=
None
edge
.
head_slot
=
None
...
@@ -266,11 +261,9 @@ class LogicalPlan:
...
@@ -266,11 +261,9 @@ class LogicalPlan:
return
phy_model
,
node_placements
return
phy_model
,
node_placements
def
node_replace
(
self
,
old_node
:
Node
,
def
node_replace
(
self
,
old_node
:
Node
,
new_node
:
Node
,
input_slot_mapping
=
None
,
output_slot_mapping
=
None
):
new_node
:
Node
,
input_slot_mapping
=
None
,
output_slot_mapping
=
None
):
# TODO: currently, only support single input slot and output slot.
# TODO: currently, only support single input slot and output slot.
if
input_slot_mapping
!=
None
or
output_slot_mapping
!=
None
:
if
input_slot_mapping
is
not
None
or
output_slot_mapping
is
not
None
:
raise
ValueError
(
'Slot mapping is not supported'
)
raise
ValueError
(
'Slot mapping is not supported'
)
phy_graph
=
old_node
.
graph
phy_graph
=
old_node
.
graph
...
...
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
View file @
a0e2f8ef
from
typing
import
List
,
Dict
,
Tuple
from
typing
import
List
,
Dict
,
Tuple
from
nni.retiarii.utils
import
uid
from
...graph
import
Graph
,
Model
,
Node
from
...graph
import
Graph
,
Model
,
Node
from
.interface
import
AbstractOptimizer
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
(
AbstractLogicalNode
,
LogicalGraph
,
LogicalPlan
,
from
.logical_plan
import
(
AbstractLogicalNode
,
LogicalGraph
,
LogicalPlan
,
...
@@ -78,8 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer):
...
@@ -78,8 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer):
assert
(
nodes_to_dedup
[
0
]
==
root_node
)
assert
(
nodes_to_dedup
[
0
]
==
root_node
)
nodes_to_skip
.
add
(
root_node
)
nodes_to_skip
.
add
(
root_node
)
else
:
else
:
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
uid
(),
nodes_to_dedup
).
_register
()
logical_plan
.
lp_model
.
_uid
(),
nodes_to_dedup
).
_register
()
for
edge
in
logical_plan
.
logical_graph
.
edges
:
for
edge
in
logical_plan
.
logical_graph
.
edges
:
if
edge
.
head
in
nodes_to_dedup
:
if
edge
.
head
in
nodes_to_dedup
:
edge
.
head
=
dedup_node
edge
.
head
=
dedup_node
...
...
nni/retiarii/trainer/pytorch/base.py
View file @
a0e2f8ef
...
@@ -36,7 +36,8 @@ def get_default_transform(dataset: str) -> Any:
...
@@ -36,7 +36,8 @@ def get_default_transform(dataset: str) -> Any:
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
])
# unsupported dataset, return None
# unsupported dataset, return None
return
None
return
None
...
@@ -79,20 +80,30 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
...
@@ -79,20 +80,30 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
only the key ``max_epochs`` is useful.
"""
"""
super
(
PyTorchImageClassificationTrainer
,
self
).
__init__
(
model
,
dataset_cls
,
dataset_kwargs
,
dataloader_kwargs
,
optimizer_cls
,
optimizer_kwargs
,
trainer_kwargs
)
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
model
=
model
self
.
model
=
model
if
self
.
_use_cuda
:
if
self
.
_use_cuda
:
self
.
model
.
cuda
()
self
.
model
.
cuda
()
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
_dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
self
.
_
train_
dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
True
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
**
(
dataset_kwargs
or
{}))
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
self
.
_val_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
False
,
transform
=
get_default_transform
(
dataset_cls
),
model
.
parameters
(),
**
(
optimizer_kwargs
or
{}))
**
(
dataset_kwargs
or
{}))
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
model
.
parameters
(),
**
(
optimizer_kwargs
or
{}))
self
.
_trainer_kwargs
=
trainer_kwargs
or
{
'max_epochs'
:
10
}
self
.
_trainer_kwargs
=
trainer_kwargs
or
{
'max_epochs'
:
10
}
# TODO: we will need at least two (maybe three) data loaders in future.
self
.
_train_dataloader
=
DataLoader
(
self
.
_train_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_dataloader
=
DataLoader
(
self
.
_val_dataloader
=
DataLoader
(
self
.
_val_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
def
_accuracy
(
self
,
input
,
target
):
# pylint: disable=redefined-builtin
def
_accuracy
(
self
,
input
,
target
):
# pylint: disable=redefined-builtin
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
...
@@ -137,12 +148,12 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
...
@@ -137,12 +148,12 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
def
_validate
(
self
):
def
_validate
(
self
):
validation_outputs
=
[]
validation_outputs
=
[]
for
i
,
batch
in
enumerate
(
self
.
_dataloader
):
for
i
,
batch
in
enumerate
(
self
.
_
val_
dataloader
):
validation_outputs
.
append
(
self
.
validation_step
(
batch
,
i
))
validation_outputs
.
append
(
self
.
validation_step
(
batch
,
i
))
return
self
.
validation_epoch_end
(
validation_outputs
)
return
self
.
validation_epoch_end
(
validation_outputs
)
def
_train
(
self
):
def
_train
(
self
):
for
i
,
batch
in
enumerate
(
self
.
_dataloader
):
for
i
,
batch
in
enumerate
(
self
.
_
train_
dataloader
):
loss
=
self
.
training_step
(
batch
,
i
)
loss
=
self
.
training_step
(
batch
,
i
)
loss
.
backward
()
loss
.
backward
()
...
@@ -157,25 +168,32 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -157,25 +168,32 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def
__init__
(
self
,
multi_model
,
kwargs
=
[]):
def
__init__
(
self
,
multi_model
,
kwargs
=
[]):
self
.
multi_model
=
multi_model
self
.
multi_model
=
multi_model
self
.
kwargs
=
kwargs
self
.
kwargs
=
kwargs
self
.
_dataloaders
=
[]
self
.
_train_dataloaders
=
[]
self
.
_datasets
=
[]
self
.
_train_datasets
=
[]
self
.
_val_dataloaders
=
[]
self
.
_val_datasets
=
[]
self
.
_optimizers
=
[]
self
.
_optimizers
=
[]
self
.
_trainers
=
[]
self
.
_trainers
=
[]
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
max_steps
=
None
self
.
max_steps
=
self
.
kwargs
[
'max_steps'
]
if
'makx_steps'
in
self
.
kwargs
else
None
if
'max_steps'
in
self
.
kwargs
:
self
.
n_model
=
len
(
self
.
kwargs
[
'model_kwargs'
])
self
.
max_steps
=
self
.
kwargs
[
'max_steps'
]
for
m
in
self
.
kwargs
[
'model_kwargs'
]:
for
m
in
self
.
kwargs
[
'model_kwargs'
]:
if
m
[
'use_input'
]:
if
m
[
'use_input'
]:
dataset_cls
=
m
[
'dataset_cls'
]
dataset_cls
=
m
[
'dataset_cls'
]
dataset_kwargs
=
m
[
'dataset_kwargs'
]
dataset_kwargs
=
m
[
'dataset_kwargs'
]
dataloader_kwargs
=
m
[
'dataloader_kwargs'
]
dataloader_kwargs
=
m
[
'dataloader_kwargs'
]
dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
train_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
True
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
val_dataset
=
getattr
(
datasets
,
dataset_cls
)(
train
=
False
,
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
**
(
dataset_kwargs
or
{}))
dataloader
=
DataLoader
(
dataset
,
**
(
dataloader_kwargs
or
{}))
train_dataloader
=
DataLoader
(
train_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_datasets
.
append
(
dataset
)
val_dataloader
=
DataLoader
(
val_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_dataloaders
.
append
(
dataloader
)
self
.
_train_datasets
.
append
(
train_dataset
)
self
.
_train_dataloaders
.
append
(
train_dataloader
)
self
.
_val_datasets
.
append
(
val_dataset
)
self
.
_val_dataloaders
.
append
(
val_dataloader
)
if
m
[
'use_output'
]:
if
m
[
'use_output'
]:
optimizer_cls
=
m
[
'optimizer_cls'
]
optimizer_cls
=
m
[
'optimizer_cls'
]
...
@@ -195,9 +213,10 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -195,9 +213,10 @@ class PyTorchMultiModelTrainer(BaseTrainer):
max_epochs
=
max
([
x
[
'trainer_kwargs'
][
'max_epochs'
]
for
x
in
self
.
kwargs
[
'model_kwargs'
]])
max_epochs
=
max
([
x
[
'trainer_kwargs'
][
'max_epochs'
]
for
x
in
self
.
kwargs
[
'model_kwargs'
]])
for
_
in
range
(
max_epochs
):
for
_
in
range
(
max_epochs
):
self
.
_train
()
self
.
_train
()
nni
.
report_final_result
(
self
.
_validate
())
def
_train
(
self
):
def
_train
(
self
):
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_dataloaders
)):
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_
train_
dataloaders
)):
for
opt
in
self
.
_optimizers
:
for
opt
in
self
.
_optimizers
:
opt
.
zero_grad
()
opt
.
zero_grad
()
xs
=
[]
xs
=
[]
...
@@ -225,16 +244,9 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -225,16 +244,9 @@ class PyTorchMultiModelTrainer(BaseTrainer):
summed_loss
.
backward
()
summed_loss
.
backward
()
for
opt
in
self
.
_optimizers
:
for
opt
in
self
.
_optimizers
:
opt
.
step
()
opt
.
step
()
if
batch_idx
%
50
==
0
:
nni
.
report_intermediate_result
(
report_loss
)
if
self
.
max_steps
and
batch_idx
>=
self
.
max_steps
:
if
self
.
max_steps
and
batch_idx
>=
self
.
max_steps
:
return
return
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
x
,
y
=
batch
x
,
y
=
batch
if
device
:
if
device
:
...
@@ -245,17 +257,47 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -245,17 +257,47 @@ class PyTorchMultiModelTrainer(BaseTrainer):
loss
=
self
.
_loss_fn
(
y_hat
,
y
)
loss
=
self
.
_loss_fn
(
y_hat
,
y
)
return
loss
return
loss
def
validation_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
def
_validate
(
self
):
x
,
y
=
self
.
validation_step_before_model
(
batch
,
batch_idx
)
all_val_outputs
=
{
idx
:
[]
for
idx
in
range
(
self
.
n_model
)}
y_hat
=
self
.
model
(
x
)
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_val_dataloaders
)):
return
self
.
validation_step_after_model
(
x
,
y
,
y_hat
)
xs
=
[]
ys
=
[]
for
idx
,
batch
in
enumerate
(
multi_model_batch
):
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
,
f
'cuda:
{
idx
}
'
)
xs
.
append
(
x
)
ys
.
append
(
y
)
if
len
(
ys
)
!=
len
(
xs
):
raise
ValueError
(
'len(ys) should be equal to len(xs)'
)
def
validation_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
):
y_hats
=
self
.
multi_model
(
*
xs
)
for
output_idx
,
yhat
in
enumerate
(
y_hats
):
if
len
(
ys
)
==
len
(
y_hats
):
acc
=
self
.
validation_step_after_model
(
xs
[
output_idx
],
ys
[
output_idx
],
yhat
)
elif
len
(
ys
)
==
1
:
acc
=
self
.
validation_step_after_model
(
xs
[
0
],
ys
[
0
].
to
(
yhat
.
get_device
()),
yhat
)
else
:
raise
ValueError
(
'len(ys) should be either 1 or len(y_hats)'
)
all_val_outputs
[
output_idx
].
append
(
acc
)
report_acc
=
{}
for
idx
in
all_val_outputs
:
avg_acc
=
np
.
mean
([
x
[
'val_acc'
]
for
x
in
all_val_outputs
[
idx
]]).
item
()
report_acc
[
self
.
kwargs
[
'model_kwargs'
][
idx
][
'model_id'
]]
=
avg_acc
nni
.
report_intermediate_result
(
report_acc
)
return
report_acc
def
validation_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
x
,
y
=
batch
x
,
y
=
batch
if
self
.
_use_cuda
:
if
device
:
x
,
y
=
x
.
cuda
(
),
y
.
cuda
(
)
x
,
y
=
x
.
cuda
(
torch
.
device
(
device
)),
y
.
cuda
(
torch
.
device
(
device
)
)
return
x
,
y
return
x
,
y
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
acc
=
self
.
_accuracy
(
y_hat
,
y
)
acc
=
self
.
_accuracy
(
y_hat
,
y
)
return
{
'val_acc'
:
acc
}
return
{
'val_acc'
:
acc
}
def
_accuracy
(
self
,
input
,
target
):
# pylint: disable=redefined-builtin
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
correct
=
predict
.
eq
(
target
.
data
).
cpu
().
sum
().
item
()
return
correct
/
input
.
size
(
0
)
test/ut/retiarii/test_cgo_engine.py
View file @
a0e2f8ef
...
@@ -42,8 +42,7 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -42,8 +42,7 @@ class CGOEngineTest(unittest.TestCase):
protocol
.
_in_file
=
open
(
'generated/debug_protocol_out_file.py'
,
'rb'
)
protocol
.
_in_file
=
open
(
'generated/debug_protocol_out_file.py'
,
'rb'
)
models
=
_load_mnist
(
2
)
models
=
_load_mnist
(
2
)
anything
=
lambda
:
None
advisor
=
RetiariiAdvisor
()
advisor
=
RetiariiAdvisor
(
anything
)
submit_models
(
*
models
)
submit_models
(
*
models
)
if
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
>=
2
:
if
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
>=
2
:
...
...
test/ut/retiarii/test_dedup_input.py
View file @
a0e2f8ef
...
@@ -55,8 +55,7 @@ class DedupInputTest(unittest.TestCase):
...
@@ -55,8 +55,7 @@ class DedupInputTest(unittest.TestCase):
self
.
assertTrue
(
correct_dump
[
0
]
==
json
.
dumps
(
lp_dump
))
self
.
assertTrue
(
correct_dump
[
0
]
==
json
.
dumps
(
lp_dump
))
anything
=
lambda
:
None
advisor
=
RetiariiAdvisor
()
advisor
=
RetiariiAdvisor
(
anything
)
cgo
=
CGOExecutionEngine
()
cgo
=
CGOExecutionEngine
()
phy_models
=
cgo
.
_assemble
(
lp
)
phy_models
=
cgo
.
_assemble
(
lp
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment