Unverified Commit e812136c authored by Titus-von-Koeller's avatar Titus-von-Koeller Committed by GitHub
Browse files

Merge pull request #843 from jph00/patch-1

Update utils.py to remove dupe `replace_linear`
parents 18e827d6 f997b1d9
......@@ -99,45 +99,6 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
return idx
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`):
Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`):
A function name of the replacement linear class that is called
after processing.
"""
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
old_module = model._modules[name]
model._modules[name] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
)
if copy_weights:
model._modules[name].weight = old_module.weight
model._modules[name].bias = old_module.bias
if post_processing_function is not None:
func = getattr(module, post_processing_function, None)
if func is not None: func(module)
return model
def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment