Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
3fbf60ad
Commit
3fbf60ad
authored
Feb 23, 2023
by
Mitchell Wortsman
Browse files
sim now worse than real
parent
7b764d35
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
118 additions
and
29 deletions
+118
-29
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+32
-23
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+85
-5
No files found.
bitsandbytes/autograd/_functions.py
View file @
3fbf60ad
...
@@ -395,38 +395,41 @@ class MatMulFP8(torch.autograd.Function):
...
@@ -395,38 +395,41 @@ class MatMulFP8(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
,
bsz
=
1024
):
# default of pytorch behavior if inputs are empty
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
A
=
A
ctx
.
B
=
B
ctx
.
B
=
B
B_shape
=
B
.
shape
B_shape
=
B
.
shape
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Dequantize
# 1. Dequantize
# 2. MatmulnN
# 2. MatmulnN
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
bsz
)
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
1024
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
,
blocksize
=
bsz
).
to
(
A
.
dtype
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
,
blocksize
=
1024
).
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize
(
B
.
float
(),
code
=
fw_code
)
cB
,
state
=
F
.
quantize
(
B
.
float
(),
code
=
fw_code
)
fp8B
=
F
.
dequantize
(
cB
,
state
).
to
(
B
.
dtype
)
fp8B
=
F
.
dequantize
(
cB
,
state
).
to
(
B
.
dtype
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
# output is half
# 3. Save state
# 3. Save state
ctx
.
fw_code
=
fw_code
ctx
.
bw_code
=
bw_code
ctx
.
bw_code
=
bw_code
ctx
.
bsz
=
bsz
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
if
any
(
ctx
.
needs_input_grad
[:
2
]):
ctx
.
tensors
=
(
fp8A
,
fp8B
)
# NOTE: we send back A, and re-quant.
ctx
.
tensors
=
(
A
,
fp8B
)
else
:
else
:
ctx
.
tensors
=
(
None
,
None
)
ctx
.
tensors
=
(
None
,
None
)
...
@@ -435,30 +438,36 @@ class MatMulFP8(torch.autograd.Function):
...
@@ -435,30 +438,36 @@ class MatMulFP8(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
,
None
req_gradA
,
req_gradB
,
_
,
_
,
_
=
ctx
.
needs_input_grad
req_gradA
,
req_gradB
,
_
,
_
,
_
,
_
=
ctx
.
needs_input_grad
fp8
A
,
B
=
ctx
.
tensors
A
,
B
=
ctx
.
tensors
grad_A
,
grad_B
=
None
,
None
grad_A
,
grad_B
=
None
,
None
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
1024
)
# TODO: Fix blocksize to be output_dim
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
,
blocksize
=
1024
).
to
(
grad_output
.
dtype
)
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
ctx
.
bsz
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
,
blocksize
=
ctx
.
bsz
).
to
(
grad_output
.
dtype
)
# Cast grad_output to fp16
cgrad_output_2
,
state_2
=
F
.
quantize
(
grad_output
.
float
(),
code
=
ctx
.
bw_code
)
if
len
(
grad_output
.
shape
)
==
3
:
fp8out_2
=
F
.
dequantize
(
cgrad_output_2
,
state_2
).
to
(
grad_output
.
dtype
)
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
# not supported by PyTorch. TODO: create work-around
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
fp8A
.
dtype
)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
if
req_gradB
:
if
req_gradB
:
if
fp8A
.
ndim
==
3
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
fp8At
=
fp8A
.
transpose
(
2
,
1
)
cA
,
state
=
F
.
quantize
(
At
.
float
(),
code
=
ctx
.
fw_code
)
elif
fp8A
.
ndim
==
2
:
fp8At
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
fp8At
=
fp8A
.
t
()
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out_2
.
dtype
),
fp8out_2
).
to
(
B
.
dtype
)
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out
.
dtype
),
fp8out
).
to
(
B
.
dtype
)
return
grad_A
,
grad_B
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
def
matmul
(
def
matmul
(
...
@@ -475,8 +484,8 @@ def matmul(
...
@@ -475,8 +484,8 @@ def matmul(
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
def
matmul_fp8
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
):
def
matmul_fp8
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
):
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
)
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
)
def
matmul
(
def
matmul
(
...
...
bitsandbytes/nn/__init__.py
View file @
3fbf60ad
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
Fake4bitLinear
,
LinearFP8
,
LinearInt8
,
Linear8bitLtThresh
,
LinearInt8Cast
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
Fake4bitLinear
,
LinearFP8
,
LinearInt8
,
Linear8bitLtThresh
,
LinearInt8Cast
,
Linear8bitLt2
bitsandbytes/nn/modules.py
View file @
3fbf60ad
...
@@ -346,6 +346,68 @@ class Linear8bitLt(nn.Linear):
...
@@ -346,6 +346,68 @@ class Linear8bitLt(nn.Linear):
return
out
return
out
# Not in use for now...
class
Linear8bitLt2
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
,
index
=
None
,
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
index
=
index
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
self
.
state
.
memory_efficient_backward
=
memory_efficient_backward
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
weight
.
CB
=
None
self
.
weight
.
SCB
=
None
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# if self.bias is not None and self.bias.dtype != torch.float16:
# self.bias.data = self.bias.data.half()
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
None
,
state
=
self
.
state
)
+
self
.
bias
#out = torch.matmul(x.half(), W.half().t()) + self.bias
if
not
self
.
state
.
has_fp16_weights
:
if
not
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CB
is
not
None
:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del
self
.
state
.
CB
self
.
weight
.
data
=
self
.
state
.
CxB
elif
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CxB
is
not
None
:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del
self
.
state
.
CxB
return
out
class
Linear8bitLtThresh
(
Linear8bitLt
):
class
Linear8bitLtThresh
(
Linear8bitLt
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -363,7 +425,7 @@ class Linear8bitLtThresh(Linear8bitLt):
...
@@ -363,7 +425,7 @@ class Linear8bitLtThresh(Linear8bitLt):
bias
=
bias
,
bias
=
bias
,
has_fp16_weights
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
,
threshold
=
6.
,
index
=
index
index
=
index
)
)
...
@@ -372,13 +434,19 @@ class LinearFP8(nn.Linear):
...
@@ -372,13 +434,19 @@ class LinearFP8(nn.Linear):
super
().
__init__
(
input_features
,
output_features
,
bias
)
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
bw_code
=
None
self
.
fw_code
=
None
self
.
fw_code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
print
(
'block size is'
,
self
.
bsz
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
if
self
.
fw_code
is
None
:
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
out
+=
self
.
bias
...
@@ -388,27 +456,39 @@ class LinearInt8(nn.Linear):
...
@@ -388,27 +456,39 @@ class LinearInt8(nn.Linear):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
code
=
None
self
.
code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
code
is
None
:
if
self
.
code
is
None
:
self
.
code
=
bnb
.
functional
.
create_linear_map
(
True
,
8
).
to
(
x
.
device
)
self
.
code
=
bnb
.
functional
.
create_linear_map
(
True
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
code
,
bw_code
=
self
.
code
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
code
,
bw_code
=
self
.
code
,
bsz
=
self
.
bsz
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
out
+=
self
.
bias
return
out
return
out
# This is 4 bit version.
class
LinearInt8Cast
(
nn
.
Linear
):
class
LinearInt8Cast
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
code
=
None
self
.
code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
code
is
None
:
if
self
.
code
is
None
:
self
.
code
=
bnb
.
functional
.
create_linear_map
(
True
,
8
).
to
(
x
.
device
)
self
.
code
=
bnb
.
functional
.
create_linear_map
(
True
,
4
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
.
half
()
,
self
.
weight
.
half
().
t
(),
fw_code
=
self
.
code
,
bw_code
=
self
.
code
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
code
,
bw_code
=
self
.
code
,
bsz
=
self
.
bsz
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
out
+=
self
.
bias
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment