Unverified Commit 2ee289fb authored by Titus's avatar Titus Committed by GitHub
Browse files

Merge pull request #867 from jph00/patch-2

Avoid double-quantizing when calling `cuda()`
parents 744d36f7 a403c0ed
...@@ -165,6 +165,8 @@ class Params4bit(torch.nn.Parameter): ...@@ -165,6 +165,8 @@ class Params4bit(torch.nn.Parameter):
return self return self
def cuda(self, device): def cuda(self, device):
if self.quant_state is not None:
return self
w = self.data.contiguous().half().cuda(device) w = self.data.contiguous().half().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type)
self.data = w_4bit self.data = w_4bit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment