Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
b8ea2b41
Commit
b8ea2b41
authored
Apr 12, 2023
by
Tim Dettmers
Browse files
Fixed bias conversion in Linear4bit
parent
e9fa03b7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
33 deletions
+1
-33
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+1
-33
No files found.
bitsandbytes/nn/modules.py
View file @
b8ea2b41
...
...
@@ -205,45 +205,13 @@ class Linear4bit(nn.Linear):
if
self
.
compute_dtype
is
not
None
:
x
=
x
.
to
(
self
.
compute_dtype
)
bias
=
None
if
self
.
bias
is
None
else
self
.
bias
.
half
(
self
.
compute_dtype
)
bias
=
None
if
self
.
bias
is
None
else
self
.
bias
.
to
(
self
.
compute_dtype
)
out
=
bnb
.
matmul_4bit
(
x
,
self
.
weight
.
t
(),
bias
=
bias
,
quant_state
=
self
.
weight
.
quant_state
)
out
=
out
.
to
(
inp_dtype
)
return
out
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# we only need to save extra state if .cuda was called
# then we have the (1) quantization weight and the (2) quantization config
#quant_state = getattr(self.weight, 'quant_state', None)
#if quant_state is not None:
# # 2. quantization state
# destination[prefix + 'quant_state'] = quant_state
#destination[prefix + 'weight'] = self.weight.detach()
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
#for key in unexpected_keys:
# input_name = key[len(prefix):]
# if input_name == "quant_state":
# if getattr(self.weight, 'quant_state', None) is None:
# # buffers not yet initialized, can't call them directly without
# raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear4bit is "
# "not supported. Please call module.cuda() before module.load_state_dict()")
# input_param = state_dict[key]
# self.weight.quant_state = input_param
# assert isinstance(self.weight, Param4bit)
# unexpected_keys.remove(key)
class
LinearFP4
(
Linear4bit
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'fp4'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment