compressor.py 24.3 KB
Newer Older
QuanluZhang's avatar
QuanluZhang committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
dosemeion's avatar
dosemeion committed
3

Ningxin Zheng's avatar
Ningxin Zheng committed
4
import copy
dosemeion's avatar
dosemeion committed
5
6
7
8
import logging
from pathlib import Path
import queue

QuanluZhang's avatar
QuanluZhang committed
9
import torch
Ningxin Zheng's avatar
Ningxin Zheng committed
10
11
12
import torch.nn as nn

from nni.common.graph_utils import build_module_graph
13
14
from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict
from nni.compression.pytorch.utils.utils import get_module_by_name
QuanluZhang's avatar
QuanluZhang committed
15
from .compress_modules import replace_module
Ningxin Zheng's avatar
Ningxin Zheng committed
16
17
18
from .infer_mask import AutoMaskInference
from .jit_translate import jit_to_python_function
from ..utils import rand_like_with_shape
QuanluZhang's avatar
QuanluZhang committed
19

20

QuanluZhang's avatar
QuanluZhang committed
21
_logger = logging.getLogger(__name__)
Ningxin Zheng's avatar
Ningxin Zheng committed
22
_logger.setLevel(logging.INFO)
QuanluZhang's avatar
QuanluZhang committed
23
24
25
26


class ModelSpeedup:
    """
Ningxin Zheng's avatar
Ningxin Zheng committed
27
    This class is to speedup the model with provided weight mask.
J-shang's avatar
J-shang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

    Parameters
    ----------
    model : pytorch model
        The model user wants to speed up
    dummy_input : pytorch tensor, tuple of tensor, list of tensor
        Note: The first dimension of the dummy_input should be the batchsize.
        The dummy input for ```jit.trace```, users should put it on the right
        device.
    masks_file : str/dict
        The path of user provided mask file, or the mask object
    map_location : str
        the device on which masks are placed, same to map_location in ```torch.load```
    batch_dim : int
        the index of batch dimension in the dummy_input
    confidence: the confidence coefficient of the sparsity inference. This value is
        actually used as the batchsize of the dummy_input.
QuanluZhang's avatar
QuanluZhang committed
45
46
    """

Ningxin Zheng's avatar
Ningxin Zheng committed
47
48
49
50
51
52
    def __init__(self, model, dummy_input, masks_file, map_location=None,
                 batch_dim=0, confidence=8):
        assert confidence > 1
        # The auto inference will change the values of the parameters in the model
        # so we need make a copy before the mask inference
        self.ori_state_dict = copy.deepcopy(model.state_dict())
QuanluZhang's avatar
QuanluZhang committed
53
        self.bound_model = model
Ningxin Zheng's avatar
Ningxin Zheng committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        self.inferred_masks = dict()  # key: module_name, value: ModuleMasks
        self.batch_dim = batch_dim
        self.dummy_input, self.device = self._random_model_input(dummy_input, confidence, batch_dim)
        self.torch_graph = build_module_graph(model, self.dummy_input)
        # dict object to save the auto inferences objects of the submodules
        self.auto_inferences = {}
        # the index dict to find the corresponding torch._C.Value object
        # according to the debug name
        # we need the dummy_input to infer the mask automaticlly, so we save
        # the indexes from tensor's debugname to the torch._C.Value object.
        self.debugname_to_value = {}
        # load the mask tensor to the same device with the dummy_input
        # self.masks save the mask tensors pruned by the user and the infered
        # masks of the others modules
dosemeion's avatar
dosemeion committed
68
        if isinstance(masks_file, (str, Path)) and Path(masks_file).exists():
69
70
71
72
73
74
            self.masks = torch.load(
                masks_file, map_location if map_location is not None else str(self.device))
        elif isinstance(masks_file, dict):
            self.masks = masks_file
        else:
            raise Exception('Please provide the mask or the path of the mask file')
Ningxin Zheng's avatar
Ningxin Zheng committed
75
76
77
78
79
        self.constant = {}
        # self.internal_result save the internal output of the submodules
        self.internal_result = {}

    def _random_model_input(self, dummy_input, confidence, batch_dim):
QuanluZhang's avatar
QuanluZhang committed
80
        """
Ningxin Zheng's avatar
Ningxin Zheng committed
81
82
        Get the new random dummy input accordint to the original dummy_input
        and confidence, batch_dim.
83

Ningxin Zheng's avatar
Ningxin Zheng committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        Parameters
        ----------
        dummy_input: Tensor or list/dict of Tensors
            The dummy_input given by the user.
        confidence: int
            The new batch size of the generated dummy_input.
        batch_dim: int
            The index of the batch dimension.

        Returns
        ------
        new_dummy_input: Tensor or list/dict of Tensors
            The generated dummy_input for mask inference.
        device: torch.device
            The device of the generated dummy_inputs
        """
        input_errmsg = 'Only support the tensor, list/tuple/dict of tensors as input'
        # Some model may use list of tensors as input, for example transformers
        new_dummy_input, device = None, None
        if isinstance(dummy_input, torch.Tensor):
            input_shape = list(dummy_input.size())
            # set the batchsize to the confidence ratio
            input_shape[batch_dim] = confidence
            new_dummy_input = rand_like_with_shape(input_shape, dummy_input)
            device = dummy_input.device
        elif isinstance(dummy_input, (tuple, list)):
            # else if the dummy input is list/tuple
            new_dummy_input = []
            old_batchsize = dummy_input[0].size(0)
            device = dummy_input[0].device
            for _, t_input in enumerate(dummy_input):
                assert isinstance(t_input, torch.Tensor), input_errmsg
                assert t_input.size(0) == old_batchsize, 'The first dimension should be batchsize\
                    and the batchsize of all inputs should be the same!'
                input_shape = list(t_input.size())
                input_shape[batch_dim] = confidence
                # rand_func = torch.randint if t_input.dtype
                new_dummy_input.append(
                    rand_like_with_shape(input_shape, t_input))
        elif isinstance(dummy_input, dict):
            new_dummy_input = {}
            tmp_key = list(dummy_input.keys())[0]
            old_batchsize = dummy_input[tmp_key].size(0)
            device = dummy_input[tmp_key].device
            for in_name, t_input in dummy_input.items():
                assert isinstance(t_input, torch.Tensor), input_errmsg
                assert old_batchsize == t_input.size(0), 'The first dimension should be batchsize\
                and the batchsize of all inputs should be the same!'
                input_shape = list(t_input.size())
                input_shape[batch_dim] = confidence
                new_dummy_input[in_name] = rand_like_with_shape(
                    input_shape, t_input)
        else:
            raise TypeError(input_errmsg)
        return new_dummy_input, device
QuanluZhang's avatar
QuanluZhang committed
139

Ningxin Zheng's avatar
Ningxin Zheng committed
140
141
142
    def _prepare_dummy_input(self, node):
        """
        Prepare the dummy_input for the auto mask inference.
QuanluZhang's avatar
QuanluZhang committed
143
144
145

        Parameters
        ----------
Ningxin Zheng's avatar
Ningxin Zheng committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        node: NodePyGroup

        Returns
        -------
        dummy_input: list
            List of tensors that will be used as input for the target node.
        debugnames: list of strs
            Debugnames of the dummy_inputs.
        """
        _logger.debug('Prepare auto mask inference for node: %s',
                      node.unique_name)

        # prepare the inputs and outputs mask for this node,
        # if there is already a mask in self.masks, then use
        # the original mask tensor, else create a new one.
        inputs_name = node.inputs
        # build the dummy_input, in_masks the target node
        dummy_input = []
        debugnames = []
        for _input in inputs_name:
            if _input not in self.internal_result:
                # if the input debug name is not in self.internal_result,
                # then this node isn't a output tensor of any predecessor
                # nodes. This node is a attribute of the submodule, such as
                # weight or bias, etc. We will skip these tensors.
                # If we don't want this specific judgement here, we can merge
                # the `prim::GetAttr` node of the weight/bias tensor into the key
                # node, such as `conv`.
                # This is caused by the `meage_module_node` function in the
                # _graph_utils.py, because it doesn't merge the prim::GetAttr
                # node into the key node. In current version of _graph_utils.py,
                # we will only merge the nodes that have same scope name, however,
                # the scope name of the correponding prim::GetAttr node of `weight` tensor
                # is None.
                continue
            # The detach operation here is for the in-place operation. We cannot
            # directly can the backward on the output tensor of an in-place operator.
            dummy_input.append(self.internal_result[_input].detach())
            debugnames.append(_input)

        return dummy_input, debugnames

    def update_direct_sparsity(self, node):
        """
        Update the direct sparsity for the target node. Here the direct sparsity
        means that the sparsity in the output tensor that caused by the sparsity
        in the input tensors/weight tensors.
        """
        # this name is consistent with the name returned by named_modules()
        module_name = node.name
        _logger.info('Update mask for %s', module_name)
        unique_name = node.unique_name
        dummy_input, input_debugname = self._prepare_dummy_input(node)
        # get the input mask from self.masks
        # Note: the input mask of the successor nodes are
        # already created by the predecessor node
        in_masks = [self.masks[debugname] for debugname in input_debugname]
        in_constants = [self.constant[debugname]
                        for debugname in input_debugname]
        if node.type == 'func':
            # we cannot get the runable function directly from the jit traced
            # graph, so we translate it back to python function, Note: the function
            # is appliable to both cpu/gpu devices, the output tensors will be on the
            # same device of the input tensors
            func = jit_to_python_function(node, self)
            if func is None:
                # no need to infer the sparsity for this node
                self.auto_inferences[unique_name] = None
                return
            # function doesn't have weights
            _auto_infer = AutoMaskInference(
                func, dummy_input, in_masks, in_constants=in_constants, batch_dim=self.batch_dim)
QuanluZhang's avatar
QuanluZhang committed
218
        else:
Ningxin Zheng's avatar
Ningxin Zheng committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
            weight_mask = None
            if module_name in self.masks:
                weight_mask = self.masks[module_name]
            _, module = get_module_by_name(self.bound_model, module_name)
            _auto_infer = AutoMaskInference(
                module, dummy_input, in_masks, weight_mask, in_constants=in_constants,
                state_dict=copy.deepcopy(module.state_dict()), batch_dim=self.batch_dim)
        self.auto_inferences[unique_name] = _auto_infer
        _auto_infer.name = node.unique_name

        _auto_infer.update_direct_sparsity()
        # also save the input debug names into the auto_infer
        _auto_infer.input_debugname = input_debugname
        # update the mask tensor and the internal output of the submodules
        # after manually unpack the tuple/list of tensors, the number of the outputs
        # of each node should always be one(Except for the TupleUnpack node at the end
        # of the whole model)
        assert len(
            node.outputs) == 1, 'The number of the output should be one after the Tuple unpacked manually'

        out_debugname = node.outputs[0]
        # update the output mask into self.masks
        self.masks[out_debugname] = _auto_infer.output_mask
        self.constant[out_debugname] = _auto_infer.out_constant
        # update the output result into self.internal_result, so that
        # the successor nodes can take these output tensors as inputs.
        self.internal_result[out_debugname] = _auto_infer.output
        # update the parameter mask of the node
QuanluZhang's avatar
QuanluZhang committed
247

Ningxin Zheng's avatar
Ningxin Zheng committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        self.masks[module_name] = _auto_infer.weight_mask

    def update_indirect_sparsity(self, node):
        """
        This function will update the indirect sparsity. To explain what's
        indirect sparsity, for example, there is two tensors TA and TB, and
        we perform the calculation: TC = TA x TB in which TC is also a tensor.
        Once some values in TA are masked to zeros, then the corresponding
        positions in TB are also potential sparsities, because these have no
        effect of the final output(the gradient of these positions in TB equal
        to 0 all the time). This function it to fine the potential sparsity caused
        by other sparsity(we call it indirect sparsity here). Basically we can find
        these potential sparsity through gradient.

        Parameters
        ---------
        node: the NodePy
            The target node to update the indirect sparsity
        """
        unique_name = node.unique_name
        if unique_name in self.auto_inferences and self.auto_inferences[unique_name] is not None:
            # if the auto inference object already in self.auto_inference, then
            # directly update the previous one
            # self.auto_inferences[unique_name].update()
            _logger.info(
                'Update the indirect sparsity for the %s', unique_name)
            auto_infer = self.auto_inferences[unique_name]
            auto_infer.update_indirect_sparsity()
            # pass the gradient to the predecessor nodes
            for in_id, tin in enumerate(auto_infer.dummy_input):
                debug_name = auto_infer.input_debugname[in_id]
                last_output = self.internal_result[debug_name]
                # if isinstance(last_output, torch.Tensor):
                # TODO what if last output is tuple/list of tensor
                if last_output.grad is not None and tin.grad is not None:
                    last_output.grad.data += tin.grad.data
                else:
                    last_output.grad = tin.grad
        else:
            _logger.warning('Note: %s does not have corresponding mask inference object', node.name)

    def _vnode_to_value(self, c_node):
        """
        translate the C Value node into the values/tensors.
        """
        errmsg = "Only support the torch._C.Value type"
        assert isinstance(c_node, torch._C.Value), errmsg
        if isinstance(c_node.type(), torch._C.TensorType):
            shape = tuple(c_node.type().sizes())
            dtype = c_node.type().scalarType()
            # TODO should use a more general way to get the input
            if dtype.startswith('Float') or dtype.startswith('Double'):
                return torch.rand(shape).to(self.device)
            else:
                # This small range is due to the ·ReLU6·, we will add
                # the manual specific mask inference rule for several
                # ops in the future, so that we can remove the constraint.
                return torch.randint(0, 10, shape, device=self.device)
        else:
            value = c_node.toIValue()
            # TODO support more kinds of value node
            errmsg = "Doesn't support convert %s to values", str(c_node.type())
            # currently only support the tensors and constant values
            assert value is not None, errmsg
            return value
QuanluZhang's avatar
QuanluZhang committed
313
314
315

    def infer_modules_masks(self):
        """
Ningxin Zheng's avatar
Ningxin Zheng committed
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        Infer the mask for all layers in the module, this function can be divided into
        two steps: first, forward inference of the the masks. Second, backward inference
        of the mask. We keep repeating these two steps until the masks of the model doesn't
        change.
        """
        # unpack the tensor tuple/list before the mask inference
        self.torch_graph.unpack_manually()
        # find the input/ouput tensor of the whole graph
        graph_input = []
        graph_output = []
        for name, nodeio in self.torch_graph.nodes_py.nodes_io.items():
            if nodeio.input_or_output == 'input':
                graph_input.append((name, nodeio))
                # also put the graph input tensor into the internal_result
                # TODO if we can find the corresponding relation between the value node
                # and the dummy_inputs, we can use the inputs value in the dummy_input
                value = self._vnode_to_value(self.debugname_to_value[name])
                self.internal_result[name] = value
                # create the mask tensor for the input value
                if isinstance(self.internal_result[name], torch.Tensor):
                    self.masks[name] = torch.ones_like(value)
                    self.constant[name] = torch.zeros_like(value)
            elif nodeio.input_or_output == 'output':
                graph_output.append((name, nodeio))
        # count the degree for the node in the graph
        in_degree = {}
        out_degree = {}
        visit_queue = queue.Queue()
        for node in self.torch_graph.nodes_py.nodes_op:
            successors = self.torch_graph.find_successors(node.unique_name)
            out_degree[node.unique_name] = len(successors)
            predecessors = self.torch_graph.find_predecessors(node.unique_name)
            in_degree[node.unique_name] = len(predecessors)
            if in_degree[node.unique_name] == 0:
                visit_queue.put(node)
        # Forward mask inference
        while not visit_queue.empty():
            curnode = visit_queue.get()
            # forward mask inference for curnode
            self.update_direct_sparsity(curnode)
            successors = self.torch_graph.find_successors(curnode.unique_name)
            for successor in successors:
                in_degree[successor] -= 1
                if in_degree[successor] == 0:
                    visit_queue.put(self.torch_graph.name_to_node[successor])
        # backward mask inference
        for unique_name in out_degree:
            if out_degree[unique_name] == 0:
                visit_queue.put(self.torch_graph.name_to_node[unique_name])
        while not visit_queue.empty():
            curnode = visit_queue.get()
            self.update_indirect_sparsity(curnode)
            predecessors = self.torch_graph.find_predecessors(
                curnode.unique_name)
            for predecessor in predecessors:
                out_degree[predecessor] -= 1
                if out_degree[predecessor] == 0:
                    visit_queue.put(self.torch_graph.name_to_node[predecessor])
QuanluZhang's avatar
QuanluZhang committed
374
375
376
377
378
379
380
381
382
383

    def replace_compressed_modules(self):
        """
        Replace all the modules that have changed (weights/inputs/output) shape.
        The new module is created using the same arguments of the to-be-replaced module,
        and correctly inherits its weights.

        NOTE: ```func``` type cannot be replaced as it is not a module, thus, one limitation
        is that ```func``` should be not required to be replaced.
        """
Ningxin Zheng's avatar
Ningxin Zheng committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        with torch.no_grad():
            for unique_name in self.auto_inferences:
                self.replace_submodule(unique_name)

    def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
        """
        Replace the submodule according to the inferred sparsity.
        unique_name: str
            The unique_name of the submodule to replace.
        reindex_dim: int
            The dimension of the re-index operation.
        reindex: Reindex
            The index tensor. Normally this variable is None. If we want to reindex the
            output of this submodule, we can pass the index by this parameter.
        """
        class ReindexModule(nn.Module):
            """
            ReindexModule is used to resolve the mask conflict when replace the submodule.
            Basically, we can use two ways to resolve the mask conflict: (1) unmask some
            values(will introduce more computation overhead) (2) reindex and padd the output
            tensor of the target op(introduce more memory access overhad). Currently this
            method is shutdown, in the future, we will merge these two methods into a graph
            pass which is used to resolve the mask conflict.
            """
            def __init__(self, ori_module, reindex_dim, reindex):
                super(ReindexModule, self).__init__()
                self.ori_module = ori_module
                self.reindex_dim = reindex_dim
                self.reindex = reindex
                tmp_index = [slice(None, None) for i in range(reindex_dim+1)]
                # the index for the tensor
                tmp_index[reindex_dim] = reindex
                self.t_index = tuple(tmp_index)

            def forward(self, x):
                tmpout = self.ori_module(x)
                shape = list(tmpout.size())
                shape[self.reindex_dim] = self.reindex.size(0)
                out = torch.zeros(tuple(shape), device=tmpout.device,
                                  requires_grad=tmpout.requires_grad)
                out[self.t_index] = tmpout
                return out

        assert unique_name in self.auto_inferences
        g_node = self.torch_graph.name_to_node[unique_name]
        _logger.debug("replace %s, in %s type, with op_type %s",
                      unique_name, g_node.type, g_node.op_type)
        auto_infer = self.auto_inferences[unique_name]
        if g_node.type == 'module':
            if g_node.unique_name in self.torch_graph.reused_module:
                if reindex_dim is not None:
                    _logger.warning(
                        'Cannot replace a reused module with padding operator!!')
                    return None
            super_module, leaf_module = get_module_by_name(
                self.bound_model, g_node.name)
            m_type = g_node.op_type
            if not m_type in replace_module:
                raise RuntimeError(
                    "Has not supported replacing the module: `{}`".format(m_type))
            _logger.info("replace module (name: %s, op_type: %s)",
                         g_node.name, m_type)
            compressed_module = replace_module[m_type](
                leaf_module, auto_infer.get_masks())
            new_submodule = compressed_module
            if reindex_dim is None:
                setattr(super_module, g_node.name.split(
                    '.')[-1], compressed_module)
            elif reindex_dim is not None and reindex is not None:
                # reindex the output of this submodule and replace the orginal module
                new_submodule = ReindexModule(
                    compressed_module, reindex_dim, reindex)
                setattr(super_module, g_node.name.split(
                    '.')[-1], new_submodule)
            return new_submodule
        elif g_node.type == 'func':
            _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
                         unique_name, g_node.op_type)
            return None
        else:
            raise RuntimeError("Unsupported node type: {}".format(g_node.type))

    def initialize_speedup(self):
        """
        Do some initial work for speedup.
        """
        # initialize the self.debugname_to_value
        # build a mapping table from the debug name of the tensor
        # to its value node in the graph
        traced_graph = self.torch_graph.trace.graph
        for node in traced_graph.nodes():
            for _input in node.inputs():
                debug_name = _input.debugName()
                if debug_name not in self.debugname_to_value:
                    self.debugname_to_value[debug_name] = _input
            for _output in node.outputs():
                debug_name = _output.debugName()
                if debug_name not in self.debugname_to_value:
                    self.debugname_to_value[debug_name] = _output
        # put the model itself into internel_result to perform the
        # value inference for the 'prim::GetAttr', the first ClassType
        # of the whole graph is the model class

        for graph_input in traced_graph.inputs():
            if graph_input.type().kind() == 'ClassType':
                self.internal_result[graph_input.debugName()
                                     ] = self.bound_model
                break
QuanluZhang's avatar
QuanluZhang committed
492
493
494

    def speedup_model(self):
        """
Ningxin Zheng's avatar
Ningxin Zheng committed
495
496
        There are basically two steps: first, do mask/shape inference,
        second, replace modules.
QuanluZhang's avatar
QuanluZhang committed
497
        """
498

Ningxin Zheng's avatar
Ningxin Zheng committed
499
500
501
502
503
504
505
506
        _logger.info("start to speed up the model")
        self.initialize_speedup()
        training = self.bound_model.training
        # set to the evaluation mode
        self.bound_model.train(False)
        # TODO suppose to fix the conflict after the sparsity propagation
        # which is more elegent
        fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
507

508
        _logger.info("infer module masks...")
QuanluZhang's avatar
QuanluZhang committed
509
        self.infer_modules_masks()
Ningxin Zheng's avatar
Ningxin Zheng committed
510
511
512
513
        _logger.info('resolve the mask conflict')

        # load the original stat dict before replace the model
        self.bound_model.load_state_dict(self.ori_state_dict)
514
        _logger.info("replace compressed modules...")
Ningxin Zheng's avatar
Ningxin Zheng committed
515
        # the mask conflict should be already resolved
QuanluZhang's avatar
QuanluZhang committed
516
        self.replace_compressed_modules()
chicm-ms's avatar
chicm-ms committed
517
        self.bound_model.train(training)
518
        _logger.info("speedup done")