Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
e812136c
Unverified
Commit
e812136c
authored
Oct 31, 2023
by
Titus-von-Koeller
Committed by
GitHub
Oct 31, 2023
Browse files
Merge pull request #843 from jph00/patch-1
Update utils.py to remove dupe `replace_linear`
parents
18e827d6
f997b1d9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
39 deletions
+0
-39
bitsandbytes/utils.py
bitsandbytes/utils.py
+0
-39
No files found.
bitsandbytes/utils.py
View file @
e812136c
...
...
@@ -99,45 +99,6 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
return
idx
def
replace_linear
(
model
,
linear_replacement
,
skip_modules
=
[
"lm_head"
],
copy_weights
=
False
,
post_processing_function
=
None
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`):
Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`):
A function name of the replacement linear class that is called
after processing.
"""
for
name
,
module
in
model
.
named_children
():
if
len
(
list
(
module
.
children
()))
>
0
:
replace_linear
(
module
,
linear_replacement
,
skip_modules
,
copy_weights
,
post_processing_function
)
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
name
not
in
skip_modules
:
old_module
=
model
.
_modules
[
name
]
model
.
_modules
[
name
]
=
linear_replacement
(
module
.
in_features
,
module
.
out_features
,
module
.
bias
is
not
None
,
)
if
copy_weights
:
model
.
_modules
[
name
].
weight
=
old_module
.
weight
model
.
_modules
[
name
].
bias
=
old_module
.
bias
if
post_processing_function
is
not
None
:
func
=
getattr
(
module
,
post_processing_function
,
None
)
if
func
is
not
None
:
func
(
module
)
return
model
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
_decode
(
subprocess_err_out_tuple
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment