"...multigpu/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "c997434cdb97d386ba2ac1dfa0226e98df20c92d"
Unverified Commit a403c0ed authored by Jeremy Howard's avatar Jeremy Howard Committed by GitHub
Browse files

Avoid double-quantizing when calling `cuda()`

parent 726f1470
...@@ -165,6 +165,8 @@ class Params4bit(torch.nn.Parameter): ...@@ -165,6 +165,8 @@ class Params4bit(torch.nn.Parameter):
return self return self
def cuda(self, device): def cuda(self, device):
if self.quant_state is not None:
return self
w = self.data.contiguous().half().cuda(device) w = self.data.contiguous().half().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type)
self.data = w_4bit self.data = w_4bit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment