Unverified Commit 617d9f32 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

support directly load the mask (#4144)

parent acb627cf
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import os
import queue import queue
import logging import logging
import copy import copy
...@@ -35,8 +35,8 @@ class ModelSpeedup: ...@@ -35,8 +35,8 @@ class ModelSpeedup:
Note: The first dimension of the dummy_input should be the batchsize. Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right The dummy input for ```jit.trace```, users should put it on the right
device. device.
masks_file : str masks_file : str/dict
The path of user provided mask file The path of user provided mask file, or the mask object
map_location : str map_location : str
the device on which masks are placed, same to map_location in ```torch.load``` the device on which masks are placed, same to map_location in ```torch.load```
batch_dim : int batch_dim : int
...@@ -63,9 +63,13 @@ class ModelSpeedup: ...@@ -63,9 +63,13 @@ class ModelSpeedup:
# load the mask tensor to the same device with the dummy_input # load the mask tensor to the same device with the dummy_input
# self.masks save the mask tensors pruned by the user and the infered # self.masks save the mask tensors pruned by the user and the infered
# masks of the others modules # masks of the others modules
if isinstance(masks_file, str) and os.path.exists(masks_file):
self.masks = torch.load( self.masks = torch.load(
masks_file, map_location if map_location is not None else str(self.device)) masks_file, map_location if map_location is not None else str(self.device))
elif isinstance(masks_file, dict):
self.masks = masks_file
else:
raise Exception('Please provide the mask or the path of the mask file')
self.constant = {} self.constant = {}
# self.internal_result save the internal output of the submodules # self.internal_result save the internal output of the submodules
self.internal_result = {} self.internal_result = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment