tuner.py 743 Bytes
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from nni.nas.pytorch.spos import SPOSEvolution

from network import ShuffleNetV2OneShot


class EvolutionWithFlops(SPOSEvolution):
    """
    This tuner extends the function of evolution tuner, by limiting the flops generated by tuner.
    Needs a function to examine the flops.
    """

    def __init__(self, flops_limit=330E6, **kwargs):
        super().__init__(**kwargs)
        self.model = ShuffleNetV2OneShot()
        self.flops_limit = flops_limit

    def _is_legal(self, cand):
        if not super()._is_legal(cand):
            return False
        if self.model.get_candidate_flops(cand) > self.flops_limit:
            return False
        return True