Unverified Commit 97d067e6 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Speedup enhancement (#4925)

parent 4ab85d3d
...@@ -45,6 +45,7 @@ replace_module = { ...@@ -45,6 +45,7 @@ replace_module = {
'Upsample': lambda module, masks: no_replace(module, masks), 'Upsample': lambda module, masks: no_replace(module, masks),
'LayerNorm': lambda module, masks: replace_layernorm(module, masks), 'LayerNorm': lambda module, masks: replace_layernorm(module, masks),
'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks), 'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks),
'Embedding': lambda module, masks: replace_embedding(module, masks),
'PixelShuffle': lambda module, masks: replace_pixelshuffle(module, masks), 'PixelShuffle': lambda module, masks: replace_pixelshuffle(module, masks),
'Flatten': lambda module, masks: no_replace(module, masks) 'Flatten': lambda module, masks: no_replace(module, masks)
} }
...@@ -85,6 +86,30 @@ def convert_to_coarse_mask(t_mask, dim): ...@@ -85,6 +86,30 @@ def convert_to_coarse_mask(t_mask, dim):
return indexes, remained_indexes return indexes, remained_indexes
def convert_dense_shape(mask):
"""
Get the dense shape of the tensor after removing the sparsity
values.
Parameters
----------
mask: torch.Tensor
The mask tensor.
Returns
-------
dense_shape: tuple
The dense shape after removing the sparsity values.
"""
assert isinstance(mask, torch.Tensor)
n_dim = len(mask.size())
dense_shape = []
for dim in range(n_dim):
_, remained = convert_to_coarse_mask(mask, dim)
dense_shape.append(remained.size(0))
return tuple(dense_shape)
def no_replace(module, masks): def no_replace(module, masks):
""" """
No need to replace No need to replace
...@@ -165,9 +190,12 @@ def replace_linear(linear, masks): ...@@ -165,9 +190,12 @@ def replace_linear(linear, masks):
in_mask = in_masks[0] in_mask = in_masks[0]
weight_mask = weight_mask['weight'] weight_mask = weight_mask['weight']
# the input of the linear may have two dimensions(CV models) or three
# dimensions(Bert, for example)
n_dim = len(in_mask.size())
# N C K # N C K
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1) pruned_in, remained_in = convert_to_coarse_mask(in_mask, n_dim-1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1) pruned_out, remained_out = convert_to_coarse_mask(output_mask, n_dim-1)
n_remained_in = weight_mask.size(1) - pruned_in.size(0) n_remained_in = weight_mask.size(1) - pruned_in.size(0)
n_remained_out = weight_mask.size(0) - pruned_out.size(0) n_remained_out = weight_mask.size(0) - pruned_out.size(0)
remained_in, remained_out = remained_in.to( remained_in, remained_out = remained_in.to(
...@@ -582,16 +610,20 @@ def replace_layernorm(layernorm, masks): ...@@ -582,16 +610,20 @@ def replace_layernorm(layernorm, masks):
if len(in_masks) != 1: if len(in_masks) != 1:
raise InputsNumberError() raise InputsNumberError()
in_mask = in_masks[0] in_mask = in_masks[0]
dim_n = len(in_mask.size()) dense_shape = convert_dense_shape(in_mask)
new_shape = [] norm_shape = layernorm.normalized_shape
for i in range(1, dim_n): dim_n = len(dense_shape) - len(norm_shape)
sum_dims = list(range(0, dim_n)) return nn.LayerNorm(dense_shape[dim_n:], layernorm.eps, layernorm.elementwise_affine)
sum_dims.remove(i)
reduced = torch.sum(in_mask, sum_dims)
n_remained = torch.sum(reduced > 0) def replace_embedding(embedding, masks):
new_shape.append(n_remained) """
Replace the embedding layer according the infered masks.
return nn.LayerNorm(tuple(new_shape), layernorm.eps, layernorm.elementwise_affine) We replace the embedding layer according the weight masks,
"""
# currently we donnot support replace the embedding layer
# because we donnot have the corressponding pruner
return embedding
def replace_pixelshuffle(pixelshuffle, masks): def replace_pixelshuffle(pixelshuffle, masks):
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import copy import copy
import logging import logging
from pathlib import Path from pathlib import Path
import queue import queue
...@@ -66,6 +67,7 @@ class ModelSpeedup: ...@@ -66,6 +67,7 @@ class ModelSpeedup:
self.bound_model = model self.bound_model = model
self.inferred_masks = dict() # key: module_name, value: ModuleMasks self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.batch_dim = batch_dim self.batch_dim = batch_dim
self.confidence = confidence
self.dummy_input, self.device = self._random_model_input( self.dummy_input, self.device = self._random_model_input(
dummy_input, confidence, batch_dim) dummy_input, confidence, batch_dim)
self.torch_graph = build_module_graph(model, self.dummy_input) self.torch_graph = build_module_graph(model, self.dummy_input)
...@@ -196,6 +198,7 @@ class ModelSpeedup: ...@@ -196,6 +198,7 @@ class ModelSpeedup:
# The detach operation here is for the in-place operation. We cannot # The detach operation here is for the in-place operation. We cannot
# directly can the backward on the output tensor of an in-place operator. # directly can the backward on the output tensor of an in-place operator.
dummy_input.append(self.internal_result[_input].detach()) dummy_input.append(self.internal_result[_input].detach())
debugnames.append(_input) debugnames.append(_input)
return dummy_input, debugnames return dummy_input, debugnames
...@@ -229,15 +232,15 @@ class ModelSpeedup: ...@@ -229,15 +232,15 @@ class ModelSpeedup:
return return
# function doesn't have weights # function doesn't have weights
_auto_infer = AutoMaskInference( _auto_infer = AutoMaskInference(
func, dummy_input, in_masks, in_constants=in_constants, batch_dim=self.batch_dim) func, dummy_input, self, in_masks, in_constants=in_constants)
else: else:
weight_mask = None weight_mask = None
if module_name in self.masks: if module_name in self.masks:
weight_mask = self.masks[module_name] weight_mask = self.masks[module_name]
_, module = get_module_by_name(self.bound_model, module_name) _, module = get_module_by_name(self.bound_model, module_name)
_auto_infer = AutoMaskInference( _auto_infer = AutoMaskInference(
module, dummy_input, in_masks, weight_mask, in_constants=in_constants, module, dummy_input, self, in_masks, weight_mask, in_constants=in_constants,
state_dict=copy.deepcopy(module.state_dict()), batch_dim=self.batch_dim) state_dict=copy.deepcopy(module.state_dict()))
self.auto_inferences[unique_name] = _auto_infer self.auto_inferences[unique_name] = _auto_infer
_auto_infer.name = node.unique_name _auto_infer.name = node.unique_name
...@@ -280,6 +283,7 @@ class ModelSpeedup: ...@@ -280,6 +283,7 @@ class ModelSpeedup:
The target node to update the indirect sparsity The target node to update the indirect sparsity
""" """
unique_name = node.unique_name unique_name = node.unique_name
if unique_name in self.auto_inferences and self.auto_inferences[unique_name] is not None: if unique_name in self.auto_inferences and self.auto_inferences[unique_name] is not None:
# if the auto inference object already in self.auto_inference, then # if the auto inference object already in self.auto_inference, then
# directly update the previous one # directly update the previous one
...@@ -291,13 +295,18 @@ class ModelSpeedup: ...@@ -291,13 +295,18 @@ class ModelSpeedup:
# pass the gradient to the predecessor nodes # pass the gradient to the predecessor nodes
for in_id, tin in enumerate(auto_infer.dummy_input): for in_id, tin in enumerate(auto_infer.dummy_input):
debug_name = auto_infer.input_debugname[in_id] debug_name = auto_infer.input_debugname[in_id]
last_output = self.internal_result[debug_name] last_output = self.internal_result[debug_name]
# if isinstance(last_output, torch.Tensor): # if isinstance(last_output, torch.Tensor):
# TODO what if last output is tuple/list of tensor # TODO what if last output is tuple/list of tensor
if last_output.grad is not None and tin.grad is not None: if last_output.grad is not None and tin.grad is not None:
last_output.grad.data += tin.grad.data last_output.grad.data += tin.grad.data
else: elif last_output.grad is None:
last_output.grad = tin.grad last_output.grad = tin.grad
elif last_output.grad is not None and tin.grad is None:
# for example, tin.view(batch, tin.size(1)/2, tin.view(2)*2)
# the size operation of tin will have no gradient
continue
else: else:
_logger.warning( _logger.warning(
'Note: %s does not have corresponding mask inference object', node.name) 'Note: %s does not have corresponding mask inference object', node.name)
...@@ -388,6 +397,7 @@ class ModelSpeedup: ...@@ -388,6 +397,7 @@ class ModelSpeedup:
if out_degree[predecessor] == 0: if out_degree[predecessor] == 0:
visit_queue.put(self.torch_graph.name_to_node[predecessor]) visit_queue.put(self.torch_graph.name_to_node[predecessor])
def replace_compressed_modules(self): def replace_compressed_modules(self):
""" """
Replace all the modules that have changed (weights/inputs/output) shape. Replace all the modules that have changed (weights/inputs/output) shape.
...@@ -401,6 +411,7 @@ class ModelSpeedup: ...@@ -401,6 +411,7 @@ class ModelSpeedup:
for unique_name in self.auto_inferences: for unique_name in self.auto_inferences:
self.replace_submodule(unique_name) self.replace_submodule(unique_name)
def replace_submodule(self, unique_name, reindex_dim=None, reindex=None): def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
""" """
Replace the submodule according to the inferred sparsity. Replace the submodule according to the inferred sparsity.
...@@ -443,7 +454,6 @@ class ModelSpeedup: ...@@ -443,7 +454,6 @@ class ModelSpeedup:
requires_grad=tmpout.requires_grad) requires_grad=tmpout.requires_grad)
out[self.t_index] = tmpout out[self.t_index] = tmpout
return out return out
assert unique_name in self.auto_inferences assert unique_name in self.auto_inferences
g_node = self.torch_graph.name_to_node[unique_name] g_node = self.torch_graph.name_to_node[unique_name]
_logger.debug("replace %s, in %s type, with op_type %s", _logger.debug("replace %s, in %s type, with op_type %s",
...@@ -483,12 +493,9 @@ class ModelSpeedup: ...@@ -483,12 +493,9 @@ class ModelSpeedup:
setattr(super_module, g_node.name.split( setattr(super_module, g_node.name.split(
'.')[-1], new_submodule) '.')[-1], new_submodule)
return new_submodule return new_submodule
elif g_node.type == 'func':
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
unique_name, g_node.op_type)
return None
else: else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type)) return None
def initialize_speedup(self): def initialize_speedup(self):
""" """
......
...@@ -12,8 +12,8 @@ STD_DELTA = 1e-6 ...@@ -12,8 +12,8 @@ STD_DELTA = 1e-6
class AutoMaskInference: class AutoMaskInference:
def __init__(self, module, dummy_input, in_masks=None, weight_mask=None, \ def __init__(self, module, dummy_input, speedup, in_masks=None, weight_mask=None,
output_mask=None, name=None, in_constants=None, state_dict=None, batch_dim=0): output_mask=None, name=None, in_constants=None, state_dict=None):
""" """
This class will infer the mask of the target module automatically. This class will infer the mask of the target module automatically.
This update_direct_sparsity will infer the output mask according This update_direct_sparsity will infer the output mask according
...@@ -28,6 +28,8 @@ class AutoMaskInference: ...@@ -28,6 +28,8 @@ class AutoMaskInference:
The target module to infer the mask. Need to be callable. The target module to infer the mask. Need to be callable.
dummy_input: torch.Tensor/list of Tensor dummy_input: torch.Tensor/list of Tensor
The dummy_input of the target module. The dummy_input of the target module.
speedup: ModelSpeedup
The reference of the ModelSpeedup object.
in_masks: list of torch.Tensor in_masks: list of torch.Tensor
The input masks of the target module, if in_masks is not None, then The input masks of the target module, if in_masks is not None, then
update_direct_sparsity and update_indirect_sparsity will incrementally update_direct_sparsity and update_indirect_sparsity will incrementally
...@@ -47,8 +49,6 @@ class AutoMaskInference: ...@@ -47,8 +49,6 @@ class AutoMaskInference:
The correponding constant values of the in_masks. The correponding constant values of the in_masks.
state_dict: dict of torch.Tensor state_dict: dict of torch.Tensor
The original values of the weights. The original values of the weights.
batch_dim: int
The index of the batch dimension of the input tensors.
""" """
errmsg = '%s is not callable, should pass the nn.Module/function' % str( errmsg = '%s is not callable, should pass the nn.Module/function' % str(
...@@ -112,7 +112,8 @@ class AutoMaskInference: ...@@ -112,7 +112,8 @@ class AutoMaskInference:
self.weight_mask[name] = torch.ones_like(para.data) self.weight_mask[name] = torch.ones_like(para.data)
self.state_dict = state_dict self.state_dict = state_dict
# TODO support the other batch dimension in the future # TODO support the other batch dimension in the future
self.batch_dim = batch_dim self.batch_dim = speedup.batch_dim
self.batch_size = speedup.confidence
def random_init(self, start=0.1, end=8.0): def random_init(self, start=0.1, end=8.0):
""" """
...@@ -125,13 +126,17 @@ class AutoMaskInference: ...@@ -125,13 +126,17 @@ class AutoMaskInference:
# rules for ReLU6 to break this range constraint. # rules for ReLU6 to break this range constraint.
with torch.no_grad(): with torch.no_grad():
for tensor in self.dummy_input: for tensor in self.dummy_input:
if isinstance(tensor, torch.Tensor) and len(tensor.size()) > 0: if isinstance(tensor, torch.Tensor) and len(tensor.size()) > self.batch_dim\
# if the tensor is a scalar, then skip this tensor and tensor.size(self.batch_dim) == self.batch_size:
# if the input tensor only has one dimension, which means
# it doesn't have the batch dimension, then we don't randomize
# this tensor, because our tensor scrambling is on the batch
# dimention. For example, if the tensor is a scalar(returned
# by the size operator), then we will skip this tensor
randomize_tensor(tensor, start, end) randomize_tensor(tensor, start, end)
for para in self.weights: for para in self.weights:
randomize_tensor(self.weights[para].data, start, end) randomize_tensor(self.weights[para].data, start, end)
def zero_grad(self): def zero_grad(self):
""" """
Set the gradient of the weight, input tensor to be zeros. Set the gradient of the weight, input tensor to be zeros.
...@@ -240,7 +245,6 @@ class AutoMaskInference: ...@@ -240,7 +245,6 @@ class AutoMaskInference:
constant[:, mask_pos] = mean[mask_pos] constant[:, mask_pos] = mean[mask_pos]
return out_mask, constant return out_mask, constant
def update_indirect_sparsity(self): def update_indirect_sparsity(self):
""" """
This function will update the indirect sparsity. To explain what's This function will update the indirect sparsity. To explain what's
...@@ -379,4 +383,3 @@ class AutoMaskInference: ...@@ -379,4 +383,3 @@ class AutoMaskInference:
def get_masks(self): def get_masks(self):
return (self.in_masks, self.output_mask, self.weight_mask) return (self.in_masks, self.output_mask, self.weight_mask)
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
import re import re
import logging import logging
from functools import partial from functools import partial
import copy
import torch import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
jitid_2_dtype = {4: torch.long, 6:torch.float32}
# to exclude partial # to exclude partial
...@@ -243,7 +245,7 @@ def softmax_python(node, speedup): ...@@ -243,7 +245,7 @@ def softmax_python(node, speedup):
def contiguous_python(node, speedup): def contiguous_python(node, speedup):
class contiguousModule(torch.nn.Module): class contiguousModule(torch.nn.Module):
def forward(self, x): def forward(self, x):
return x.contiguous() return x.contiguous().clone()
return contiguousModule() return contiguousModule()
...@@ -297,6 +299,7 @@ def squeeze_python(node, speedup): ...@@ -297,6 +299,7 @@ def squeeze_python(node, speedup):
new_squeeze = partial(torch.squeeze, dim=dim) new_squeeze = partial(torch.squeeze, dim=dim)
return new_squeeze return new_squeeze
def unsqueeze_python(node, speedup): def unsqueeze_python(node, speedup):
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs()) inputs = list(c_node.inputs())
...@@ -324,7 +327,10 @@ def slice_python(node, speedup): ...@@ -324,7 +327,10 @@ def slice_python(node, speedup):
class SliceMoudle(torch.nn.Module): class SliceMoudle(torch.nn.Module):
def __init__(self, sliceobj): def __init__(self, sliceobj):
super(SliceMoudle, self).__init__() super(SliceMoudle, self).__init__()
self.sliceobj = sliceobj # we need to deepcopy the value here, because, in the
# follwing steps, we may randomize the input tensor
# which will change the values of the sliceobj
self.sliceobj = copy.deepcopy(sliceobj)
def forward(self, x, *args): def forward(self, x, *args):
# args is for the slice dimension and indexes, however, # args is for the slice dimension and indexes, however,
...@@ -344,11 +350,22 @@ def slice_python(node, speedup): ...@@ -344,11 +350,22 @@ def slice_python(node, speedup):
slice_end = parse_constant(inputs[3], speedup) slice_end = parse_constant(inputs[3], speedup)
slice_step = parse_constant(inputs[4], speedup) slice_step = parse_constant(inputs[4], speedup)
slice_obj = slice(slice_start, slice_end, slice_step) slice_obj = slice(slice_start, slice_end, slice_step)
slice_list = [] slice_list = []
for _ in range(slice_dim): for _ in range(slice_dim):
slice_list.append(slice(None, None)) slice_list.append(slice(None, None))
logger.info('Slice dim:%s, Slice obj:%s', str(slice_dim), str(slice_obj)) logger.info('Slice dim:%s, Slice obj:%s', str(slice_dim), str(slice_obj))
slice_list.append(slice_obj) slice_list.append(slice_obj)
if inputs[0].debugName() not in speedup.internal_result:
# The inputs of slice operator may be the constant
target_tensor = parse_constant(inputs[0], speedup)
slice_list = tuple(slice_list)
def constant_slice(*args):
return target_tensor[slice_list]
return constant_slice
else:
return SliceMoudle(tuple(slice_list)) return SliceMoudle(tuple(slice_list))
...@@ -356,8 +373,8 @@ def select_python(node, speedup): ...@@ -356,8 +373,8 @@ def select_python(node, speedup):
class SelectModule(torch.nn.Module): class SelectModule(torch.nn.Module):
def __init__(self, dim, index): def __init__(self, dim, index):
super(SelectModule, self).__init__() super(SelectModule, self).__init__()
self.dim = dim self.dim = copy.deepcopy(dim)
self.index = index self.index = copy.deepcopy(index)
def forward(self, x): def forward(self, x):
return x.select(self.dim, self.index) return x.select(self.dim, self.index)
...@@ -425,7 +442,9 @@ def permute_python(node, speedup): ...@@ -425,7 +442,9 @@ def permute_python(node, speedup):
class PermuteModule(torch.nn.Module): class PermuteModule(torch.nn.Module):
def __init__(self, dimlist): def __init__(self, dimlist):
super(PermuteModule, self).__init__() super(PermuteModule, self).__init__()
self.dimlist = dimlist # deepcopy the values here, because the following randomize operation
# will change the value of the dimlist
self.dimlist = copy.deepcopy(dimlist)
def forward(self, x): def forward(self, x):
return x.permute(self.dimlist) return x.permute(self.dimlist)
...@@ -439,6 +458,7 @@ def getattr_python(node, speedup): ...@@ -439,6 +458,7 @@ def getattr_python(node, speedup):
""" """
Note: Ops started with Prim:: is not taken as the key node, Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton. so we directly pass the Cpp node into this funciton.
Parameters Parameters
---------- ----------
node: torch._C.Node node: torch._C.Node
...@@ -462,6 +482,44 @@ def getattr_python(node, speedup): ...@@ -462,6 +482,44 @@ def getattr_python(node, speedup):
assert len(key_words) == 1 assert len(key_words) == 1
return GetModule(key_words[0]) return GetModule(key_words[0])
def constant_python(node, speedup):
"""
get the constant value of constant operator node.
Parameters
----------
node: torch._C.Node
The cpp node of prim::Getattr
speedup: ModelSpeedup
The corresponding speedup object.
"""
class ConstantModule(torch.nn.Module):
def __init__(self, constant):
super(ConstantModule, self).__init__()
self.constant = constant
def forward(self):
return self.constant
assert node.kind() == 'prim::Constant'
pattern = '\[value=(.*?)\]'
key_words = re.findall(pattern, str(node))
if len(key_words) == 0:
return ConstantModule(None)
assert len(key_words) == 1
# parse the constant value
value = key_words[0]
if value.startswith("\""):
value = torch.device(value[1:-1])
elif value.startswith('{'):
# TODO Support set values in the future
value = set()
elif '.' in value:
# float value
value = float(value)
else:
# integer value
value = int(value)
return ConstantModule(value)
def upsample_bilinear2d_python(node, speedup): def upsample_bilinear2d_python(node, speedup):
class UpsampleModule(torch.nn.Module): class UpsampleModule(torch.nn.Module):
...@@ -539,16 +597,25 @@ def typeas_python(node, speedup): ...@@ -539,16 +597,25 @@ def typeas_python(node, speedup):
def to_python(node, speedup): def to_python(node, speedup):
# for the time being, only device parameters are supported # for the time being, only device parameters are supported
class ToModule(torch.nn.Module): class ToModule(torch.nn.Module):
def __init__(self, device): def __init__(self, device, dtype):
super(ToModule, self).__init__() super(ToModule, self).__init__()
self.device = device
self.dtype = dtype
def forward(self, x): def forward(self, x):
return x.to(device) return x.to(device, dtype=self.dtype)
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs()) inputs = list(c_node.inputs())
device = inputs[3].toIValue() in_debugname = inputs[0].debugName()
return ToModule(device) # device of the input tensor
device = speedup.internal_result[in_debugname].device
for _, _node in enumerate(inputs[1:]):
val = parse_constant(_node, speedup)
if isinstance(val, torch.device):
device = val
dtype = jitid_2_dtype[parse_constant(inputs[1], speedup)]
return ToModule(device, dtype)
def cat_python(node, speedup): def cat_python(node, speedup):
...@@ -566,6 +633,77 @@ def cat_python(node, speedup): ...@@ -566,6 +633,77 @@ def cat_python(node, speedup):
return CatModule(dim) return CatModule(dim)
def ones_python(node, speedup):
class OnesModule(torch.nn.Module):
def __init__(self, out_size, dtype_id, device, require_grad):
super(OnesModule, self).__init__()
self.out_size = out_size
self.device = device
self.require_grad = require_grad
self.dtype = jitid_2_dtype[dtype_id]
def forward(self, *args):
return torch.ones(size=self.out_size, dtype=self.dtype, device=self.device, requires_grad=self.require_grad)
c_node = node.key_node
inputs = list(c_node.inputs())
output_shape = translate_list(inputs[0], speedup)
dtype_id = parse_constant(inputs[1], speedup)
# layout = parse_constant(inputs[2], speedup)
device = parse_constant(inputs[3], speedup)
require_grad = parse_constant(inputs[4], speedup)
return OnesModule(output_shape, dtype_id, device, require_grad)
def zeros_python(node, speedup):
class ZerosModule(torch.nn.Module):
def __init__(self, out_size, dtype_id, device, require_grad):
super(ZerosModule, self).__init__()
self.out_size = out_size
self.device = device
self.require_grad = require_grad
self.dtype = jitid_2_dtype[dtype_id]
def forward(self, *args):
return torch.zeros(size=self.out_size, dtype=self.dtype, device=self.device, requires_grad=self.require_grad)
c_node = node.key_node
inputs = list(c_node.inputs())
output_shape = translate_list(inputs[0], speedup)
dtype_id = parse_constant(inputs[1], speedup)
# layout = parse_constant(inputs[2], speedup)
device = parse_constant(inputs[3], speedup)
require_grad = parse_constant(inputs[4], speedup)
return ZerosModule(output_shape, dtype_id, device, require_grad)
def rsub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
other_name = inputs[1].debugName()
alpha = parse_constant(inputs[2], speedup)
if other_name not in speedup.internal_result:
constant = parse_constant(inputs[1], speedup)
if constant is None:
return torch.sub()
else:
new_sub = partial(torch.sub, other=constant, alpha=alpha)
return new_sub
def expand_python(node, speedup):
class ExpandModule(torch.nn.Module):
def __init__(self, new_size):
super(ExpandModule, self).__init__()
# need deepcopy when the input is size-related
self.new_size = copy.deepcopy(new_size)
def forward(self, *args):
return args[0].expand(self.new_size).clone()
c_node = node.key_node
inputs = list(c_node.inputs())
new_size = translate_list(inputs[1], speedup)
return ExpandModule(new_size)
def expandas_python(node, speedup): def expandas_python(node, speedup):
class ExpandasModule(torch.nn.Module): class ExpandasModule(torch.nn.Module):
def forward(self, x, y): def forward(self, x, y):
...@@ -616,13 +754,18 @@ trans_from_jit_to_python = { ...@@ -616,13 +754,18 @@ trans_from_jit_to_python = {
'aten::exp': exp_python, 'aten::exp': exp_python,
'aten::squeeze': squeeze_python, 'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python, 'aten::unsqueeze': unsqueeze_python,
'aten::constant_pad_nd': constant_pad_nd_python, 'aten::ones': ones_python,
'aten::silu': silu_python, 'aten::zeros': zeros_python,
'aten::expand_as': expandas_python, 'aten::rsub': rsub_python,
'aten::expand': expand_python,
'prim::TupleUnpack': tupleunpack_python, 'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python, 'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python, 'prim::NumToTensor': num2tensor_python,
'prim::GetAttr': getattr_python 'prim::GetAttr': getattr_python,
'prim::Constant': constant_python,
'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python,
'aten::expand_as': expandas_python
} }
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import re
import logging
import torch import torch
torch_float_dtype = [torch.float, torch.float16, torch.float32, torch.float64, torch.half, torch.double] torch_float_dtype = [torch.float, torch.float16,
torch_integer_dtype = [torch.uint8, torch.int16, torch.short, torch.int32, torch.long, torch.bool] torch.float32, torch.float64, torch.half, torch.double]
torch_integer_dtype = [torch.uint8, torch.int16,
torch.short, torch.int32, torch.long, torch.bool]
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def get_module_by_name(model, module_name): def get_module_by_name(model, module_name):
""" """
...@@ -46,11 +54,13 @@ def rand_like_with_shape(shape, ori_t): ...@@ -46,11 +54,13 @@ def rand_like_with_shape(shape, ori_t):
require_grad = ori_t.requires_grad require_grad = ori_t.requires_grad
lower_bound = torch.min(ori_t) lower_bound = torch.min(ori_t)
higher_bound = torch.max(ori_t) higher_bound = torch.max(ori_t)
if dtype in [torch.uint8, torch.int16, torch.short, torch.int16, torch.long, torch.bool]: if dtype in [torch.uint8, torch.int16, torch.short, torch.int16, torch.long, torch.bool]:
return torch.randint(lower_bound, higher_bound+1, shape, dtype=dtype, device=device) return torch.randint(lower_bound, higher_bound+1, shape, dtype=dtype, device=device)
else: else:
return torch.rand(shape, dtype=dtype, device=device, requires_grad=require_grad) return torch.rand(shape, dtype=dtype, device=device, requires_grad=require_grad)
def randomize_tensor(tensor, start=1, end=100): def randomize_tensor(tensor, start=1, end=100):
""" """
Randomize the target tensor according to the given Randomize the target tensor according to the given
...@@ -59,11 +69,79 @@ def randomize_tensor(tensor, start=1, end=100): ...@@ -59,11 +69,79 @@ def randomize_tensor(tensor, start=1, end=100):
assert isinstance(tensor, torch.Tensor) assert isinstance(tensor, torch.Tensor)
if tensor.dtype in torch_integer_dtype: if tensor.dtype in torch_integer_dtype:
# integer tensor can only be randomized by the torch.randint # integer tensor can only be randomized by the torch.randint
# torch.randint(int(start), int(end), tensor.size(), out=tensor.data, dtype=tensor.dtype) torch.randint(int(start), int(end), tensor.size(),
pass out=tensor.data, dtype=tensor.dtype)
# pass
else: else:
# we can use nn.init.uniform_ to randomize this tensor # we can use nn.init.uniform_ to randomize this tensor
# Note: the tensor that with integer type cannot be randomize # Note: the tensor that with integer type cannot be randomize
# with nn.init.uniform_ # with nn.init.uniform_
torch.nn.init.uniform_(tensor.data, start, end) torch.nn.init.uniform_(tensor.data, start, end)
jit_python_code_replacement = {
'torch.slice': lambda tmpstr: python_slice_replace(tmpstr)
}
def translate_jit_code(code):
pattern = 'torch\.(.*?)\('
func_names = re.findall(pattern, code)
modules = {'torch.': torch, 'torch.nn.functional.': torch.nn.functional,
'torch.Tensor.': torch.Tensor, 'torch._C._nn.': torch._C._nn}
replace = {}
# rebase the namespace to get the runnable python code
for full_name in func_names:
func = re.split('\.', full_name)[-1]
for module_name in modules:
torch_module = modules[module_name]
if hasattr(torch_module, func):
replace['torch.'+full_name] = module_name + func
break
# assert found == True, 'Cannot find the function call %s' % full_name
for key, value in replace.items():
code = code.replace(key, value)
# several function cannot find the coresponding function under the namespace
# torch.Tensor and torch.(for example torch.slice), so we need to handle these
# functions manually
lines = code.split('\n')
for i, line in enumerate(lines):
for fname in jit_python_code_replacement:
if fname in line:
lines[i] = jit_python_code_replacement[fname](line)
code = '\n'.join(lines)
code = 'import torch\nfrom torch import Tensor, tensor\nfrom typing import *\n' + code
with open('nni_jit_tmp_forward.py', 'w') as f:
f.write(code)
from nni_jit_tmp_forward import forward # pylint: disable=import-error
return forward
def python_slice_replace(funcstr):
"""
translate the torch.slice to the appropriate python str that can be replace
in the forward function string.
Parameters
----------
funcstr: str
the str that calling the torch.slice, for example:
_8 = torch.slice(attention_mask, 0, 0, 9223372036854775807, 1)
Returns:
new_str: str
the string that should replace the original one
"""
# parse the input parameters
pattern = 'torch\.slice\((.*)\)'
parameter_str = re.findall(pattern, funcstr)
parameters = re.split(',', parameter_str[0])
target_tensor = parameters[0]
dim = int(parameters[1])
dim_str = ','.join([':']*(dim) + [':'.join(parameters[2:])])
print('%s[%s]' % (target_tensor, dim_str))
new_str = funcstr.replace(
'torch.slice(%s)' % parameter_str[0], '%s[%s]' % (target_tensor, dim_str))
return new_str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment