Commit 7f00a36e authored by thomwolf's avatar thomwolf
Browse files

pruning should keep on device

parent e4b46d86
......@@ -80,7 +80,7 @@ def prune_linear_layer(layer, index, dim=0):
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
......
......@@ -55,7 +55,7 @@ def prune_conv1d_layer(layer, index, dim=1):
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = Conv1D(new_size[1], new_size[0])
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment