Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
54860539
Commit
54860539
authored
Nov 09, 2023
by
Ruslan Svirschevski
Browse files
type hints in Params4bit constructors
parent
74c00eb1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+3
-3
No files found.
bitsandbytes/nn/modules.py
View file @
54860539
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
typing
import
Optional
,
TypeVar
,
Union
,
overload
from
typing
import
Any
,
Dict
,
Optional
,
TypeVar
,
Union
,
overload
import
warnings
import
warnings
import
torch
import
torch
...
@@ -142,7 +142,7 @@ class Embedding(torch.nn.Embedding):
...
@@ -142,7 +142,7 @@ class Embedding(torch.nn.Embedding):
class
Params4bit
(
torch
.
nn
.
Parameter
):
class
Params4bit
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
,
quant_type
=
'fp4'
)
:
def
__new__
(
cls
,
data
:
Optional
[
torch
.
Tensor
]
=
None
,
requires_grad
=
True
,
quant_state
:
QuantState
=
None
,
blocksize
:
int
=
64
,
compress_statistics
:
bool
=
True
,
quant_type
:
str
=
'fp4'
)
->
"Params4bit"
:
if
data
is
None
:
if
data
is
None
:
data
=
torch
.
empty
(
0
)
data
=
torch
.
empty
(
0
)
...
@@ -155,7 +155,7 @@ class Params4bit(torch.nn.Parameter):
...
@@ -155,7 +155,7 @@ class Params4bit(torch.nn.Parameter):
return
self
return
self
@
classmethod
@
classmethod
def
from_prequantized
(
cls
,
data
,
quantized_stats
,
requires_grad
=
False
,
device
=
'cuda'
,
**
kwargs
):
def
from_prequantized
(
cls
,
data
:
torch
.
Tensor
,
quantized_stats
:
Dict
[
str
,
Any
]
,
requires_grad
:
bool
=
False
,
device
=
'cuda'
,
**
kwargs
)
->
"Params4bit"
:
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
.
requires_grad
=
requires_grad
self
.
requires_grad
=
requires_grad
self
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
quantized_stats
,
device
=
device
)
self
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
quantized_stats
,
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment