Unverified Commit 7449d713 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Core`] Change 8-bit serialization weight format format (#1164)



* change 8-bit serialization weight format format

* precimmit

* pre-commit

* fix

* Update bitsandbytes/nn/modules.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/nn/modules.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/utils.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* address feedback

* lint

---------
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>
parent c54053d3
...@@ -14,7 +14,11 @@ import bitsandbytes as bnb ...@@ -14,7 +14,11 @@ import bitsandbytes as bnb
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer from bitsandbytes.utils import (
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
)
T = TypeVar("T", bound="torch.nn.Module") T = TypeVar("T", bound="torch.nn.Module")
...@@ -619,6 +623,16 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k ...@@ -619,6 +623,16 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
return return
weight_format = state_dict.pop(f"{prefix}weight_format", "row") weight_format = state_dict.pop(f"{prefix}weight_format", "row")
if isinstance(weight_format, torch.Tensor):
weight_format = weight_format.item()
# For new weights format storage type, we explicitly check
# if weights_format is on the mapping
if isinstance(weight_format, int) and weight_format not in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Expected supported weight format - got {weight_format}")
elif isinstance(weight_format, int) and weight_format in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]
if weight_format != "row": if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device) tile_indices = get_tile_inds(weight_format, weight.device)
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
...@@ -711,13 +725,20 @@ class Linear8bitLt(nn.Linear): ...@@ -711,13 +725,20 @@ class Linear8bitLt(nn.Linear):
if not self.state.has_fp16_weights: if not self.state.has_fp16_weights:
if param_from_weight is not None: if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
destination[format_name] = "row" destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None and not layout_reordered: elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = "row" destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None: elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB weights_format = self.state.formatB
# At this point `weights_format` is an str
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Unrecognized weights format {weights_format}")
weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]
destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)
def _load_from_state_dict( def _load_from_state_dict(
self, self,
......
...@@ -198,3 +198,7 @@ def unpack_tensor_to_dict(tensor_data): ...@@ -198,3 +198,7 @@ def unpack_tensor_to_dict(tensor_data):
unpacked_dict = json.loads(json_str) unpacked_dict = json.loads(json_str)
return unpacked_dict return unpacked_dict
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment