Commit cd4d904a authored by Max Ryabinin's avatar Max Ryabinin
Browse files

Raise an error when loading a quantized checkpoint before quantization

parent ac3ab281
...@@ -248,6 +248,11 @@ class Linear8bitLt(nn.Linear): ...@@ -248,6 +248,11 @@ class Linear8bitLt(nn.Linear):
for key in unexpected_keys: for key in unexpected_keys:
input_name = key[len(prefix):] input_name = key[len(prefix):]
if input_name == "SCB": if input_name == "SCB":
if self.weight.SCB is None:
# buffers not yet initialized, can't call them directly without
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()")
input_param = state_dict[key] input_param = state_dict[key]
self.weight.SCB.copy_(input_param) self.weight.SCB.copy_(input_param)
unexpected_keys.remove(key) unexpected_keys.remove(key)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment