Unverified Commit 08b8eec1 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

fix(server): Fixing non parameters in quantize script `bigcode/starcoder` was an example. (#661)

parent 362883f2
...@@ -812,10 +812,13 @@ def load_weights_pre_hook(module_name, weights, recursive=False): ...@@ -812,10 +812,13 @@ def load_weights_pre_hook(module_name, weights, recursive=False):
tensor = weights.get_tensor(tensor_name) tensor = weights.get_tensor(tensor_name)
setdeepattr(module, local_param, nn.Parameter(tensor)) setdeepattr(module, local_param, nn.Parameter(tensor))
else: else:
tensor = current_tensor.to(device=torch.device("cuda:0"))
if current_tensor.requires_grad:
tensor = nn.Parameter(tensor)
setdeepattr( setdeepattr(
module, module,
local_param, local_param,
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))), tensor
) )
return inner return inner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment