Commit 6a934d4f authored by Ruslan Svirschevski's avatar Ruslan Svirschevski
Browse files

reorder state_dict

parent 1d541b50
...@@ -229,6 +229,8 @@ class Linear4bit(nn.Linear): ...@@ -229,6 +229,8 @@ class Linear4bit(nn.Linear):
besides weight and bias, besides weight and bias,
fill state_dict with components of quant_state fill state_dict with components of quant_state
""" """
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
if getattr(self.weight, "quant_state", None) is not None: if getattr(self.weight, "quant_state", None) is not None:
quant_state_dict = self.weight.quant_state.as_dict() quant_state_dict = self.weight.quant_state.as_dict()
tensor_keys = [k for k, v in quant_state_dict.items() if isinstance(v, torch.Tensor)] tensor_keys = [k for k, v in quant_state_dict.items() if isinstance(v, torch.Tensor)]
...@@ -236,7 +238,6 @@ class Linear4bit(nn.Linear): ...@@ -236,7 +238,6 @@ class Linear4bit(nn.Linear):
destination[prefix + "weight." + k] = quant_state_dict.pop(k) if keep_vars else quant_state_dict.pop(k).detach() destination[prefix + "weight." + k] = quant_state_dict.pop(k) if keep_vars else quant_state_dict.pop(k).detach()
destination[prefix + "weight." + "quant_state_dict"] = quant_state_dict destination[prefix + "weight." + "quant_state_dict"] = quant_state_dict
destination[prefix + "weight." + "quantization_method"] = "bitsandbytes." + quant_state_dict["quant_type"] destination[prefix + "weight." + "quantization_method"] = "bitsandbytes." + quant_state_dict["quant_type"]
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually # weights are cast automatically as Int8Params, but the bias has to be cast manually
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment