pytorch.py 6.14 KB
Newer Older
1
2
3
import logging
from dataclasses import dataclass
from pathlib import Path
4
from subprocess import Popen
5
from threading import Thread
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from typing import Any, List, Optional, Union

import torch
import torch.nn as nn
from nni.experiment import Experiment, TrainingServiceConfig
from nni.experiment.config import util
from nni.experiment.config.base import ConfigBase, PathLike
from nni.experiment.pipe import Pipe

from ..converter import convert_to_graph
from ..graph import Model, TrainingConfig
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import process_inline_mutation
from ..strategies.strategy import BaseStrategy
from ..trainer.interface import BaseOneShotTrainer, BaseTrainer
from ..utils import get_records
23
24
25

_logger = logging.getLogger(__name__)

26

27
28
29
@dataclass(init=False)
class RetiariiExeConfig(ConfigBase):
    experiment_name: Optional[str] = None
30
    search_space: Any = ''  # TODO: remove
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    trial_command: str = 'python3 -m nni.retiarii.trial_entry'
    trial_code_directory: PathLike = '.'
    trial_concurrency: int
    trial_gpu_number: int = 0
    max_experiment_duration: Optional[str] = None
    max_trial_number: Optional[int] = None
    nni_manager_ip: Optional[str] = None
    debug: bool = False
    log_level: Optional[str] = None
    experiment_working_directory: Optional[PathLike] = None
    # remove configuration of tuner/assessor/advisor
    training_service: TrainingServiceConfig

    def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
        super().__init__(**kwargs)
        if training_service_platform is not None:
            assert 'training_service' not in kwargs
48
            self.training_service = util.training_service_config_factory(platform = training_service_platform)
49
50
51
52
53
54
55
56
57
58
59
60

    def validate(self, initialized_tuner: bool = False) -> None:
        super().validate()

    @property
    def _canonical_rules(self):
        return _canonical_rules

    @property
    def _validation_rules(self):
        return _validation_rules

61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
_canonical_rules = {
    'trial_code_directory': util.canonical_path,
    'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
    'experiment_working_directory': util.canonical_path
}

_validation_rules = {
    'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
    'trial_concurrency': lambda value: value > 0,
    'trial_gpu_number': lambda value: value >= 0,
    'max_experiment_duration': lambda value: util.parse_time(value) > 0,
    'max_trial_number': lambda value: value > 0,
    'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
    'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}


class RetiariiExperiment(Experiment):
80
81
82
    def __init__(self, base_model: nn.Module, trainer: Union[TrainingConfig, BaseOneShotTrainer],
                 applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None):
        # TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
83
84
85
86
87
88
89
90
91
92
        self.config: RetiariiExeConfig = None
        self.port: Optional[int] = None

        self.base_model = base_model
        self.trainer = trainer
        self.applied_mutators = applied_mutators
        self.strategy = strategy
        self.recorded_module_args = get_records()

        self._dispatcher = RetiariiAdvisor()
93
        self._dispatcher_thread: Optional[Thread] = None
94
95
96
97
98
99
100
101
102
        self._proc: Optional[Popen] = None
        self._pipe: Optional[Pipe] = None

    def _start_strategy(self):
        try:
            script_module = torch.jit.script(self.base_model)
        except Exception as e:
            _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
            raise e
103
        base_model_ir = convert_to_graph(script_module, self.base_model)
104
        base_model_ir.training_config = self.trainer
105
106

        # handle inline mutations
107
        mutators = process_inline_mutation(base_model_ir)
108
        if mutators is not None and self.applied_mutators:
109
110
            raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
                               'do not use mutators when you use LayerChoice/InputChoice')
111
112
113
114
        if mutators is not None:
            self.applied_mutators = mutators

        _logger.info('Starting strategy...')
115
        Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators)).start()
116
117
        _logger.info('Strategy started!')

118
    def start(self, port: int = 8080, debug: bool = False) -> None:
119
120
121
122
123
124
125
126
127
128
129
        """
        Start the experiment in background.
        This method will raise exception on failure.
        If it returns, the experiment should have been successfully started.
        Parameters
        ----------
        port
            The port of web UI.
        debug
            Whether to start in debug mode.
        """
liuzhe-lz's avatar
liuzhe-lz committed
130
        super().start(port, debug)
131
132
        self._start_strategy()

liuzhe-lz's avatar
liuzhe-lz committed
133
134
    def _create_dispatcher(self):
        return self._dispatcher
135

136
137
138
139
    def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
        """
        Run the experiment.
        This function will block until experiment finish or error.
140
        """
141
        if isinstance(self.trainer, BaseOneShotTrainer):
142
143
144
145
146
147
            self.trainer.fit()
        else:
            assert config is not None, 'You are using classic search mode, config cannot be None!'
            self.config = config
            super().run(port, debug)

148
    def export_top_models(self, top_n: int = 1):
149
        """
150
151
        export several top performing models
        """
152
153
154
155
156
157
        if top_n != 1:
            _logger.warning('Only support top_n is 1 for now.')
        if isinstance(self.trainer, BaseOneShotTrainer):
            return self.trainer.export()
        else:
            _logger.info('For this experiment, you can find out the best one from WebUI.')
158

159
    def retrain_model(self, model):
160
        """
161
        this function retrains the exported model, and test it to output test accuracy
162
        """
163
        raise NotImplementedError