"vscode:/vscode.git/clone" did not exist on "12714155c7e5388306b2b3c01f6520b4ceec3cb7"
Unverified Commit 617d9f32 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

support directly load the mask (#4144)

parent acb627cf
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import queue
import logging
import copy
......@@ -35,8 +35,8 @@ class ModelSpeedup:
Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right
device.
masks_file : str
The path of user provided mask file
masks_file : str/dict
The path of user provided mask file, or the mask object
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
batch_dim : int
......@@ -63,9 +63,13 @@ class ModelSpeedup:
# load the mask tensor to the same device with the dummy_input
# self.masks save the mask tensors pruned by the user and the infered
# masks of the others modules
self.masks = torch.load(
masks_file, map_location if map_location is not None else str(self.device))
if isinstance(masks_file, str) and os.path.exists(masks_file):
self.masks = torch.load(
masks_file, map_location if map_location is not None else str(self.device))
elif isinstance(masks_file, dict):
self.masks = masks_file
else:
raise Exception('Please provide the mask or the path of the mask file')
self.constant = {}
# self.internal_result save the internal output of the submodules
self.internal_result = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment