mutator.py 4.71 KB
Newer Older
1
2
3
from typing import (Any, Iterable, List, Optional)

from .graph import Model
4
5
6
7
8


__all__ = ['Sampler', 'Mutator']


9
Choice = Any
10
11
12
13
14
15


class Sampler:
    """
    Handles `Mutator.choice()` calls.
    """
16

17
    def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
18
19
        raise NotImplementedError()

20
    def mutation_start(self, mutator: 'Mutator', model: Model) -> None:
21
22
        pass

23
    def mutation_end(self, mutator: 'Mutator', model: Model) -> None:
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        pass


class Mutator:
    """
    Mutates graphs in model to generate new model.
    `Mutator` class will be used in two places:
      1. Inherit `Mutator` to implement graph mutation logic.
      2. Use `Mutator` subclass to implement NAS strategy.
    In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
    In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
    and then use `Mutator.apply()` to mutate model.
    For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
    # Method names are open for discussion.
    """
39

40
41
42
43
44
    def __init__(self, sampler: Optional[Sampler] = None):
        self.sampler: Optional[Sampler] = sampler
        self._cur_model: Optional[Model] = None
        self._cur_choice_idx: Optional[int] = None

45
    def bind_sampler(self, sampler: Sampler) -> 'Mutator':
46
47
48
49
        """
        Set the sampler which will handle `Mutator.choice` calls.
        """
        self.sampler = sampler
50
        return self
51
52
53
54
55
56
57

    def apply(self, model: Model) -> Model:
        """
        Apply this mutator on a model.
        Returns mutated model.
        The model will be copied before mutation and the original model will not be modified.
        """
58
        assert self.sampler is not None
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        copy = model.fork()
        self._cur_model = copy
        self._cur_choice_idx = 0
        self.sampler.mutation_start(self, copy)
        self.mutate(copy)
        self.sampler.mutation_end(self, copy)
        self._cur_model = None
        self._cur_choice_idx = None
        return copy

    def dry_run(self, model: Model) -> List[List[Choice]]:
        """
        Dry run mutator on a model to collect choice candidates.
        If you invoke this method multiple times on same or different models,
        it may or may not return identical results, depending on how the subclass implements `Mutator.mutate()`.
        """
        sampler_backup = self.sampler
        recorder = _RecorderSampler()
        self.sampler = recorder
QuanluZhang's avatar
QuanluZhang committed
78
        new_model = self.apply(model)
79
        self.sampler = sampler_backup
QuanluZhang's avatar
QuanluZhang committed
80
        return recorder.recorded_candidates, new_model
81
82
83
84
85
86
87
88
89
90
91
92

    def mutate(self, model: Model) -> None:
        """
        Abstract method to be implemented by subclass.
        Mutate a model in place.
        """
        raise NotImplementedError()

    def choice(self, candidates: Iterable[Choice]) -> Choice:
        """
        Ask sampler to make a choice.
        """
93
        assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
94
95
96
97
98
99
100
101
102
103
104
        ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
        self._cur_choice_idx += 1
        return ret


class _RecorderSampler(Sampler):
    def __init__(self):
        self.recorded_candidates: List[List[Choice]] = []

    def choice(self, candidates: List[Choice], *args) -> Choice:
        self.recorded_candidates.append(candidates)
105
106
        return candidates[0]

107

108
109
# the following is for inline mutation

110

111
112
113
114
115
116
117
118
119
120
121
122
123
class LayerChoiceMutator(Mutator):
    def __init__(self, node_name: str, candidates: List):
        super().__init__()
        self.node_name = node_name
        self.candidates = candidates

    def mutate(self, model):
        target = model.get_node_by_name(self.node_name)
        indexes = [i for i in range(len(self.candidates))]
        chosen_index = self.choice(indexes)
        chosen_cand = self.candidates[chosen_index]
        target.update_operation(chosen_cand['type'], chosen_cand['parameters'])

124

125
class InputChoiceMutator(Mutator):
126
    def __init__(self, node_name: str, n_candidates: int, n_chosen: int, reduction: str):
127
128
        super().__init__()
        self.node_name = node_name
129
        self.n_candidates = n_candidates
130
        self.n_chosen = n_chosen
131
        self.reduction = reduction
132
133
134

    def mutate(self, model):
        target = model.get_node_by_name(self.node_name)
135
136
        candidates = [i for i in range(self.n_candidates)]
        chosen = [self.choice(candidates) for _ in range(self.n_chosen)]
137
        target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs',
138
                                {'chosen': chosen, 'reduction': self.reduction})