Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
f997b1d9
Unverified
Commit
f997b1d9
authored
Oct 26, 2023
by
Jeremy Howard
Committed by
GitHub
Oct 26, 2023
Browse files
Update utils.py to remove dupe `replace_linear`
parent
18e827d6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
39 deletions
+0
-39
bitsandbytes/utils.py
bitsandbytes/utils.py
+0
-39
No files found.
bitsandbytes/utils.py
View file @
f997b1d9
...
@@ -99,45 +99,6 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
...
@@ -99,45 +99,6 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
return
idx
return
idx
def
replace_linear
(
model
,
linear_replacement
,
skip_modules
=
[
"lm_head"
],
copy_weights
=
False
,
post_processing_function
=
None
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`):
Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`):
A function name of the replacement linear class that is called
after processing.
"""
for
name
,
module
in
model
.
named_children
():
if
len
(
list
(
module
.
children
()))
>
0
:
replace_linear
(
module
,
linear_replacement
,
skip_modules
,
copy_weights
,
post_processing_function
)
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
name
not
in
skip_modules
:
old_module
=
model
.
_modules
[
name
]
model
.
_modules
[
name
]
=
linear_replacement
(
module
.
in_features
,
module
.
out_features
,
module
.
bias
is
not
None
,
)
if
copy_weights
:
model
.
_modules
[
name
].
weight
=
old_module
.
weight
model
.
_modules
[
name
].
bias
=
old_module
.
bias
if
post_processing_function
is
not
None
:
func
=
getattr
(
module
,
post_processing_function
,
None
)
if
func
is
not
None
:
func
(
module
)
return
model
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
_decode
(
subprocess_err_out_tuple
):
def
_decode
(
subprocess_err_out_tuple
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment