Commit 7f00a36e authored by thomwolf's avatar thomwolf
Browse files

pruning should keep on device

parent e4b46d86
...@@ -80,7 +80,7 @@ def prune_linear_layer(layer, index, dim=0): ...@@ -80,7 +80,7 @@ def prune_linear_layer(layer, index, dim=0):
b = layer.bias[index].clone().detach() b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size()) new_size = list(layer.weight.size())
new_size[dim] = len(index) new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None) new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous()) new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True new_layer.weight.requires_grad = True
......
...@@ -55,7 +55,7 @@ def prune_conv1d_layer(layer, index, dim=1): ...@@ -55,7 +55,7 @@ def prune_conv1d_layer(layer, index, dim=1):
b = layer.bias[index].clone().detach() b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size()) new_size = list(layer.weight.size())
new_size[dim] = len(index) new_size[dim] = len(index)
new_layer = Conv1D(new_size[1], new_size[0]) new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
new_layer.weight.requires_grad = False new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous()) new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True new_layer.weight.requires_grad = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment