Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
443ba8c1
Unverified
Commit
443ba8c1
authored
Dec 06, 2021
by
Yuge Zhang
Committed by
GitHub
Dec 06, 2021
Browse files
Serialization infrastructure V2 (#4337)
parent
896c516f
Changes
40
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
367 additions
and
148 deletions
+367
-148
docs/en_US/NAS/ApiReference.rst
docs/en_US/NAS/ApiReference.rst
+3
-1
docs/en_US/autotune_ref.rst
docs/en_US/autotune_ref.rst
+6
-0
examples/nas/multi-trial/nasbench101/network.py
examples/nas/multi-trial/nasbench101/network.py
+2
-2
examples/nas/multi-trial/nasbench201/network.py
examples/nas/multi-trial/nasbench201/network.py
+2
-2
nni/__init__.py
nni/__init__.py
+1
-1
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
+8
-8
nni/algorithms/hpo/hyperband_advisor.py
nni/algorithms/hpo/hyperband_advisor.py
+6
-6
nni/common/__init__.py
nni/common/__init__.py
+1
-0
nni/common/serializer.py
nni/common/serializer.py
+305
-94
nni/experiment/experiment.py
nni/experiment/experiment.py
+2
-2
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+1
-1
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+1
-1
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
+2
-2
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
+3
-4
nni/retiarii/evaluator/pytorch/cgo/trainer.py
nni/retiarii/evaluator/pytorch/cgo/trainer.py
+4
-3
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+12
-10
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+1
-1
nni/retiarii/integration.py
nni/retiarii/integration.py
+4
-4
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+0
-4
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+3
-2
No files found.
docs/en_US/NAS/ApiReference.rst
View file @
443ba8c1
...
...
@@ -114,7 +114,9 @@ CGO Execution
Utilities
---------
.. autofunction:: nni.retiarii.serialize
.. autofunction:: nni.retiarii.basic_unit
.. autofunction:: nni.retiarii.model_wrapper
.. autofunction:: nni.retiarii.fixed_arch
...
...
docs/en_US/autotune_ref.rst
View file @
443ba8c1
...
...
@@ -78,3 +78,9 @@ Utilities
---------
.. autofunction:: nni.utils.merge_parameter
.. autofunction:: nni.trace
.. autofunction:: nni.dump
.. autofunction:: nni.load
examples/nas/multi-trial/nasbench101/network.py
View file @
443ba8c1
...
...
@@ -3,7 +3,7 @@ import nni
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
torch.nn
as
nn
import
torchmetrics
from
nni.retiarii
import
model_wrapper
,
serialize
,
serialize_cls
from
nni.retiarii
import
model_wrapper
,
serialize
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.nn.pytorch
import
NasBench101Cell
from
nni.retiarii.strategy
import
Random
...
...
@@ -82,7 +82,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
@
nni
.
trace
class
NasBench101TrainingModule
(
pl
.
LightningModule
):
def
__init__
(
self
,
max_epochs
=
108
,
learning_rate
=
0.1
,
weight_decay
=
1e-4
):
super
().
__init__
()
...
...
examples/nas/multi-trial/nasbench201/network.py
View file @
443ba8c1
...
...
@@ -3,7 +3,7 @@ import nni
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
torch.nn
as
nn
import
torchmetrics
from
nni.retiarii
import
model_wrapper
,
serialize
,
serialize_cls
from
nni.retiarii
import
model_wrapper
,
serialize
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.nn.pytorch
import
NasBench201Cell
from
nni.retiarii.strategy
import
Random
...
...
@@ -71,7 +71,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
@
nni
.
trace
class
NasBench201TrainingModule
(
pl
.
LightningModule
):
def
__init__
(
self
,
max_epochs
=
200
,
learning_rate
=
0.1
,
weight_decay
=
5e-4
):
super
().
__init__
()
...
...
nni/__init__.py
View file @
443ba8c1
...
...
@@ -9,7 +9,7 @@ except ModuleNotFoundError:
from
.runtime.log
import
init_logger
init_logger
()
from
.common.serializer
import
*
from
.common.serializer
import
trace
,
dump
,
load
from
.runtime.env_vars
import
dispatcher_env_vars
from
.utils
import
ClassArgsValidator
...
...
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
View file @
443ba8c1
...
...
@@ -7,12 +7,12 @@ bohb_advisor.py
import
sys
import
math
import
logging
import
json_tricks
from
schema
import
Schema
,
Optional
import
ConfigSpace
as
CS
import
ConfigSpace.hyperparameters
as
CSH
from
ConfigSpace.read_and_write
import
pcs_new
import
nni
from
nni
import
ClassArgsValidator
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
...
...
@@ -428,7 +428,7 @@ class BOHB(MsgDispatcherBase):
'parameter_source'
:
'algorithm'
,
'parameters'
:
''
}
send
(
CommandType
.
NoMoreTrialJobs
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
NoMoreTrialJobs
,
nni
.
dump
(
ret
))
return
None
assert
self
.
generated_hyper_configs
params
=
self
.
generated_hyper_configs
.
pop
(
0
)
...
...
@@ -459,7 +459,7 @@ class BOHB(MsgDispatcherBase):
"""
ret
=
self
.
_get_one_trial_job
()
if
ret
is
not
None
:
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
NewTrialJob
,
nni
.
dump
(
ret
))
self
.
credit
-=
1
def
handle_update_search_space
(
self
,
data
):
...
...
@@ -536,7 +536,7 @@ class BOHB(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
hyper_params
=
json_tricks
.
load
s
(
data
[
'hyper_params'
])
hyper_params
=
nni
.
load
(
data
[
'hyper_params'
])
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
...
...
@@ -551,7 +551,7 @@ class BOHB(MsgDispatcherBase):
ret
[
'parameter_index'
]
=
one_unsatisfied
[
'parameter_index'
]
# update parameter_id in self.job_id_para_id_map
self
.
job_id_para_id_map
[
ret
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
...
...
@@ -584,7 +584,7 @@ class BOHB(MsgDispatcherBase):
"""
logger
.
debug
(
'handle report metric data = %s'
,
data
)
if
'value'
in
data
:
data
[
'value'
]
=
json_tricks
.
load
s
(
data
[
'value'
])
data
[
'value'
]
=
nni
.
load
(
data
[
'value'
])
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
...
...
@@ -599,7 +599,7 @@ class BOHB(MsgDispatcherBase):
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
# update parameter_id in self.job_id_para_id_map
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
else
:
assert
'value'
in
data
value
=
extract_scalar_reward
(
data
[
'value'
])
...
...
@@ -655,7 +655,7 @@ class BOHB(MsgDispatcherBase):
data doesn't have required key 'parameter' and 'value'
"""
for
entry
in
data
:
entry
[
'value'
]
=
json_tricks
.
load
s
(
entry
[
'value'
])
entry
[
'value'
]
=
nni
.
load
(
entry
[
'value'
])
_completed_num
=
0
for
trial_info
in
data
:
logger
.
info
(
"Importing data, current processing progress %s / %s"
,
_completed_num
,
len
(
data
))
...
...
nni/algorithms/hpo/hyperband_advisor.py
View file @
443ba8c1
...
...
@@ -10,10 +10,10 @@ import logging
import
math
import
sys
import
json_tricks
import
numpy
as
np
from
schema
import
Schema
,
Optional
import
nni
from
nni
import
ClassArgsValidator
from
nni.common.hpo_utils
import
validate_search_space
from
nni.runtime.common
import
multi_phase_enabled
...
...
@@ -336,7 +336,7 @@ class Hyperband(MsgDispatcherBase):
def
_request_one_trial_job
(
self
):
ret
=
self
.
_get_one_trial_job
()
if
ret
is
not
None
:
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
NewTrialJob
,
nni
.
dump
(
ret
))
self
.
credit
-=
1
def
_get_one_trial_job
(
self
):
...
...
@@ -365,7 +365,7 @@ class Hyperband(MsgDispatcherBase):
'parameter_source'
:
'algorithm'
,
'parameters'
:
''
}
send
(
CommandType
.
NoMoreTrialJobs
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
NoMoreTrialJobs
,
nni
.
dump
(
ret
))
return
None
assert
self
.
generated_hyper_configs
...
...
@@ -408,7 +408,7 @@ class Hyperband(MsgDispatcherBase):
event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
hyper_params
=
json_tricks
.
load
s
(
data
[
'hyper_params'
])
hyper_params
=
nni
.
load
(
data
[
'hyper_params'
])
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
...
...
@@ -426,7 +426,7 @@ class Hyperband(MsgDispatcherBase):
Data type not supported
"""
if
'value'
in
data
:
data
[
'value'
]
=
json_tricks
.
load
s
(
data
[
'value'
])
data
[
'value'
]
=
nni
.
load
(
data
[
'value'
])
# multiphase? need to check
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
...
...
@@ -440,7 +440,7 @@ class Hyperband(MsgDispatcherBase):
if
data
[
'parameter_index'
]
is
not
None
:
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
else
:
value
=
extract_scalar_reward
(
data
[
'value'
])
bracket_id
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
...
...
nni/common/__init__.py
View file @
443ba8c1
from
.serializer
import
trace
,
dump
,
load
,
is_traceable
nni/common/serializer.py
View file @
443ba8c1
This diff is collapsed.
Click to expand it.
nni/experiment/experiment.py
View file @
443ba8c1
...
...
@@ -6,11 +6,11 @@ from subprocess import Popen
import
time
from
typing
import
Optional
,
Union
,
List
,
overload
,
Any
import
json_tricks
import
colorama
import
psutil
import
nni.runtime.log
from
nni.common
import
dump
from
.config
import
ExperimentConfig
,
AlgorithmConfig
from
.data
import
TrialJob
,
TrialMetricData
,
TrialResult
...
...
@@ -439,7 +439,7 @@ class Experiment:
value: dict
New search_space.
"""
value
=
json_tricks
.
dump
s
(
value
)
value
=
dump
(
value
)
self
.
_update_experiment_profile
(
'searchSpace'
,
value
)
def
update_max_trial_number
(
self
,
value
:
int
):
...
...
nni/retiarii/__init__.py
View file @
443ba8c1
...
...
@@ -6,4 +6,4 @@ from .graph import *
from
.execution
import
*
from
.fixed
import
fixed_arch
from
.mutator
import
*
from
.serializer
import
basic_unit
,
json_dump
,
json_dumps
,
json_load
,
json_loads
,
serialize
,
serialize_cls
,
model_wrapper
from
.serializer
import
basic_unit
,
model_wrapper
,
serialize
,
serialize_cls
nni/retiarii/converter/graph_gen.py
View file @
443ba8c1
...
...
@@ -637,7 +637,7 @@ class GraphConverter:
original_type_name
not
in
MODULE_EXCEPT_LIST
:
# this is a basic module from pytorch, no need to parse its graph
m_attrs
=
get_init_parameters_or_fail
(
module
)
elif
getattr
(
module
,
'_
stop_parsing
'
,
False
):
elif
getattr
(
module
,
'_
nni_basic_unit
'
,
False
):
# this module is marked as serialize, won't continue to parse
m_attrs
=
get_init_parameters_or_fail
(
module
)
if
m_attrs
is
not
None
:
...
...
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
View file @
443ba8c1
...
...
@@ -10,7 +10,7 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import Trainin
from
pytorch_lightning.trainer
import
Trainer
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
....serializer
import
serialize_cls
import
nni
class
BypassPlugin
(
TrainingTypePlugin
):
...
...
@@ -126,7 +126,7 @@ def get_accelerator_connector(
)
@
serialize_cls
@
nni
.
trace
class
BypassAccelerator
(
Accelerator
):
def
__init__
(
self
,
precision_plugin
=
None
,
device
=
"cpu"
,
**
trainer_kwargs
):
if
precision_plugin
is
None
:
...
...
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
View file @
443ba8c1
...
...
@@ -14,10 +14,9 @@ import nni
from
..lightning
import
LightningModule
,
_AccuracyWithLogits
,
Lightning
from
.trainer
import
Trainer
from
....serializer
import
serialize_cls
@
serialize_cls
@
nni
.
trace
class
_MultiModelSupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torchmetrics
.
Metric
],
n_models
:
int
=
0
,
...
...
@@ -126,7 +125,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
super
().
__init__
(
criterion
,
metrics
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
@
serialize_cls
@
nni
.
trace
class
_ClassificationModule
(
MultiModelSupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
...
...
@@ -174,7 +173,7 @@ class Classification(Lightning):
train_dataloader
=
train_dataloader
,
val_dataloaders
=
val_dataloaders
)
@
serialize_cls
@
nni
.
trace
class
_RegressionModule
(
MultiModelSupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
...
...
nni/retiarii/evaluator/pytorch/cgo/trainer.py
View file @
443ba8c1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
pytorch_lightning
as
pl
from
....serializer
import
serialize_cls
import
nni
from
.accelerator
import
BypassAccelerator
@
serialize_cls
@
nni
.
trace
class
Trainer
(
pl
.
Trainer
):
"""
Trainer for cross-graph optimization.
...
...
nni/retiarii/evaluator/pytorch/lightning.py
View file @
443ba8c1
...
...
@@ -10,17 +10,17 @@ import pytorch_lightning as pl
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torchmetrics
from
torch.utils.data
import
DataLoader
import
torch.utils.data
as
torch_data
import
nni
from
nni.common.serializer
import
is_traceable
try
:
from
.cgo
import
trainer
as
cgo_trainer
cgo_import_failed
=
False
except
ImportError
:
cgo_import_failed
=
True
from
...graph
import
Evaluator
from
...serializer
import
serialize_cls
from
nni.retiarii.graph
import
Evaluator
__all__
=
[
'LightningModule'
,
'Trainer'
,
'DataLoader'
,
'Lightning'
,
'Classification'
,
'Regression'
]
...
...
@@ -40,9 +40,10 @@ class LightningModule(pl.LightningModule):
self
.
model
=
model
Trainer
=
serialize_cls
(
pl
.
Trainer
)
DataLoader
=
serialize_cls
(
DataLoader
)
Trainer
=
nni
.
trace
(
pl
.
Trainer
)
DataLoader
=
nni
.
trace
(
torch_data
.
DataLoader
)
@
nni
.
trace
class
Lightning
(
Evaluator
):
"""
Delegate the whole training to PyTorch Lightning.
...
...
@@ -74,9 +75,10 @@ class Lightning(Evaluator):
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
):
assert
isinstance
(
lightning_module
,
LightningModule
),
f
'Lightning module must be an instance of
{
__name__
}
.LightningModule.'
if
cgo_import_failed
:
assert
isinstance
(
trainer
,
T
rainer
),
f
'Trainer must be imported from
{
__name__
}
'
assert
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
t
rainer
),
f
'Trainer must be imported from
{
__name__
}
'
else
:
assert
isinstance
(
trainer
,
Trainer
)
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert
(
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
))
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
f
'Trainer must be imported from
{
__name__
}
or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert
_check_dataloader
(
train_dataloader
),
f
'Wrong dataloader type. Try import DataLoader from
{
__name__
}
.'
assert
_check_dataloader
(
val_dataloaders
),
f
'Wrong dataloader type. Try import DataLoader from
{
__name__
}
.'
...
...
@@ -135,7 +137,7 @@ def _check_dataloader(dataloader):
return
True
if
isinstance
(
dataloader
,
list
):
return
all
([
_check_dataloader
(
d
)
for
d
in
dataloader
])
return
isinstance
(
dataloader
,
DataL
oader
)
return
isinstance
(
dataloader
,
torch_data
.
DataLoader
)
and
is_traceable
(
datal
oader
)
### The following are some commonly used Lightning modules ###
...
...
@@ -219,7 +221,7 @@ class _AccuracyWithLogits(torchmetrics.Accuracy):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
@
nni
.
trace
class
_ClassificationModule
(
_SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
...
...
@@ -272,7 +274,7 @@ class Classification(Lightning):
train_dataloader
=
train_dataloader
,
val_dataloaders
=
val_dataloaders
)
@
serialize_cls
@
nni
.
trace
class
_RegressionModule
(
_SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
...
...
nni/retiarii/execution/cgo_engine.py
View file @
443ba8c1
...
...
@@ -200,7 +200,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params
=
model
.
evaluator
.
module
.
_init_parameter
s
.
copy
()
new_module_init_params
=
model
.
evaluator
.
module
.
trace_kwarg
s
.
copy
()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params
[
'n_models'
]
=
len
(
multi_model
)
...
...
nni/retiarii/integration.py
View file @
443ba8c1
...
...
@@ -4,13 +4,13 @@
import
logging
from
typing
import
Any
,
Callable
import
nni
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.utils
import
MetricType
from
.graph
import
MetricData
from
.integration_api
import
register_advisor
from
.serializer
import
json_dumps
,
json_loads
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -121,7 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'placement_constraint'
:
placement_constraint
}
_logger
.
debug
(
'New trial sent: %s'
,
new_trial
)
send
(
CommandType
.
NewTrialJob
,
json_
dump
s
(
new_trial
))
send
(
CommandType
.
NewTrialJob
,
nni
.
dump
(
new_trial
))
if
self
.
send_trial_callback
is
not
None
:
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
return
self
.
parameters_count
...
...
@@ -140,7 +140,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
def
handle_trial_end
(
self
,
data
):
_logger
.
debug
(
'Trial end: %s'
,
data
)
self
.
trial_end_callback
(
json_
load
s
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
self
.
trial_end_callback
(
nni
.
load
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
data
[
'event'
]
==
'SUCCEEDED'
)
def
handle_report_metric_data
(
self
,
data
):
...
...
@@ -156,7 +156,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
@
staticmethod
def
_process_value
(
value
)
->
Any
:
# hopefully a float
value
=
json_
load
s
(
value
)
value
=
nni
.
load
(
value
)
if
isinstance
(
value
,
dict
):
if
'default'
in
value
:
return
value
[
'default'
]
...
...
nni/retiarii/integration_api.py
View file @
443ba8c1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
from
typing
import
NewType
,
Any
import
nni
from
.serializer
import
json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
...
...
@@ -41,7 +38,6 @@ def receive_trial_parameters() -> dict:
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params
=
nni
.
get_next_parameter
()
params
=
json_loads
(
json
.
dumps
(
params
))
return
params
...
...
nni/retiarii/nn/pytorch/api.py
View file @
443ba8c1
...
...
@@ -8,8 +8,9 @@ from typing import Any, List, Union, Dict, Optional
import
torch
import
torch.nn
as
nn
from
...serializer
import
Translatable
,
basic_unit
from
...utils
import
NoContextError
from
nni.common.serializer
import
Translatable
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.utils
import
NoContextError
from
.utils
import
generate_new_label
,
get_fixed_value
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment