Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
7b764d35
Commit
7b764d35
authored
Feb 21, 2023
by
Mitchell Wortsman
Browse files
adding half() cast
parent
2489d819
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
9 deletions
+66
-9
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+9
-5
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+56
-3
No files found.
bitsandbytes/autograd/_functions.py
View file @
7b764d35
...
@@ -415,8 +415,8 @@ class MatMulFP8(torch.autograd.Function):
...
@@ -415,8 +415,8 @@ class MatMulFP8(torch.autograd.Function):
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
1024
)
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
1024
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
,
blocksize
=
1024
).
to
(
A
.
dtype
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
,
blocksize
=
1024
).
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize
_blockwise
(
B
,
code
=
fw_code
,
blocksize
=
1024
)
cB
,
state
=
F
.
quantize
(
B
.
float
()
,
code
=
fw_code
)
fp8B
=
F
.
dequantize
_blockwise
(
cB
,
state
,
blocksize
=
1024
).
to
(
B
.
dtype
)
fp8B
=
F
.
dequantize
(
cB
,
state
).
to
(
B
.
dtype
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
...
@@ -450,9 +450,13 @@ class MatMulFP8(torch.autograd.Function):
...
@@ -450,9 +450,13 @@ class MatMulFP8(torch.autograd.Function):
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
# not supported by PyTorch. TODO: create work-around
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
fp8A
.
dtype
)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
())
if
req_gradB
:
if
req_gradB
:
grad_B
=
torch
.
matmul
(
fp8A
.
t
(),
fp8out
)
if
fp8A
.
ndim
==
3
:
fp8At
=
fp8A
.
transpose
(
2
,
1
)
elif
fp8A
.
ndim
==
2
:
fp8At
=
fp8A
.
t
()
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out
.
dtype
),
fp8out
).
to
(
B
.
dtype
)
return
grad_A
,
grad_B
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
...
...
bitsandbytes/nn/__init__.py
View file @
7b764d35
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
Fake4bitLinear
,
LinearFP8
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
Fake4bitLinear
,
LinearFP8
,
LinearInt8
,
Linear8bitLtThresh
,
LinearInt8Cast
bitsandbytes/nn/modules.py
View file @
7b764d35
...
@@ -326,10 +326,11 @@ class Linear8bitLt(nn.Linear):
...
@@ -326,10 +326,11 @@ class Linear8bitLt(nn.Linear):
self
.
init_8bit_state
()
self
.
init_8bit_state
()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
torch
.
float16
:
#
if self.bias is not None and self.bias.dtype != torch.float16:
self
.
bias
.
data
=
self
.
bias
.
data
.
half
()
#
self.bias.data = self.bias.data.half()
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out
=
bnb
.
matmul
(
x
.
half
(),
self
.
weight
.
half
(),
bias
=
None
,
state
=
self
.
state
)
+
self
.
bias
if
not
self
.
state
.
has_fp16_weights
:
if
not
self
.
state
.
has_fp16_weights
:
if
not
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CB
is
not
None
:
if
not
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CB
is
not
None
:
...
@@ -344,6 +345,28 @@ class Linear8bitLt(nn.Linear):
...
@@ -344,6 +345,28 @@ class Linear8bitLt(nn.Linear):
return
out
return
out
class
Linear8bitLtThresh
(
Linear8bitLt
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
6.0
,
index
=
None
,
):
super
().
__init__
(
input_features
,
output_features
,
bias
=
bias
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
,
index
=
index
)
class
LinearFP8
(
nn
.
Linear
):
class
LinearFP8
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
super
().
__init__
(
input_features
,
output_features
,
bias
)
...
@@ -361,3 +384,33 @@ class LinearFP8(nn.Linear):
...
@@ -361,3 +384,33 @@ class LinearFP8(nn.Linear):
return
out
return
out
class
LinearInt8
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
code
=
None
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
code
is
None
:
self
.
code
=
bnb
.
functional
.
create_linear_map
(
True
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
code
,
bw_code
=
self
.
code
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
class
LinearInt8Cast
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
code
=
None
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
code
is
None
:
self
.
code
=
bnb
.
functional
.
create_linear_map
(
True
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
.
half
(),
self
.
weight
.
half
().
t
(),
fw_code
=
self
.
code
,
bw_code
=
self
.
code
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment