Commit 8706830f authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Fix some bad types

parent b03ce0e0
......@@ -658,8 +658,8 @@ class Linear8bitLt(nn.Linear):
def __init__(
self,
input_features,
output_features,
input_features: int,
output_features: int,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
......@@ -671,9 +671,9 @@ class Linear8bitLt(nn.Linear):
Initialize Linear8bitLt class.
Args:
input_features (`str`):
input_features (`int`):
Number of input features of the linear layer.
output_features (`str`):
output_features (`int`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
......
......@@ -140,7 +140,7 @@ def replace_linear(
List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`):
Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`):
post_processing_function (`str`):
A function name of the replacement linear class that is called
after processing.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment