Unverified Commit e5a208ba authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

fix speedup with CUDA (#2947)

parent 1cd7ad5f
...@@ -573,7 +573,7 @@ def view_inshape(module_masks, mask, shape): ...@@ -573,7 +573,7 @@ def view_inshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3] step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]: for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)]) index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable output_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask) module_masks.set_output_mask(output_cmask)
return output_cmask return output_cmask
...@@ -609,7 +609,7 @@ def view_outshape(module_masks, mask, shape): ...@@ -609,7 +609,7 @@ def view_outshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3] step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]: for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)]) index.extend([loc * step_size + i for i in range(step_size)])
input_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable input_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable
module_masks.set_input_mask(input_cmask) module_masks.set_input_mask(input_cmask)
return input_cmask return input_cmask
...@@ -870,7 +870,7 @@ def conv2d_mask(module_masks, mask): ...@@ -870,7 +870,7 @@ def conv2d_mask(module_masks, mask):
if index is None: if index is None:
return None, None, None return None, None, None
else: else:
index = torch.LongTensor(index).to(weight_mask.device) index = index.long().to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4) weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=dim, index=index) weight_cmask.add_index_mask(dim=dim, index=index)
bias_cmask = None bias_cmask = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment