Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f2f58dbb
Unverified
Commit
f2f58dbb
authored
Jul 30, 2021
by
Zhenhua Han
Committed by
GitHub
Jul 30, 2021
Browse files
[Retiarii] cross-graph optimization: device placement and input deduplication (#3202)
parent
6645bd33
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1091 additions
and
131 deletions
+1091
-131
.gitignore
.gitignore
+2
-2
dependencies/recommended.txt
dependencies/recommended.txt
+1
-1
dependencies/recommended_gpu.txt
dependencies/recommended_gpu.txt
+1
-1
nni/common/device.py
nni/common/device.py
+7
-2
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+1
-1
nni/retiarii/evaluator/pytorch/cgo/__init__.py
nni/retiarii/evaluator/pytorch/cgo/__init__.py
+0
-0
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
+106
-0
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
+222
-0
nni/retiarii/evaluator/pytorch/cgo/trainer.py
nni/retiarii/evaluator/pytorch/cgo/trainer.py
+31
-0
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+26
-3
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+223
-55
nni/retiarii/execution/logical_optimizer/logical_plan.py
nni/retiarii/execution/logical_optimizer/logical_plan.py
+52
-41
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
+17
-6
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+13
-5
nni/retiarii/graph.py
nni/retiarii/graph.py
+6
-1
nni/retiarii/integration.py
nni/retiarii/integration.py
+1
-1
nni/retiarii/operation.py
nni/retiarii/operation.py
+4
-0
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+80
-12
test/retiarii_test/cgo/darts_model.py
test/retiarii_test/cgo/darts_model.py
+165
-0
test/retiarii_test/cgo/ops.py
test/retiarii_test/cgo/ops.py
+133
-0
No files found.
.gitignore
View file @
f2f58dbb
...
...
@@ -10,6 +10,8 @@
/ts/nni_manager/exp_profile.json
/ts/nni_manager/metrics.json
/ts/nni_manager/trial_jobs.json
/test/ut/retiarii/_debug_graph_data.json
/test/ut/retiarii/out.tmp
# Logs
logs
...
...
@@ -105,5 +107,3 @@ venv.bak/
.vscode
.vs
.history
generated/
test/ut/retiarii/_debug_graph_data.json
dependencies/recommended.txt
View file @
f2f58dbb
...
...
@@ -8,7 +8,7 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.
1.1
pytorch-lightning >= 1.
2.8
onnx
peewee
graphviz
...
...
dependencies/recommended_gpu.txt
View file @
f2f58dbb
...
...
@@ -5,7 +5,7 @@ tensorflow
keras == 2.4.3
torch == 1.9.0+cu111
torchvision == 0.10.0+cu111
pytorch-lightning >= 1.
1.1
pytorch-lightning >= 1.
2.8
onnx
peewee
graphviz
...
...
nni/common/device.py
View file @
f2f58dbb
...
...
@@ -9,7 +9,9 @@ class GPUDevice:
status
:
Literal
[
'idle'
,
'busy'
,
'unknown'
]
=
'idle'
def
__eq__
(
self
,
o
)
->
bool
:
return
self
.
node_id
==
o
.
node_id
and
self
.
gpu_id
==
o
.
gpu_id
if
isinstance
(
o
,
GPUDevice
):
return
self
.
node_id
==
o
.
node_id
and
self
.
gpu_id
==
o
.
gpu_id
return
False
def
__lt__
(
self
,
o
)
->
bool
:
if
self
.
node_id
<
o
.
node_id
:
...
...
@@ -23,7 +25,10 @@ class GPUDevice:
return
"{Environment %s, GPU %d, Status %s}"
%
(
self
.
node_id
,
self
.
gpu_id
,
self
.
status
)
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
node_id
+
'_'
+
self
.
gpu_id
)
return
hash
(
self
.
node_id
+
'_'
+
str
(
self
.
gpu_id
)
)
def
set_status
(
self
,
status
):
self
.
status
=
status
def
device_repr
(
self
,):
return
f
"cuda:
{
self
.
gpu_id
}
"
nni/retiarii/codegen/pytorch.py
View file @
f2f58dbb
...
...
@@ -115,7 +115,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_code
=
node
.
operation
.
to_init_code
(
_remove_prefix
(
node
.
name
,
graph_name
))
if
node_code
is
not
None
:
if
placement
and
node
in
placement
and
len
(
node_code
)
>
0
:
node_codes
.
append
(
f
"
{
node_code
}
.to('
{
placement
[
node
].
device
}
')"
)
node_codes
.
append
(
f
"
{
node_code
}
.to('
{
placement
[
node
].
device
_repr
()
}
')"
)
else
:
node_codes
.
append
(
node_code
)
...
...
nni/retiarii/evaluator/pytorch/cgo/__init__.py
0 → 100644
View file @
f2f58dbb
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
0 → 100644
View file @
f2f58dbb
from
typing
import
Any
,
Union
,
Optional
,
List
import
torch
from
pytorch_lightning.accelerators.accelerator
import
Accelerator
from
pytorch_lightning.plugins.training_type.training_type_plugin
import
TrainingTypePlugin
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
pytorch_lightning.plugins
import
Plugin
from
pytorch_lightning.plugins.environments
import
ClusterEnvironment
from
....serializer
import
serialize_cls
class
BypassPlugin
(
TrainingTypePlugin
):
""" Plugin that handles communication on a single device. """
def
__init__
(
self
,
device
:
str
):
super
().
__init__
()
self
.
device
:
str
=
device
self
.
global_rank
=
0
self
.
local_rank
=
0
self
.
world_size
=
1
def
connect
(
self
,
model
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
self
.
_model
=
model
self
.
model_to_device
()
return
self
.
model
@
property
def
on_tpu
(
self
)
->
bool
:
return
False
@
property
def
on_gpu
(
self
)
->
bool
:
return
"cuda"
in
self
.
device
and
torch
.
cuda
.
is_available
()
def
reduce
(
self
,
tensor
:
Union
[
Any
,
torch
.
Tensor
],
*
args
:
Any
,
**
kwargs
:
Any
)
->
Union
[
Any
,
torch
.
Tensor
]:
"""
Reduces a tensor from several distributed processes to one aggregated tensor.
As this plugin only operates with a single device, the reduction is simply the identity.
Args:
tensor: the tensor to sync and reduce
*args: ignored
**kwargs: ignored
Return:
the unmodified input as reduction is not needed for single process operation
"""
return
tensor
def
all_gather
(
self
,
tensor
:
torch
.
Tensor
,
group
:
Optional
[
Any
]
=
None
,
sync_grads
:
bool
=
False
)
->
torch
.
Tensor
:
"""Perform a all_gather on all processes """
return
tensor
@
property
def
root_device
(
self
)
->
torch
.
device
:
return
torch
.
device
(
self
.
device
)
def
model_to_device
(
self
)
->
None
:
# bypass device placement from pytorch lightning
pass
def
setup
(
self
,
model
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
self
.
model_to_device
()
return
self
.
model
@
property
def
is_global_zero
(
self
)
->
bool
:
return
True
def
barrier
(
self
,
*
args
,
**
kwargs
)
->
None
:
pass
def
broadcast
(
self
,
obj
:
object
,
src
:
int
=
0
)
->
object
:
return
obj
def
get_accelerator_connector
(
num_processes
:
int
=
1
,
tpu_cores
:
Optional
[
Union
[
List
[
int
],
str
,
int
]]
=
None
,
distributed_backend
:
Optional
[
str
]
=
None
,
auto_select_gpus
:
bool
=
False
,
gpus
:
Optional
[
Union
[
List
[
int
],
str
,
int
]]
=
None
,
num_nodes
:
int
=
1
,
sync_batchnorm
:
bool
=
False
,
benchmark
:
bool
=
False
,
replace_sampler_ddp
:
bool
=
True
,
deterministic
:
bool
=
False
,
precision
:
int
=
32
,
amp_backend
:
str
=
'native'
,
amp_level
:
str
=
'O2'
,
plugins
:
Optional
[
Union
[
List
[
Union
[
Plugin
,
ClusterEnvironment
,
str
]],
Plugin
,
ClusterEnvironment
,
str
]]
=
None
):
return
AcceleratorConnector
(
num_processes
,
tpu_cores
,
distributed_backend
,
auto_select_gpus
,
gpus
,
num_nodes
,
sync_batchnorm
,
benchmark
,
replace_sampler_ddp
,
deterministic
,
precision
,
amp_backend
,
amp_level
,
plugins
)
@
serialize_cls
class
BypassAccelerator
(
Accelerator
):
def
__init__
(
self
,
precision_plugin
=
None
,
device
=
"cpu"
):
if
precision_plugin
is
None
:
precision_plugin
=
get_accelerator_connector
().
precision_plugin
# pylint: disable=abstract-class-instantiated
super
().
__init__
(
precision_plugin
=
precision_plugin
,
training_type_plugin
=
BypassPlugin
(
device
))
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
0 → 100644
View file @
f2f58dbb
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch.nn
as
nn
import
torch.optim
as
optim
import
pytorch_lightning
as
pl
from
torch.utils.data
import
DataLoader
import
nni
from
..lightning
import
LightningModule
,
_AccuracyWithLogits
,
Lightning
from
.trainer
import
Trainer
from
....serializer
import
serialize_cls
@
serialize_cls
class
_MultiModelSupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
n_models
:
int
=
0
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
()
self
.
save_hyperparameters
(
'criterion'
,
'optimizer'
,
'learning_rate'
,
'weight_decay'
)
self
.
criterion
=
criterion
()
self
.
criterion_cls
=
criterion
self
.
optimizer
=
optimizer
self
.
metrics
=
nn
.
ModuleDict
({
name
:
cls
()
for
name
,
cls
in
metrics
.
items
()})
self
.
n_models
=
n_models
def
forward
(
self
,
x
):
y_hat
=
self
.
model
(
x
)
return
y_hat
def
training_step
(
self
,
batch
,
batch_idx
):
x
,
y
=
batch
multi_y_hat
=
self
(
x
)
if
isinstance
(
multi_y_hat
,
tuple
):
assert
len
(
multi_y_hat
)
==
self
.
n_models
else
:
assert
self
.
n_models
==
1
multi_y_hat
=
[
multi_y_hat
]
multi_loss
=
[]
for
idx
,
y_hat
in
enumerate
(
multi_y_hat
):
loss
=
self
.
criterion
(
y_hat
.
to
(
"cpu"
),
y
.
to
(
"cpu"
))
self
.
log
(
f
'train_loss_
{
idx
}
'
,
loss
,
prog_bar
=
True
)
for
name
,
metric
in
self
.
metrics
.
items
():
self
.
log
(
f
'train_
{
idx
}
_'
+
name
,
metric
(
y_hat
.
to
(
"cpu"
),
y
.
to
(
"cpu"
)),
prog_bar
=
True
)
multi_loss
.
append
(
loss
)
return
sum
(
multi_loss
)
def
validation_step
(
self
,
batch
,
batch_idx
):
x
,
y
=
batch
multi_y_hat
=
self
(
x
)
if
isinstance
(
multi_y_hat
,
tuple
):
assert
len
(
multi_y_hat
)
==
self
.
n_models
else
:
assert
self
.
n_models
==
1
multi_y_hat
=
[
multi_y_hat
]
for
idx
,
y_hat
in
enumerate
(
multi_y_hat
):
self
.
log
(
f
'val_loss_
{
idx
}
'
,
self
.
criterion
(
y_hat
.
to
(
"cpu"
),
y
.
to
(
"cpu"
)),
prog_bar
=
True
)
for
name
,
metric
in
self
.
metrics
.
items
():
self
.
log
(
f
'val_
{
idx
}
_'
+
name
,
metric
(
y_hat
.
to
(
"cpu"
),
y
.
to
(
"cpu"
)),
prog_bar
=
True
)
def
test_step
(
self
,
batch
,
batch_idx
):
x
,
y
=
batch
multi_y_hat
=
self
(
x
)
if
isinstance
(
multi_y_hat
,
tuple
):
assert
len
(
multi_y_hat
)
==
self
.
n_models
else
:
assert
self
.
n_models
==
1
multi_y_hat
=
[
multi_y_hat
]
for
idx
,
y_hat
in
enumerate
(
multi_y_hat
):
self
.
log
(
f
'test_loss_
{
idx
}
'
,
self
.
criterion
(
y_hat
.
to
(
"cpu"
),
y
.
to
(
"cpu"
)),
prog_bar
=
True
)
for
name
,
metric
in
self
.
metrics
.
items
():
self
.
log
(
f
'test_
{
idx
}
_'
+
name
,
metric
(
y_hat
.
to
(
"cpu"
),
y
.
to
(
"cpu"
)),
prog_bar
=
True
)
def
configure_optimizers
(
self
):
return
self
.
optimizer
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
)
def
on_validation_epoch_end
(
self
):
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
def
teardown
(
self
,
stage
):
if
stage
==
'fit'
:
nni
.
report_final_result
(
self
.
_get_validation_metrics
())
def
_get_validation_metrics
(
self
):
# TODO: split metric of multiple models?
if
len
(
self
.
metrics
)
==
1
:
metric_name
=
next
(
iter
(
self
.
metrics
))
ret
=
[]
for
idx
in
range
(
self
.
n_models
):
ret
.
append
(
self
.
trainer
.
callback_metrics
[
f
'val_
{
idx
}
_'
+
metric_name
].
item
())
return
ret
else
:
warnings
.
warn
(
'Multiple metrics without "default" is not supported by current framework.'
)
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
class
MultiModelSupervisedLearningModule
(
_MultiModelSupervisedLearningModule
):
"""
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
Users who needs cross-graph optimization should use this module.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
"""
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
(
criterion
,
metrics
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
@
serialize_cls
class
_ClassificationModule
(
MultiModelSupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
(
criterion
,
{
'acc'
:
_AccuracyWithLogits
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
class
Classification
(
Lightning
):
"""
Trainer that is used for classification.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
**
trainer_kwargs
):
module
=
_ClassificationModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
super
().
__init__
(
module
,
Trainer
(
use_cgo
=
True
,
**
trainer_kwargs
),
train_dataloader
=
train_dataloader
,
val_dataloaders
=
val_dataloaders
)
@
serialize_cls
class
_RegressionModule
(
MultiModelSupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
(
criterion
,
{
'mse'
:
pl
.
metrics
.
MeanSquaredError
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
class
Regression
(
Lightning
):
"""
Trainer that is used for regression.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.MSELoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
**
trainer_kwargs
):
module
=
_RegressionModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
super
().
__init__
(
module
,
Trainer
(
use_cgo
=
True
,
**
trainer_kwargs
),
train_dataloader
=
train_dataloader
,
val_dataloaders
=
val_dataloaders
)
nni/retiarii/evaluator/pytorch/cgo/trainer.py
0 → 100644
View file @
f2f58dbb
import
pytorch_lightning
as
pl
from
....serializer
import
serialize_cls
from
.accelerator
import
BypassAccelerator
@
serialize_cls
class
Trainer
(
pl
.
Trainer
):
"""
Trainer for cross-graph optimization.
Parameters
----------
use_cgo : bool
Whether cross-graph optimization (CGO) is used.
If it is True, CGO will manage device placement.
Any device placement from pytorch lightning will be bypassed.
default: False
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def
__init__
(
self
,
use_cgo
=
False
,
**
trainer_kwargs
):
if
use_cgo
:
if
"accelerator"
in
trainer_kwargs
:
raise
ValueError
(
"accelerator should not be set when cross-graph optimization is enabled."
)
trainer_kwargs
[
'accelerator'
]
=
BypassAccelerator
(
device
=
'cpu'
)
super
().
__init__
(
**
trainer_kwargs
)
nni/retiarii/evaluator/pytorch/lightning.py
View file @
f2f58dbb
...
...
@@ -12,6 +12,12 @@ import torch.optim as optim
from
torch.utils.data
import
DataLoader
import
nni
try
:
import
nni.retiarii.evaluator.pytorch.cgo.trainer
as
cgo_trainer
cgo_import_failed
=
False
except
ImportError
:
cgo_import_failed
=
True
from
...graph
import
Evaluator
from
...serializer
import
serialize_cls
...
...
@@ -36,7 +42,6 @@ class LightningModule(pl.LightningModule):
Trainer
=
serialize_cls
(
pl
.
Trainer
)
DataLoader
=
serialize_cls
(
DataLoader
)
class
Lightning
(
Evaluator
):
"""
Delegate the whole training to PyTorch Lightning.
...
...
@@ -67,7 +72,11 @@ class Lightning(Evaluator):
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
):
assert
isinstance
(
lightning_module
,
LightningModule
),
f
'Lightning module must be an instance of
{
__name__
}
.LightningModule.'
assert
isinstance
(
trainer
,
Trainer
),
f
'Trainer must be imported from
{
__name__
}
.'
if
cgo_import_failed
:
assert
isinstance
(
trainer
,
Trainer
),
f
'Trainer must be imported from
{
__name__
}
'
else
:
assert
isinstance
(
trainer
,
Trainer
)
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
f
'Trainer must be imported from
{
__name__
}
or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert
_check_dataloader
(
train_dataloader
),
f
'Wrong dataloader type. Try import DataLoader from
{
__name__
}
.'
assert
_check_dataloader
(
val_dataloaders
),
f
'Wrong dataloader type. Try import DataLoader from
{
__name__
}
.'
self
.
module
=
lightning_module
...
...
@@ -91,7 +100,21 @@ class Lightning(Evaluator):
return
self
.
fit
(
model_cls
)
def
__eq__
(
self
,
other
):
return
self
.
function
==
other
.
function
and
self
.
arguments
==
other
.
arguments
eq_func
=
False
eq_args
=
False
if
other
is
None
:
return
False
if
hasattr
(
self
,
"function"
)
and
hasattr
(
other
,
"function"
):
eq_func
=
(
self
.
function
==
other
.
function
)
elif
not
(
hasattr
(
self
,
"function"
)
or
hasattr
(
other
,
"function"
)):
eq_func
=
True
if
hasattr
(
self
,
"arguments"
)
and
hasattr
(
other
,
"arguments"
):
eq_args
=
(
self
.
arguments
==
other
.
arguments
)
elif
not
(
hasattr
(
self
,
"arguments"
)
or
hasattr
(
other
,
"arguments"
)):
eq_args
=
True
return
eq_func
and
eq_args
def
fit
(
self
,
model
):
"""
...
...
nni/retiarii/execution/cgo_engine.py
View file @
f2f58dbb
...
...
@@ -2,14 +2,22 @@
# Licensed under the MIT license.
import
logging
import
os
import
random
import
string
import
time
import
threading
from
typing
import
Iterable
,
List
,
Dict
,
Tuple
from
nni.common.device
import
GPUDevice
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..graph
import
Model
,
ModelStatus
,
MetricData
,
Node
from
..integration_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
.logical_optimizer.logical_plan
import
LogicalPlan
,
PhysicalDevic
e
from
.logical_optimizer.logical_plan
import
LogicalPlan
,
AbstractLogicalNod
e
from
.logical_optimizer.opt_dedup_input
import
DedupInputOptimizer
from
..evaluator.pytorch.lightning
import
Lightning
from
..evaluator.pytorch.cgo.evaluator
import
MultiModelSupervisedLearningModule
,
_MultiModelSupervisedLearningModule
from
.base
import
BaseGraphData
...
...
@@ -17,29 +25,93 @@ _logger = logging.getLogger(__name__)
class
CGOExecutionEngine
(
AbstractExecutionEngine
):
def
__init__
(
self
,
devices
=
None
,
n_model_per_graph
=
4
)
->
None
:
"""
The execution engine with Cross-Graph Optimization (CGO).
Only models using PyTorch Lighting and MultiModelSupervisedLearningModule as the evaluator can be optimized.
Otherwise, a model will be submitted independently without any cross-graph optimization.
Parameters
----------
devices : List[str] or List[GPUDevice]
Available devices for execution.
If a list of str is provided, it will build a list of GPUDevice in a server named ``single_server``
max_concurrency : int
The maximum number of trials to run concurrently.
batch_waiting_time: int
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
"""
def
__init__
(
self
,
devices
:
List
[
GPUDevice
]
=
None
,
max_concurrency
:
int
=
None
,
batch_waiting_time
:
int
=
60
,
)
->
None
:
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
logical_plan_counter
=
0
self
.
n_model_per_graph
=
n_model_per_graph
self
.
available_devices
:
List
[
GPUDevice
]
=
[]
self
.
max_concurrency
:
int
=
max_concurrency
for
device
in
devices
:
self
.
available_devices
.
append
(
device
)
self
.
all_devices
=
self
.
available_devices
.
copy
()
self
.
_batch_waiting_time
=
batch_waiting_time
# seconds to wait for all models in a batch to do cross-graph optimization
self
.
_optimizers
=
[
DedupInputOptimizer
()]
self
.
_original_models
=
{}
self
.
_original_model_to_multi_model
=
{}
self
.
devices
=
[]
if
devices
is
None
else
devices
self
.
_trial_to_original_models
=
{}
self
.
_trial_used_devices
:
Dict
[
int
,
List
[
GPUDevice
]]
=
{}
self
.
_history
:
List
[
Model
]
=
[]
self
.
_queuing_jobs
:
List
[
Model
]
=
[]
self
.
_queue_lock
=
threading
.
Lock
()
# register advisor callbacks
advisor
=
get_advisor
()
advisor
.
send_trial_callback
=
self
.
_send_trial_callback
advisor
.
request_trial_jobs_callback
=
self
.
_request_trial_jobs_callback
#
advisor.send_trial_callback = self._send_trial_callback
#
advisor.request_trial_jobs_callback = self._request_trial_jobs_callback
advisor
.
trial_end_callback
=
self
.
_trial_end_callback
advisor
.
intermediate_metric_callback
=
self
.
_intermediate_metric_callback
advisor
.
final_metric_callback
=
self
.
_final_metric_callback
self
.
_stopped
=
False
self
.
_consumer_thread
=
threading
.
Thread
(
target
=
self
.
_consume_queue
)
self
.
_consumer_thread
.
start
()
def
join
(
self
):
self
.
_stopped
=
True
self
.
_consumer_thread
.
join
()
def
add_optimizer
(
self
,
opt
):
self
.
_optimizers
.
append
(
opt
)
def
submit_models
(
self
,
*
models
:
List
[
Model
])
->
None
:
curr_time
=
time
.
time
()
_logger
.
info
(
'%d models are submitted'
,
len
(
models
))
self
.
_queue_lock
.
acquire
()
self
.
_queuing_jobs
.
extend
([(
curr_time
,
_
)
for
_
in
models
])
self
.
_queue_lock
.
release
()
def
_consume_queue
(
self
):
# a thread to monitor self.queuing_jobs to consume them in batch
while
not
self
.
_stopped
:
if
len
(
self
.
_queuing_jobs
)
>
0
:
curr_time
=
time
.
time
()
self
.
_queue_lock
.
acquire
()
if
(
self
.
max_concurrency
and
len
(
self
.
_queuing_jobs
)
>=
self
.
max_concurrency
):
self
.
_submit_models_in_batch
(
*
[
_
[
1
]
for
_
in
self
.
_queuing_jobs
[:
self
.
max_concurrency
]])
self
.
_queuing_jobs
=
self
.
_queuing_jobs
[
self
.
max_concurrency
:]
elif
len
(
self
.
available_devices
)
<=
len
(
self
.
_queuing_jobs
)
or
\
(
curr_time
-
self
.
_queuing_jobs
[
0
][
0
]
>
self
.
_batch_waiting_time
):
self
.
_submit_models_in_batch
(
*
[
_
[
1
]
for
_
in
self
.
_queuing_jobs
])
self
.
_queuing_jobs
=
[]
self
.
_queue_lock
.
release
()
time
.
sleep
(
1
)
def
_submit_models_in_batch
(
self
,
*
models
:
List
[
Model
])
->
None
:
_logger
.
info
(
'%d models are submitted in batch'
,
len
(
models
))
logical
=
self
.
_build_logical
(
models
)
for
opt
in
self
.
_optimizers
:
...
...
@@ -47,31 +119,51 @@ class CGOExecutionEngine(AbstractExecutionEngine):
phy_models_and_placements
=
self
.
_assemble
(
logical
)
for
model
,
placement
,
grouped_models
in
phy_models_and_placements
:
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
,
placement
=
placement
),
model
.
evaluator
)
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
,
placement
=
placement
),
model
.
evaluator
)
trial_id
=
send_trial
(
data
.
dump
())
# unique non-cpu devices used by the trial
self
.
_trial_used_devices
[
trial_id
]
=
list
([
_
for
_
in
set
(
placement
.
values
())
if
isinstance
(
_
,
GPUDevice
)])
# currently, it is impossible for search strategy to submit models more than the number of available devices
for
used_device
in
self
.
_trial_used_devices
[
trial_id
]:
self
.
available_devices
.
remove
(
used_device
)
# used_device must be in self.available_devices
self
.
_running_models
[
trial_id
]
=
model
self
.
_trial_to_original_models
[
trial_id
]
=
[]
for
m
in
grouped_models
:
self
.
_original_models
[
m
.
model_id
]
=
m
self
.
_original_model_to_multi_model
[
m
.
model_id
]
=
model
self
.
_running_models
[
send_trial
(
data
.
dump
())]
=
model
# for model in models:
# data = BaseGraphData(codegen.model_to_pytorch_script(model),
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
self
.
_trial_to_original_models
[
trial_id
].
append
(
m
.
model_id
)
self
.
_history
.
append
(
m
)
def
list_models
(
self
)
->
Iterable
[
Model
]:
raise
NotImplementedError
def
_assemble
(
self
,
logical_plan
:
LogicalPlan
)
->
List
[
Tuple
[
Model
,
PhysicalDevice
]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
# if node.graph.model not in unique_models:
# unique_models.add(node.graph.model)
# return [m for m in unique_models]
grouped_models
:
List
[
Dict
[
Model
,
PhysicalDevice
]]
=
AssemblePolicy
().
group
(
logical_plan
)
return
self
.
_history
def
_assemble
(
self
,
logical_plan
:
LogicalPlan
)
->
List
[
Tuple
[
Model
,
Dict
[
Node
,
GPUDevice
],
List
[
Model
]]]:
# try to use the available_devices first so that it can be launched as early as possible
# if free devices are not enough to assemble all models in one trial, try all devices
if
len
(
self
.
available_devices
)
>
0
:
grouped_models
:
List
[
Dict
[
Model
,
GPUDevice
]]
=
AssemblePolicy
().
group
(
logical_plan
,
self
.
available_devices
)
if
len
(
self
.
available_devices
)
==
0
or
len
(
grouped_models
)
>
1
:
grouped_models
:
List
[
Dict
[
Model
,
GPUDevice
]]
=
AssemblePolicy
().
group
(
logical_plan
,
self
.
all_devices
)
phy_models_and_placements
=
[]
for
multi_model
in
grouped_models
:
model
,
model_placement
=
logical_plan
.
assemble
(
multi_model
)
assert
isinstance
(
model
.
evaluator
,
Lightning
),
\
"cross-graph optimization only supports pytorch lighting as evaluator"
assert
isinstance
(
model
.
evaluator
.
module
,
_MultiModelSupervisedLearningModule
),
\
"cross-graph optimization only support MultiModelSupervisedLearningModule"
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params
=
model
.
evaluator
.
module
.
_init_parameters
.
copy
()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params
[
'n_models'
]
=
len
(
multi_model
)
new_module
=
_MultiModelSupervisedLearningModule
(
**
new_module_init_params
)
model
.
evaluator
.
module
=
new_module
phy_models_and_placements
.
append
((
model
,
model_placement
,
multi_model
.
keys
()))
return
phy_models_and_placements
...
...
@@ -85,13 +177,14 @@ class CGOExecutionEngine(AbstractExecutionEngine):
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
self
.
_listeners
.
append
(
listener
)
def
_send_trial_callback
(
self
,
paramater
:
dict
)
->
None
:
for
listener
in
self
.
_listeners
:
listener
.
on_resource_used
(
0
)
# FIXME: find the real resource id
# def _send_trial_callback(self, paramater: dict) -> None:
# if len(self.available_devices) == 0:
# _logger.warning('There is no available devices, but trial is submitted.')
# _logger.debug('Resource used. Remaining: %d', len(self.available_devices))
def
_request_trial_jobs_callback
(
self
,
num_trials
:
int
)
->
None
:
for
listener
in
self
.
_listeners
:
listener
.
on_resource_available
([
0
]
*
num_trials
)
# FIXME: find the real
resource
id
#
def _request_trial_jobs_callback(self, num_trials: int) -> None:
#
self.resources += num_trials
#
_logger.info('
on_resource_available
: %d', self.
resource
s)
def
_trial_end_callback
(
self
,
trial_id
:
int
,
success
:
bool
)
->
None
:
model
=
self
.
_running_models
[
trial_id
]
...
...
@@ -108,31 +201,40 @@ class CGOExecutionEngine(AbstractExecutionEngine):
original_model
.
status
=
ModelStatus
.
Failed
for
listener
in
self
.
_listeners
:
listener
.
on_training_end
(
original_model
,
success
)
self
.
available_devices
.
extend
(
self
.
_trial_used_devices
[
trial_id
])
self
.
available_devices
=
sorted
(
list
(
set
(
self
.
available_devices
)))
del
self
.
_running_models
[
trial_id
]
def
_intermediate_metric_callback
(
self
,
trial_id
:
int
,
metrics
:
MetricData
)
->
None
:
# model = self._running_models[trial_id]
merged_metrics
=
dict
(
metrics
)
merged_metrics
=
{}
for
idx
,
_
in
enumerate
(
metrics
):
merged_metrics
[
self
.
_trial_to_original_models
[
trial_id
][
idx
]]
=
metrics
[
idx
]
for
model_id
in
merged_metrics
:
int_model_id
=
int
(
model_id
)
self
.
_original_models
[
int_model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
# model.intermediate_metrics.append(metrics)
self
.
_original_models
[
model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
for
listener
in
self
.
_listeners
:
listener
.
on_intermediate_metric
(
self
.
_original_models
[
int_
model_id
],
merged_metrics
[
model_id
])
listener
.
on_intermediate_metric
(
self
.
_original_models
[
model_id
],
merged_metrics
[
model_id
])
def
_final_metric_callback
(
self
,
trial_id
:
int
,
metrics
:
MetricData
)
->
None
:
merged_metrics
=
dict
(
metrics
)
for
model_id
in
merged_metrics
:
int_model_id
=
int
(
model_id
)
self
.
_original_models
[
int_model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
# model.intermediate_metrics.append(metrics)
for
listener
in
self
.
_listeners
:
listener
.
on_metric
(
self
.
_original_models
[
int_model_id
],
merged_metrics
[
model_id
])
_logger
.
debug
(
metrics
)
if
isinstance
(
metrics
,
float
):
self
.
_listeners
[
0
].
on_metric
(
self
.
_running_models
[
trial_id
],
metrics
)
else
:
merged_metrics
=
{}
for
idx
,
_
in
enumerate
(
metrics
):
merged_metrics
[
self
.
_trial_to_original_models
[
trial_id
][
idx
]]
=
metrics
[
idx
]
for
model_id
in
merged_metrics
:
self
.
_original_models
[
model_id
].
metric
=
merged_metrics
[
model_id
]
for
listener
in
self
.
_listeners
:
listener
.
on_metric
(
self
.
_original_models
[
model_id
],
merged_metrics
[
model_id
])
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
raise
NotImplementedError
# move the method from listener to here?
# the _queuing_jobs need to use available_devices first
return
len
(
self
.
available_devices
)
-
len
(
self
.
_queuing_jobs
)
def
budget_exhausted
(
self
)
->
bool
:
raise
NotImplementedError
advisor
=
get_advisor
()
return
advisor
.
stopping
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
...
...
@@ -141,20 +243,86 @@ class CGOExecutionEngine(AbstractExecutionEngine):
"""
graph_data
=
BaseGraphData
.
load
(
receive_trial_parameters
())
_logger
.
info
(
'CGO_ENGINE trial parameters received'
)
with
open
(
'_generated_model.py'
,
'w'
)
as
f
:
random_str
=
''
.
join
(
random
.
choice
(
string
.
ascii_uppercase
+
string
.
digits
)
for
_
in
range
(
6
))
file_name
=
f
'_generated_model/
{
random_str
}
.py'
os
.
makedirs
(
os
.
path
.
dirname
(
file_name
),
exist_ok
=
True
)
with
open
(
file_name
,
'w'
)
as
f
:
f
.
write
(
graph_data
.
model_script
)
# with open('_debug_graph_data.json', 'w') as f:
# json.dump(graph_data.dump(), f)
trainer_cls
=
utils
.
import_
(
graph_data
.
training_module
)
model_cls
=
utils
.
import_
(
f
"_generated_model.
{
graph_data
.
training_kwargs
[
'model_cls'
]
}
"
)
trainer_instance
=
trainer_cls
(
model_cls
(),
graph_data
.
training_kwargs
)
trainer_instance
.
fit
()
trainer_instance
=
graph_data
.
evaluator
model_cls
=
utils
.
import_
(
f
'_generated_model.
{
random_str
}
._model'
)
trainer_instance
.
fit
(
model_cls
())
os
.
remove
(
file_name
)
def
_remap_cuda_device
(
group_model
:
Dict
[
Model
,
GPUDevice
]):
used_devices
=
{}
for
m
in
group_model
:
if
group_model
[
m
].
node_id
not
in
used_devices
:
used_devices
[
group_model
[
m
].
node_id
]
=
{}
if
isinstance
(
group_model
[
m
],
GPUDevice
):
if
group_model
[
m
].
gpu_id
not
in
used_devices
[
group_model
[
m
].
node_id
]:
n_used_gpu_in_server
=
len
(
used_devices
[
group_model
[
m
].
node_id
])
used_devices
[
group_model
[
m
].
node_id
][
group_model
[
m
].
gpu_id
]
=
n_used_gpu_in_server
group_model
[
m
].
gpu_id
=
used_devices
[
group_model
[
m
].
node_id
][
group_model
[
m
].
gpu_id
]
return
group_model
class
AssemblePolicy
:
@
staticmethod
def
group
(
logical_plan
):
def
_is_related_node
(
model
:
Model
,
node
:
Node
):
if
isinstance
(
node
,
AbstractLogicalNode
):
if
model
in
node
.
related_models
:
return
True
else
:
if
model
==
node
.
graph
.
model
:
return
True
return
False
@
staticmethod
def
_check_graph_connectivity
(
model
:
Model
,
group_model
:
Dict
[
Model
,
GPUDevice
],
logical_plan
:
LogicalPlan
)
->
bool
:
for
edge
in
logical_plan
.
logical_graph
.
edges
:
if
AssemblePolicy
.
_is_related_node
(
model
,
edge
.
head
)
or
\
AssemblePolicy
.
_is_related_node
(
model
,
edge
.
tail
):
for
grouped_model
in
group_model
:
if
AssemblePolicy
.
_is_related_node
(
grouped_model
,
edge
.
head
)
or
\
AssemblePolicy
.
_is_related_node
(
grouped_model
,
edge
.
tail
):
return
True
return
False
@
staticmethod
def
_check_evaluator
(
new_model
:
Model
,
group_model
:
Dict
[
Model
,
GPUDevice
])
->
bool
:
if
not
(
isinstance
(
new_model
.
evaluator
,
Lightning
)
and
isinstance
(
new_model
.
evaluator
.
module
,
MultiModelSupervisedLearningModule
)):
return
False
for
m
in
group_model
:
if
not
m
.
evaluator
==
new_model
.
evaluator
:
return
False
return
True
@
staticmethod
def
group
(
logical_plan
,
available_devices
):
# TODO: Packing multiple model in one GPU
# Currently, we only support one model per GPU
all_grouped_models
=
[]
group_model
=
{}
assert
(
len
(
available_devices
)
>
0
)
# There should be at least 1 device, set in CGO_DEVICES
for
idx
,
m
in
enumerate
(
logical_plan
.
models
):
group_model
[
m
]
=
PhysicalDevice
(
'server'
,
f
'cuda:
{
idx
}
'
)
return
[
group_model
]
# models in one group should
# (1) not use more GPUs than available_devices
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
# (3) use same MultiModelSupervisedLearningModule
if
len
(
group_model
)
>
0
and
\
(
AssemblePolicy
.
_check_graph_connectivity
(
m
,
group_model
,
logical_plan
)
==
False
or
AssemblePolicy
.
_check_evaluator
(
m
,
group_model
)
==
False
):
all_grouped_models
.
append
(
_remap_cuda_device
(
group_model
))
group_model
=
{}
group_model
[
m
]
=
available_devices
[
idx
%
len
(
available_devices
)]
if
len
(
group_model
)
==
len
(
available_devices
)
or
\
idx
==
len
(
logical_plan
.
models
)
-
1
:
all_grouped_models
.
append
(
_remap_cuda_device
(
group_model
))
group_model
=
{}
return
all_grouped_models
nni/retiarii/execution/logical_optimizer/logical_plan.py
View file @
f2f58dbb
...
...
@@ -2,30 +2,30 @@
# Licensed under the MIT license.
import
copy
from
typing
import
Dict
,
Tuple
,
List
,
Any
from
typing
import
Dict
,
Tuple
,
Any
,
Union
from
nni.retiarii.utils
import
uid
from
nni.common.device
import
GPUDevice
from
...graph
import
Cell
,
Edge
,
Graph
,
Model
,
Node
from
...operation
import
Operation
,
_IOPseudoOperation
class
PhysicalDevice
:
def
__init__
(
self
,
server
:
str
,
device
:
str
):
self
.
server
=
server
self
.
device
=
device
def
__eq__
(
self
,
o
)
->
bool
:
return
self
.
server
==
o
.
server
and
self
.
device
==
o
.
device
class
CPUDevice
:
def
__init__
(
self
,
node_id
):
self
.
node_id
=
node_id
self
.
device
=
'cpu'
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
server
+
'_'
+
self
.
device
)
def
device_repr
(
self
)
:
return
"cpu"
class
AbstractLogicalNode
(
Node
):
def
__init__
(
self
,
graph
,
node_id
,
name
,
operation
,
_internal
=
False
):
super
().
__init__
(
graph
,
node_id
,
name
,
operation
,
_internal
=
_internal
)
self
.
related_models
=
[]
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
Physical
Device
])
->
Tuple
[
Node
,
Physical
Device
]:
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
GPU
Device
])
->
Tuple
[
Node
,
GPU
Device
]:
raise
NotImplementedError
def
_fork_to
(
self
,
graph
:
Graph
):
...
...
@@ -40,8 +40,7 @@ class LogicalGraph(Graph):
nodes_dump
=
{}
for
node
in
self
.
hidden_nodes
:
if
isinstance
(
node
,
OriginNode
):
nodes_dump
[
f
"
{
node
.
original_graph
.
model
.
model_id
}
_
{
node
.
name
}
"
]
=
node
.
_dump
(
)
nodes_dump
[
f
"
{
node
.
original_graph
.
model
.
model_id
}
_
{
node
.
name
}
"
]
=
node
.
_dump
()
else
:
nodes_dump
[
f
"
{
node
.
graph
.
model
.
model_id
}
_
{
node
.
name
}
"
]
=
node
.
_dump
()
...
...
@@ -93,7 +92,7 @@ class OriginNode(AbstractLogicalNode):
self
.
original_graph
=
original_graph
self
.
original_node
=
original_node
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
Physical
Device
])
->
Tuple
[
Node
,
Physical
Device
]:
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
GPU
Device
])
->
Tuple
[
Node
,
GPU
Device
]:
model_id
=
self
.
original_node
.
graph
.
model
.
model_id
new_node
=
Node
(
self
.
original_node
.
graph
,
self
.
original_node
.
id
,
f
"M_
{
model_id
}
_"
+
...
...
@@ -137,30 +136,32 @@ class LogicalPlan:
for
edge
in
from_graph
.
edges
:
new_head
=
id_to_new_node
[
edge
.
head
.
id
]
new_tail
=
id_to_new_node
[
edge
.
tail
.
id
]
Edge
((
new_head
,
edge
.
head_slot
),
(
new_tail
,
edge
.
tail_slot
),
_internal
=
True
).
_register
()
Edge
((
new_head
,
edge
.
head_slot
),
(
new_tail
,
edge
.
tail_slot
),
_internal
=
True
).
_register
()
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
Physical
Device
])
\
->
Tuple
[
Model
,
Dict
[
Node
,
Physical
Device
]
,
List
[
Model
]]:
phy_model
=
Model
(
_internal
=
True
)
# self.lp_model.fork()
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
GPU
Device
])
\
->
Tuple
[
Model
,
Dict
[
Node
,
Union
[
GPU
Device
,
CPUDevice
]
]]:
phy_model
=
Model
(
_internal
=
True
)
phy_graph
=
self
.
lp_model
.
root_graph
.
_fork_to
(
phy_model
)
# Add a flag to mark multi-model in graph json.
# Multi-model has a list of training configs in kwargs['model_kwargs']
if
len
(
multi_model_placement
)
>
1
:
phy_model
.
evaluator
.
kwargs
[
'is_multi_model'
]
=
True
phy_model
.
evaluator
.
kwargs
[
'model_cls'
]
=
phy_graph
.
name
phy_model
.
evaluator
.
kwargs
[
'model_kwargs'
]
=
[]
# FIXME: allow user to specify
phy_model
.
evaluator
.
module
=
'nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer'
phy_graph
.
_rename_graph
(
phy_graph
.
name
,
"_model"
)
# merge sub-graphs
for
model
in
multi_model_placement
:
if
phy_model
.
evaluator
is
None
and
model
.
evaluator
is
not
None
:
phy_model
.
evaluator
=
model
.
evaluator
for
graph_name
in
model
.
graphs
:
if
graph_name
!=
model
.
_root_graph_name
:
model
.
graphs
[
graph_name
].
_fork_to
(
new_graph
=
model
.
graphs
[
graph_name
].
_fork_to
(
phy_model
,
name_prefix
=
f
'M_
{
model
.
model_id
}
_'
)
# prefix of M_ of hidden_nodes name in non-root graphs is added here
for
new_node
in
new_graph
.
hidden_nodes
:
if
isinstance
(
new_node
.
operation
,
Cell
):
old_cell_name
=
new_node
.
operation
.
cell_name
new_node
.
operation
=
copy
.
deepcopy
(
new_node
.
operation
)
new_node
.
operation
.
cell_name
=
f
'M_
{
model
.
model_id
}
_
{
old_cell_name
}
'
assert
(
phy_model
.
evaluator
is
not
None
)
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
evaluator_slot
=
{}
# Model ID -> Slot ID
...
...
@@ -169,6 +170,9 @@ class LogicalPlan:
# Replace all logical nodes to executable physical nodes
hidden_nodes
=
phy_graph
.
hidden_nodes
.
copy
()
node_placements
=
{}
added_models
=
[]
for
node
in
hidden_nodes
:
if
isinstance
(
node
,
OriginNode
):
model_id
=
node
.
original_graph
.
model
.
model_id
...
...
@@ -185,12 +189,9 @@ class LogicalPlan:
if
isinstance
(
new_node
.
operation
,
_IOPseudoOperation
):
model_id
=
new_node
.
graph
.
model
.
model_id
if
model_id
not
in
evaluator_slot
:
phy_model
.
evaluator
.
kwargs
[
'model_kwargs'
].
append
(
new_node
.
graph
.
model
.
evaluator
.
kwargs
.
copy
()
)
evaluator_slot
[
model_id
]
=
len
(
phy_model
.
evaluator
.
kwargs
[
'model_kwargs'
]
)
-
1
added_models
.
append
(
model_id
)
evaluator_slot
[
model_id
]
=
len
(
added_models
)
-
1
slot
=
evaluator_slot
[
model_id
]
phy_model
.
evaluator
.
kwargs
[
'model_kwargs'
][
slot
][
'model_id'
]
=
model_id
phy_model
.
evaluator
.
kwargs
[
'model_kwargs'
][
slot
][
'use_input'
]
=
False
phy_model
.
evaluator
.
kwargs
[
'model_kwargs'
][
slot
][
'use_output'
]
=
False
else
:
slot
=
evaluator_slot
[
model_id
]
# If a model's inputs/outputs are not used in the multi-model
...
...
@@ -199,37 +200,47 @@ class LogicalPlan:
# an input/output of a model is used in a multi-model
if
new_node
.
operation
.
type
==
'_inputs'
:
input_slot_mapping
[
new_node
]
=
slot
phy_model
.
evaluator
.
kwargs
[
'model_kwargs'
][
slot
][
'use_input'
]
=
True
if
new_node
.
operation
.
type
==
'_outputs'
:
output_slot_mapping
[
new_node
]
=
slot
phy_model
.
evaluator
.
kwargs
[
'model_kwargs'
][
slot
][
'use_output'
]
=
True
self
.
node_replace
(
node
,
new_node
)
# name prefix of M_ of cells in hidden_nodes of root graphs is added here
# FIXME: merge this rename with non-root graph, only do once.
if
isinstance
(
new_node
.
operation
,
Cell
):
old_cell_name
=
new_node
.
operation
.
cell_name
new_node
.
operation
=
copy
.
deepcopy
(
new_node
.
operation
)
new_node
.
operation
.
cell_name
=
f
'M_
{
model_id
}
_
{
old_cell_name
}
'
node_placements
[
new_node
]
=
placement
# input should be at CPU, move it to GPU first if necessary
if
isinstance
(
new_node
.
operation
,
_IOPseudoOperation
)
and
new_node
.
operation
.
type
==
'_inputs'
:
# hack: only support single_server
node_placements
[
new_node
]
=
CPUDevice
(
node_id
=
placement
.
node_id
)
else
:
node_placements
[
new_node
]
=
placement
node
.
remove
()
# If two nodes are placed on different devices, use ToDevice op to copy the node
existing_edges
=
phy_graph
.
edges
.
copy
()
# Avoid a node is copied multiple times on the same device
copied_op
:
Dict
[
Tuple
(
Node
,
Physical
Device
),
Node
]
=
{}
copied_op
:
Dict
[
Tuple
(
Node
,
Union
[
GPUDevice
,
CPU
Device
]
),
Node
]
=
{}
for
edge
in
existing_edges
:
head_placement
=
node_placements
[
edge
.
head
]
tail_placement
=
node_placements
[
edge
.
tail
]
if
head_placement
!=
tail_placement
:
if
head_placement
.
server
!=
tail_placement
.
server
:
if
head_placement
.
node_id
!=
tail_placement
.
node_id
:
raise
ValueError
(
'Cross-server placement is not supported.'
)
# Same server different devices
if
(
edge
.
head
,
tail_placement
)
in
copied_op
:
to_node
=
copied_op
[(
edge
.
head
,
tail_placement
)]
else
:
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
to_node
=
Node
(
phy_graph
,
uid
(),
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
dst_name
=
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device_repr
(),
"src"
:
(
edge
.
head
.
name
,
edge
.
head_slot
),
"dst"
:
dst_name
})
to_node
=
Node
(
phy_graph
,
uid
(),
dst_name
,
to_operation
).
_register
()
Edge
((
edge
.
head
,
edge
.
head_slot
),
(
to_node
,
None
),
_internal
=
True
).
_register
()
copied_op
[(
edge
.
head
,
tail_placement
)]
=
to_node
edge
.
head
=
to_node
...
...
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
View file @
f2f58dbb
...
...
@@ -4,23 +4,28 @@
from
typing
import
List
,
Dict
,
Tuple
from
nni.retiarii.utils
import
uid
from
nni.retiarii.evaluator.pytorch.cgo.evaluator
import
MultiModelSupervisedLearningModule
from
nni.common.device
import
GPUDevice
from
...graph
import
Graph
,
Model
,
Node
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
(
AbstractLogicalNode
,
LogicalGraph
,
LogicalPlan
,
OriginNode
,
PhysicalDevice
)
OriginNode
)
_supported_
training_modules
=
[
'nni.retiarii.trainer.pytorch.PyTorchImageClassificationTrainer'
]
_supported_
evaluators
=
[
MultiModelSupervisedLearningModule
]
class
DedupInputNode
(
AbstractLogicalNode
):
def
__init__
(
self
,
logical_graph
:
LogicalGraph
,
node_id
:
int
,
nodes_to_dedup
:
List
[
Node
],
_internal
=
False
):
super
().
__init__
(
logical_graph
,
node_id
,
"Dedup_"
+
nodes_to_dedup
[
0
].
name
,
"Dedup_"
+
nodes_to_dedup
[
0
].
name
,
nodes_to_dedup
[
0
].
operation
)
self
.
origin_nodes
:
List
[
OriginNode
]
=
nodes_to_dedup
.
copy
()
self
.
related_models
=
[
_
.
original_graph
.
model
for
_
in
self
.
origin_nodes
]
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
Physical
Device
])
->
Tuple
[
Node
,
Physical
Device
]:
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
GPU
Device
])
->
Tuple
[
Node
,
GPU
Device
]:
for
node
in
self
.
origin_nodes
:
if
node
.
original_graph
.
model
in
multi_model_placement
:
new_node
=
Node
(
node
.
original_graph
,
node
.
id
,
...
...
@@ -41,6 +46,12 @@ class DedupInputOptimizer(AbstractOptimizer):
def
__init__
(
self
)
->
None
:
pass
def
_check_supported_evaluator
(
self
,
evaluator
):
for
e
in
_supported_evaluators
:
if
isinstance
(
evaluator
,
e
):
return
True
return
False
def
_check_deduplicate_by_node
(
self
,
root_node
,
node_to_check
):
if
root_node
==
node_to_check
:
return
True
...
...
@@ -48,7 +59,7 @@ class DedupInputOptimizer(AbstractOptimizer):
node_to_check
.
operation
.
type
==
'_inputs'
and
\
isinstance
(
root_node
,
OriginNode
)
and
\
isinstance
(
node_to_check
,
OriginNode
):
if
root_node
.
original_graph
.
model
.
evaluator
.
module
not
in
_supported_training_modules
:
if
self
.
_check_supported_evaluator
(
root_node
.
original_graph
.
model
.
evaluator
)
:
return
False
if
root_node
.
original_graph
.
model
.
evaluator
==
node_to_check
.
original_graph
.
model
.
evaluator
:
return
True
...
...
@@ -68,7 +79,7 @@ class DedupInputOptimizer(AbstractOptimizer):
continue
root_node
=
node
break
if
root_node
==
None
:
if
root_node
is
None
:
break
# end of convert
else
:
nodes_to_dedup
=
[]
...
...
nni/retiarii/experiment/pytorch.py
View file @
f2f58dbb
...
...
@@ -50,8 +50,11 @@ class RetiariiExeConfig(ConfigBase):
trial_code_directory
:
PathLike
=
'.'
trial_concurrency
:
int
trial_gpu_number
:
int
=
0
devices
:
Optional
[
List
[
Union
[
str
,
GPUDevice
]]]
=
None
max_experiment_duration
:
Optional
[
str
]
=
None
max_trial_number
:
Optional
[
int
]
=
None
max_concurrency_cgo
:
Optional
[
int
]
=
None
batch_waiting_time
:
Optional
[
int
]
=
None
nni_manager_ip
:
Optional
[
str
]
=
None
debug
:
bool
=
False
log_level
:
Optional
[
str
]
=
None
...
...
@@ -134,11 +137,12 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_
if
mutators
is
not
None
and
applied_mutators
:
raise
RuntimeError
(
'Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice'
)
'do not use mutators when you use LayerChoice/InputChoice'
)
if
mutators
is
not
None
:
applied_mutators
=
mutators
return
base_model_ir
,
applied_mutators
def
debug_mutated_model
(
base_model
,
trainer
,
applied_mutators
):
"""
Locally run only one trial without launching an experiment for debug purpose, then exit.
...
...
@@ -189,7 +193,7 @@ class RetiariiExperiment(Experiment):
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
_logger
.
info
(
'Strategy exit'
)
# TODO: find out a proper way to show no more trial message on WebUI
#self._dispatcher.mark_experiment_as_ending()
#
self._dispatcher.mark_experiment_as_ending()
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
...
...
@@ -205,14 +209,18 @@ class RetiariiExperiment(Experiment):
"""
atexit
.
register
(
self
.
stop
)
devices
=
self
.
_construct_devices
()
# we will probably need a execution engine factory to make this clean and elegant
if
self
.
config
.
execution_engine
==
'base'
:
from
..execution.base
import
BaseExecutionEngine
engine
=
BaseExecutionEngine
()
elif
self
.
config
.
execution_engine
==
'cgo'
:
from
..execution.cgo_engine
import
CGOExecutionEngine
engine
=
CGOExecutionEngine
(
devices
=
devices
)
# assert self.config.trial_gpu_number==1, "trial_gpu_number must be 1 to use CGOExecutionEngine"
assert
self
.
config
.
batch_waiting_time
is
not
None
devices
=
self
.
_construct_devices
()
engine
=
CGOExecutionEngine
(
devices
,
max_concurrency
=
self
.
config
.
max_concurrency_cgo
,
batch_waiting_time
=
self
.
config
.
batch_waiting_time
)
elif
self
.
config
.
execution_engine
==
'py'
:
from
..execution.python
import
PurePythonExecutionEngine
engine
=
PurePythonExecutionEngine
()
...
...
@@ -315,7 +323,7 @@ class RetiariiExperiment(Experiment):
if
self
.
_dispatcher_thread
is
not
None
:
self
.
_dispatcher
.
stopping
=
True
self
.
_dispatcher_thread
.
join
(
timeout
=
1
)
if
self
.
id
is
not
None
:
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
if
self
.
_proc
is
not
None
:
...
...
nni/retiarii/graph.py
View file @
f2f58dbb
...
...
@@ -410,7 +410,7 @@ class Graph:
return
self
is
other
def
_fork_to
(
self
,
model
:
Model
,
name_prefix
=
''
)
->
'Graph'
:
new_graph
=
Graph
(
model
,
self
.
id
,
name_prefix
+
self
.
name
,
_internal
=
True
).
_register
()
new_graph
=
Graph
(
model
,
self
.
id
,
name_prefix
+
self
.
name
,
_internal
=
True
).
_register
()
# TODO: use node copy instead
new_graph
.
input_node
.
operation
.
io_names
=
self
.
input_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
...
...
@@ -458,6 +458,11 @@ class Graph:
self
.
model
.
graphs
[
self
.
name
]
=
self
return
self
def
_rename_graph
(
self
,
old_name
,
new_name
):
self
.
model
.
graphs
[
old_name
].
name
=
new_name
self
.
model
.
graphs
[
new_name
]
=
self
.
model
.
graphs
[
old_name
]
del
self
.
model
.
graphs
[
old_name
]
@
staticmethod
def
_load
(
model
:
Model
,
name
:
str
,
ir
:
Any
)
->
'Graph'
:
graph
=
Graph
(
model
,
uid
(),
name
,
_internal
=
True
)
...
...
nni/retiarii/integration.py
View file @
f2f58dbb
...
...
@@ -158,4 +158,4 @@ class RetiariiAdvisor(MsgDispatcherBase):
return
value
[
'default'
]
else
:
return
value
return
value
return
value
\ No newline at end of file
nni/retiarii/operation.py
View file @
f2f58dbb
...
...
@@ -98,6 +98,10 @@ class PyTorchOperation(Operation):
if
hasattr
(
subclass
,
'_ori_type_name'
)
and
\
subclass_name
in
subclass
.
_ori_type_name
:
return
subclass
for
subclass
in
cls
.
__subclasses__
():
if
hasattr
(
subclass
,
'_artificial_op_name'
)
and
\
subclass_name
in
subclass
.
_artificial_op_name
:
return
subclass
return
cls
@
classmethod
...
...
nni/retiarii/operation_def/torch_op_def.py
View file @
f2f58dbb
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
(
Any
,
List
)
from
typing
import
(
Any
,
Dict
,
List
)
import
torch
...
...
@@ -32,21 +32,27 @@ scalar_type_to_pytorch_type = [
'torch.bool'
,
# 11
]
class
NoOpIdentity
(
PyTorchOperation
):
"""
this operator type is added by us
"""
_ori_type_name
=
[
'noop_identity'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
=
{
", "
.
join
(
inputs
)
}
'
class
ModuleOperator
(
PyTorchOperation
):
_ori_type_name
=
[
'ModuleOperator'
,
'shared'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
class
FunctionalOperator
(
PyTorchOperation
):
_ori_type_name
=
[
'FunctionalOperator'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
func_name
=
self
.
type
[
len
(
'Function.'
):]
if
not
hasattr
(
torch
.
nn
.
functional
,
func_name
):
...
...
@@ -54,8 +60,10 @@ class FunctionalOperator(PyTorchOperation):
f
'
{
func_name
}
is not in it.'
)
return
f
'
{
output
}
= F.
{
func_name
}
(
{
", "
.
join
(
inputs
)
}
)'
class
PrimConstant
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::Constant'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
...
...
@@ -75,63 +83,83 @@ class PrimConstant(PyTorchOperation):
else
:
raise
RuntimeError
(
f
'unsupported type of prim::Constant:
{
self
.
parameters
[
"type"
]
}
'
)
class
PrimListConstruct
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::ListConstruct'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
class
PrimListUnpack
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::ListUnpack'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
PrimTupleConstruct
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::TupleConstruct'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= (
{
", "
.
join
(
inputs
)
}
)'
class
PrimTupleUnpack
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::TupleUnpack'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
# have single output here, because the following code uses index to access the unpacked values
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
PrimGetAttr
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::GetAttr'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
if
self
.
parameters
[
'value'
]
is
not
None
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'value'
]
}
"
else
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
class
SimpleMember
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::is_cuda'
,
'prim::data'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
member_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
member_name
}
'
class
AtenContiguous
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::contiguous'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
# defined in pytorch/c10/core/MemoryFormat.h
assert
inputs_value
[
1
]
in
[
0
,
1
,
2
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.contiguous(memory_format=
{
mem_format
[
inputs_value
[
1
]]
}
)'
class
AtenGetitem
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::__getitem__'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
class
AtenAppend
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::append'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
assert
len
(
inputs
)
==
2
return
f
'_,
{
output
}
=
{
inputs
[
0
]
}
.append(
{
inputs
[
1
]
}
),
{
inputs
[
0
]
}
'
class
MergedSlice
(
PyTorchOperation
):
_ori_type_name
=
[
'MergedSlice'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
if
(
len
(
inputs
)
-
1
)
%
4
==
0
:
slices
=
[]
...
...
@@ -148,23 +176,30 @@ class MergedSlice(PyTorchOperation):
# the following Aten classes means these aten ops are not in torch.Tensor
class
AtenBool
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::Bool'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
class
AtenNot
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::__not__'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= not
{
inputs
[
0
]
}
'
class
AtenCat
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::cat'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
#====================================
# ====================================
class
AtenTensors
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::full'
,
'aten::full_like'
,
'aten::empty_like'
,
...
...
@@ -209,20 +244,26 @@ class AtenTensors(PyTorchOperation):
else
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
op_name
}
(
{
", "
.
join
(
args_list
[
1
:])
}
)'
#====================================
# ====================================
class
AtenFloordiv
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::floordiv'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
//
{
inputs
[
1
]
}
'
class
AtenLen
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::len'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= len(
{
inputs
[
0
]
}
)'
class
AtenIntImplicit
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::IntImplicit'
,
'aten::Float'
,
'aten::Int'
,
'aten::ScalarImplicit'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
if
self
.
type
.
endswith
(
'Implicit'
):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
...
...
@@ -231,11 +272,14 @@ class AtenIntImplicit(PyTorchOperation):
elif
self
.
type
==
'aten::Float'
:
return
f
'
{
output
}
= float(
{
inputs
[
0
]
}
)'
class
AtenIndex
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::index'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
ManuallyChooseDef
=
{
'aten::flatten'
:
[(
'start_dim'
,
'int'
,
'0'
),
(
'end_dim'
,
'int'
,
'-1'
)],
'aten::split'
:
[(
'split_size'
,
'int'
,
'None'
),
(
'dim'
,
'int'
,
'0'
)],
...
...
@@ -248,21 +292,24 @@ ManuallyChooseDef = {
}
TensorOpExceptions
=
{
'aten::sub'
:
lambda
output
,
inputs
:
f
'
{
output
}
=
{
inputs
[
0
]
}
-
{
inputs
[
1
]
}
'
,
# example: x.size(1) - 3
'aten::add'
:
lambda
output
,
inputs
:
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
# example: input.shape[0] + 5
'aten::sub'
:
lambda
output
,
inputs
:
f
'
{
output
}
=
{
inputs
[
0
]
}
-
{
inputs
[
1
]
}
'
,
# example: x.size(1) - 3
'aten::add'
:
lambda
output
,
inputs
:
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
# example: input.shape[0] + 5
}
TorchOpExclude
=
[
'aten::Size'
,
'aten::as_tensor'
,
'aten::device'
,
'aten::manual_seed'
,
'aten::quantized_gru'
,
'aten::quantized_lstm'
,
'aten::save'
,
'aten::tensor'
,
'aten::wait'
]
]
def
_hidden
(
name
):
return
name
.
startswith
(
'_'
)
and
not
name
.
startswith
(
'__'
)
def
_emit_args
(
args
):
# filter out the `out` argument here
return
[(
arg
.
name
,
str
(
arg
.
type
),
str
(
arg
.
default_value
))
for
arg
in
args
]
# if arg.name != 'out'
return
[(
arg
.
name
,
str
(
arg
.
type
),
str
(
arg
.
default_value
))
for
arg
in
args
]
# if arg.name != 'out'
def
_get_tensor_ops
():
def
is_tensor_method
(
schema
):
...
...
@@ -291,6 +338,7 @@ def _get_tensor_ops():
return
op_args
.
keys
(),
op_args
def
_get_torch_ops
():
torch_op_args
=
{}
for
mod
in
torch
.
jit
.
_builtins
.
_modules_containing_builtins
:
...
...
@@ -316,6 +364,7 @@ def _get_torch_ops():
return
torch_op_args
.
keys
(),
torch_op_args
def
_get_torch_ops_exclude_tensor_ops
():
tensor_op_names
,
_
=
_get_tensor_ops
()
torch_op_names
,
torch_ops
=
_get_torch_ops
()
...
...
@@ -330,6 +379,7 @@ def _get_torch_ops_exclude_tensor_ops():
return
torch_exclude_ops
.
keys
(),
torch_exclude_ops
class
TensorOps
(
PyTorchOperation
):
"""
corresponding to _get_tensor_ops in torch.jit.supported_ops
...
...
@@ -346,7 +396,7 @@ class TensorOps(PyTorchOperation):
name
=
','
.
join
([
arg
[
0
]
for
arg
in
each
])
concated_names
.
append
(
name
)
for
i
in
range
(
len
(
concated_names
)
-
1
):
if
concated_names
[
i
]
!=
concated_names
[
i
+
1
]:
if
concated_names
[
i
]
!=
concated_names
[
i
+
1
]:
return
False
return
True
...
...
@@ -383,6 +433,7 @@ class TensorOps(PyTorchOperation):
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
+
1
]
}
'
for
i
,
(
name
,
t
,
default
)
in
enumerate
(
matched_args
)])
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
op_name
}
(
{
args_str
}
)'
class
TorchOps
(
PyTorchOperation
):
"""
corresponding to _get_nn_functional_ops in torch.jit.supported_ops
...
...
@@ -400,7 +451,7 @@ class TorchOps(PyTorchOperation):
name
=
','
.
join
([
arg
[
0
]
for
arg
in
each
])
concated_names
.
append
(
name
)
for
i
in
range
(
len
(
concated_names
)
-
1
):
if
concated_names
[
i
]
!=
concated_names
[
i
+
1
]:
if
concated_names
[
i
]
!=
concated_names
[
i
+
1
]:
return
False
return
True
...
...
@@ -424,19 +475,36 @@ class TorchOps(PyTorchOperation):
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
matched_args
=
TorchOps
.
_get_matched_args
(
self
.
type
,
inputs
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
]
}
'
if
t
.
startswith
(
'Optional['
)
else
f
'
{
inputs
[
i
]
}
'
\
for
i
,
(
name
,
t
,
default
)
in
enumerate
(
matched_args
)])
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
]
}
'
if
t
.
startswith
(
'Optional['
)
else
f
'
{
inputs
[
i
]
}
'
for
i
,
(
name
,
t
,
default
)
in
enumerate
(
matched_args
)])
return
f
'
{
output
}
= torch.
{
op_name
}
(
{
args_str
}
)'
class
AtenAvgpool2d
(
PyTorchOperation
):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name
=
[
'aten::avg_pool2d'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= F.avg_pool2d(
{
", "
.
join
(
inputs
)
}
)'
class
ToDevice
(
PyTorchOperation
):
_artificial_op_name
=
"ToDevice"
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
],
_internal
:
bool
=
False
):
self
.
type
=
"ToDevice"
self
.
device
=
parameters
[
'device'
]
self
.
src
=
parameters
[
'src'
]
self
.
dst
=
parameters
[
'dst'
]
def
__repr__
(
self
):
return
f
'to("
{
self
.
device
}
")'
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.to("
{
self
.
device
}
")'
class
AtenDet
(
PyTorchOperation
):
# for torch 1.9
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name
=
[
'aten::linalg_det'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= torch.det(
{
inputs
[
0
]
}
)'
\ No newline at end of file
return
f
'
{
output
}
= torch.det(
{
inputs
[
0
]
}
)'
test/retiarii_test/cgo/darts_model.py
0 → 100644
View file @
f2f58dbb
from
collections
import
OrderedDict
from
typing
import
(
List
,
Optional
)
import
torch
import
torch.nn
as
torch_nn
#sys.path.append(str(Path(__file__).resolve().parents[2]))
import
ops
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
@
basic_unit
class
AuxiliaryHead
(
nn
.
Module
):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
def
__init__
(
self
,
input_size
,
C
,
n_classes
):
""" assuming input size 7x7 or 8x8 """
assert
input_size
in
[
7
,
8
]
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
AvgPool2d
(
5
,
stride
=
input_size
-
5
,
padding
=
0
,
count_include_pad
=
False
),
# 2x2 out
nn
.
Conv2d
(
C
,
128
,
kernel_size
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
128
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
128
,
768
,
kernel_size
=
2
,
bias
=
False
),
# 1x1 out
nn
.
BatchNorm2d
(
768
),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
linear
=
nn
.
Linear
(
768
,
n_classes
)
def
forward
(
self
,
x
):
out
=
self
.
net
(
x
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
# flatten
logits
=
self
.
linear
(
out
)
return
logits
class
Node
(
nn
.
Module
):
def
__init__
(
self
,
node_id
,
num_prev_nodes
,
channels
,
num_downsample_connect
):
super
().
__init__
()
self
.
ops
=
nn
.
ModuleList
()
choice_keys
=
[]
for
i
in
range
(
num_prev_nodes
):
stride
=
2
if
i
<
num_downsample_connect
else
1
choice_keys
.
append
(
"{}_p{}"
.
format
(
node_id
,
i
))
self
.
ops
.
append
(
nn
.
LayerChoice
([
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
nn
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
)
]))
self
.
drop_path
=
ops
.
DropPath
()
self
.
input_switch
=
nn
.
InputChoice
(
n_candidates
=
num_prev_nodes
,
n_chosen
=
2
)
def
forward
(
self
,
prev_nodes
:
List
[
'Tensor'
])
->
'Tensor'
:
#assert self.ops.__len__() == len(prev_nodes)
#out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out
=
[]
for
i
,
op
in
enumerate
(
self
.
ops
):
out
.
append
(
op
(
prev_nodes
[
i
]))
#out = [self.drop_path(o) if o is not None else None for o in out]
return
self
.
input_switch
(
out
)
class
Cell
(
nn
.
Module
):
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
super
().
__init__
()
self
.
reduction
=
reduction
self
.
n_nodes
=
n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if
reduction_p
:
self
.
preproc0
=
ops
.
FactorizedReduce
(
channels_pp
,
channels
,
affine
=
False
)
else
:
self
.
preproc0
=
ops
.
StdConv
(
channels_pp
,
channels
,
1
,
1
,
0
,
affine
=
False
)
self
.
preproc1
=
ops
.
StdConv
(
channels_p
,
channels
,
1
,
1
,
0
,
affine
=
False
)
# generate dag
self
.
mutable_ops
=
nn
.
ModuleList
()
for
depth
in
range
(
2
,
self
.
n_nodes
+
2
):
self
.
mutable_ops
.
append
(
Node
(
"{}_n{}"
.
format
(
"reduce"
if
reduction
else
"normal"
,
depth
),
depth
,
channels
,
2
if
reduction
else
0
))
def
forward
(
self
,
s0
,
s1
):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors
=
[
self
.
preproc0
(
s0
),
self
.
preproc1
(
s1
)]
new_tensors
=
[]
for
node
in
self
.
mutable_ops
:
tmp
=
tensors
+
new_tensors
cur_tensor
=
node
(
tmp
)
new_tensors
.
append
(
cur_tensor
)
output
=
torch
.
cat
(
new_tensors
,
dim
=
1
)
return
output
class
CNN
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
stem_multiplier
=
3
,
auxiliary
=
False
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
channels
=
channels
self
.
n_classes
=
n_classes
self
.
n_layers
=
n_layers
self
.
aux_pos
=
2
*
n_layers
//
3
if
auxiliary
else
-
1
c_cur
=
stem_multiplier
*
self
.
channels
self
.
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
c_cur
,
3
,
1
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
c_cur
)
)
# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp
,
channels_p
,
c_cur
=
c_cur
,
c_cur
,
channels
self
.
cells
=
nn
.
ModuleList
()
reduction_p
,
reduction
=
False
,
False
for
i
in
range
(
n_layers
):
reduction_p
,
reduction
=
reduction
,
False
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
if
i
in
[
n_layers
//
3
,
2
*
n_layers
//
3
]:
c_cur
*=
2
reduction
=
True
cell
=
Cell
(
n_nodes
,
channels_pp
,
channels_p
,
c_cur
,
reduction_p
,
reduction
)
self
.
cells
.
append
(
cell
)
c_cur_out
=
c_cur
*
n_nodes
channels_pp
,
channels_p
=
channels_p
,
c_cur_out
#if i == self.aux_pos:
# self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
linear
=
nn
.
Linear
(
channels_p
,
n_classes
)
def
forward
(
self
,
x
):
s0
=
s1
=
self
.
stem
(
x
)
#aux_logits = None
for
i
,
cell
in
enumerate
(
self
.
cells
):
s0
,
s1
=
s1
,
cell
(
s0
,
s1
)
#if i == self.aux_pos and self.training:
# aux_logits = self.aux_head(s1)
out
=
self
.
gap
(
s1
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
# flatten
logits
=
self
.
linear
(
out
)
#if aux_logits is not None:
# return logits, aux_logits
return
logits
def
drop_path_prob
(
self
,
p
):
for
module
in
self
.
modules
():
if
isinstance
(
module
,
ops
.
DropPath
):
module
.
p
=
p
if
__name__
==
'__main__'
:
base_model
=
CNN
(
32
,
3
,
16
,
10
,
8
)
test/retiarii_test/cgo/ops.py
0 → 100644
View file @
f2f58dbb
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
@
basic_unit
class
DropPath
(
nn
.
Module
):
def
__init__
(
self
,
p
=
0.
):
"""
Drop path with probability.
Parameters
----------
p : float
Probability of an path to be zeroed.
"""
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
x
):
if
self
.
training
and
self
.
p
>
0.
:
keep_prob
=
1.
-
self
.
p
# per data point mask
mask
=
torch
.
zeros
((
x
.
size
(
0
),
1
,
1
,
1
),
device
=
x
.
device
).
bernoulli_
(
keep_prob
)
return
x
/
keep_prob
*
mask
return
x
@
basic_unit
class
PoolBN
(
nn
.
Module
):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
if
pool_type
.
lower
()
==
'max'
:
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
,
stride
,
padding
)
elif
pool_type
.
lower
()
==
'avg'
:
self
.
pool
=
nn
.
AvgPool2d
(
kernel_size
,
stride
,
padding
,
count_include_pad
=
False
)
else
:
raise
ValueError
()
self
.
bn
=
nn
.
BatchNorm2d
(
C
,
affine
=
affine
)
def
forward
(
self
,
x
):
out
=
self
.
pool
(
x
)
out
=
self
.
bn
(
out
)
return
out
@
basic_unit
class
StdConv
(
nn
.
Module
):
"""
Standard conv: ReLU - Conv - BN
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
basic_unit
class
FacConv
(
nn
.
Module
):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
(
kernel_length
,
1
),
stride
,
padding
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
(
1
,
kernel_length
),
stride
,
padding
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
basic_unit
class
DilConv
(
nn
.
Module
):
"""
(Dilated) depthwise separable conv.
ReLU - (Dilated) depthwise separable - Pointwise - BN.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
basic_unit
class
SepConv
(
nn
.
Module
):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
DilConv
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
1
,
affine
=
affine
),
DilConv
(
C_in
,
C_out
,
kernel_size
,
1
,
padding
,
dilation
=
1
,
affine
=
affine
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
basic_unit
class
FactorizedReduce
(
nn
.
Module
):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
super
().
__init__
()
self
.
relu
=
nn
.
ReLU
()
self
.
conv1
=
nn
.
Conv2d
(
C_in
,
C_out
//
2
,
1
,
stride
=
2
,
padding
=
0
,
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
C_in
,
C_out
//
2
,
1
,
stride
=
2
,
padding
=
0
,
bias
=
False
)
self
.
bn
=
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
def
forward
(
self
,
x
):
x
=
self
.
relu
(
x
)
out
=
torch
.
cat
([
self
.
conv1
(
x
),
self
.
conv2
(
x
[:,
:,
1
:,
1
:])],
dim
=
1
)
out
=
self
.
bn
(
out
)
return
out
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment