Commit 8706830f authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Fix some bad types

parent b03ce0e0
...@@ -658,8 +658,8 @@ class Linear8bitLt(nn.Linear): ...@@ -658,8 +658,8 @@ class Linear8bitLt(nn.Linear):
def __init__( def __init__(
self, self,
input_features, input_features: int,
output_features, output_features: int,
bias=True, bias=True,
has_fp16_weights=True, has_fp16_weights=True,
memory_efficient_backward=False, memory_efficient_backward=False,
...@@ -671,9 +671,9 @@ class Linear8bitLt(nn.Linear): ...@@ -671,9 +671,9 @@ class Linear8bitLt(nn.Linear):
Initialize Linear8bitLt class. Initialize Linear8bitLt class.
Args: Args:
input_features (`str`): input_features (`int`):
Number of input features of the linear layer. Number of input features of the linear layer.
output_features (`str`): output_features (`int`):
Number of output features of the linear layer. Number of output features of the linear layer.
bias (`bool`, defaults to `True`): bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well. Whether the linear class uses the bias term as well.
......
...@@ -140,7 +140,7 @@ def replace_linear( ...@@ -140,7 +140,7 @@ def replace_linear(
List of modules names not to convert. Defaults to `lm_head`. List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`): copy_weights (`bool`):
Copy the weights from the old linear module to the new one Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`): post_processing_function (`str`):
A function name of the replacement linear class that is called A function name of the replacement linear class that is called
after processing. after processing.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment