Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
75377d12
Commit
75377d12
authored
Feb 24, 2023
by
Mitchell Wortsman
Browse files
new experiments
parent
5d2e23e8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
1 deletion
+60
-1
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+59
-0
No files found.
bitsandbytes/nn/__init__.py
View file @
75377d12
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
Fake4bitLinear
,
LinearFP8
,
LinearInt8
,
Linear8bitLtThresh
,
LinearInt8Cast
,
Linear8bitLt2
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
Fake4bitLinear
,
LinearFP8
,
LinearInt8
,
Linear8bitLtThresh
,
LinearInt8Cast
,
Linear8bitLt2
,
Linear8bitLtMixed
bitsandbytes/nn/modules.py
View file @
75377d12
...
@@ -407,6 +407,65 @@ class Linear8bitLt2(nn.Linear):
...
@@ -407,6 +407,65 @@ class Linear8bitLt2(nn.Linear):
return
out
return
out
class
Linear8bitLtMixed
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
,
index
=
None
,
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
index
=
index
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
self
.
state
.
memory_efficient_backward
=
memory_efficient_backward
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
weight
.
CB
=
None
self
.
weight
.
SCB
=
None
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# if self.bias is not None and self.bias.dtype != torch.float16:
# self.bias.data = self.bias.data.half()
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out
=
bnb
.
matmul_mixed
(
x
.
half
(),
self
.
weight
.
half
(),
bias
=
None
,
state
=
self
.
state
)
+
self
.
bias
if
not
self
.
state
.
has_fp16_weights
:
if
not
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CB
is
not
None
:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del
self
.
state
.
CB
self
.
weight
.
data
=
self
.
state
.
CxB
elif
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CxB
is
not
None
:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del
self
.
state
.
CxB
return
out
class
Linear8bitLtThresh
(
Linear8bitLt
):
class
Linear8bitLtThresh
(
Linear8bitLt
):
def
__init__
(
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment