Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c5c38ca1
Commit
c5c38ca1
authored
Feb 23, 2023
by
Tim Dettmers
Browse files
Added matmul_mixed.
parent
7b764d35
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
189 additions
and
4 deletions
+189
-4
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+2
-1
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+186
-2
tests/test_autograd.py
tests/test_autograd.py
+1
-1
No files found.
bitsandbytes/__init__.py
View file @
c5c38ca1
...
@@ -10,7 +10,8 @@ from .autograd._functions import (
...
@@ -10,7 +10,8 @@ from .autograd._functions import (
matmul
,
matmul
,
matmul_cublas
,
matmul_cublas
,
mm_cublas
,
mm_cublas
,
matmul_fp8
matmul_fp8
,
matmul_mixed
)
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
from
.nn
import
modules
...
...
bitsandbytes/autograd/_functions.py
View file @
c5c38ca1
...
@@ -461,6 +461,190 @@ class MatMulFP8(torch.autograd.Function):
...
@@ -461,6 +461,190 @@ class MatMulFP8(torch.autograd.Function):
return
grad_A
,
grad_B
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
class
MatMul8bitMixed
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
state
=
MatmulLtState
()):
# default to pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
ctx
.
bias
=
bias
if
A
.
shape
[
-
1
]
==
B
.
shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
formatB
=
state
.
formatB
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
# Cast A to fp16
if
A
.
dtype
!=
torch
.
float16
:
warnings
.
warn
(
f
"MatMul8bitLt: inputs will be cast from
{
A
.
dtype
}
to float16 during quantization"
)
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
.
to
(
torch
.
float16
),
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
has_fp16_weights
:
idx
=
torch
.
unique
(
coo_tensorA
.
colidx
).
long
()
CA
[:,
idx
]
=
0
CAt
[:,
idx
]
=
0
subA
=
A
[:,
idx
]
state
.
subB
=
B
[:,
idx
].
t
().
contiguous
()
state
.
idx
=
idx
else
:
if
state
.
CxB
is
None
:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
else
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
subA
=
None
# 2. Quantize B
if
state
.
has_fp16_weights
:
has_grad
=
True
if
(
getattr
(
B
,
"grad"
,
None
)
is
not
None
)
else
False
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
if
is_transposed
:
B
=
B
.
contiguous
()
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
state
.
reset_grads
()
(
CB
,
state
.
CBt
,
state
.
SCB
,
state
.
SCBt
,
coo_tensorB
,
)
=
F
.
double_quant
(
B
.
to
(
torch
.
float16
))
state
.
CxB
,
state
.
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
else
:
has_grad
=
False
if
coo_tensorA
is
not
None
and
not
state
.
has_fp16_weights
:
# extract outliers
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
state
.
idx
=
outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
)
.
t
()
.
contiguous
()
.
to
(
A
.
dtype
)
)
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
subA
=
A
[:,
state
.
idx
.
long
()]
shapeB
=
state
.
SB
[
0
]
if
len
(
input_shape
)
==
3
:
output_shape
=
(
input_shape
[
0
],
input_shape
[
1
],
shapeB
[
0
])
else
:
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
# 3. Matmul
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
# we apply the fused bias here
if
bias
is
None
or
bias
.
dtype
==
torch
.
float16
:
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
bias
)
output
=
output
.
to
(
A
.
dtype
)
else
:
# apply bias separately
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
None
)
output
=
output
.
to
(
A
.
dtype
).
add_
(
bias
)
# 4. Mixed-precision decomposition matmul
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
output
+=
torch
.
matmul
(
subA
,
state
.
subB
)
# 5. Save state
ctx
.
state
=
state
ctx
.
formatB
=
formatB
ctx
.
grad_shape
=
input_shape
ctx
.
dtype_A
,
ctx
.
dtype_B
,
ctx
.
dtype_bias
=
A
.
dtype
,
B
.
dtype
,
None
if
bias
is
None
else
bias
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
ctx
.
tensors
=
(
CAt
,
subA
,
A
)
ctx
.
tensor_states
=
(
SCAt
,
state
.
idx
)
else
:
ctx
.
tensors
=
[
None
,
None
,
None
]
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
clone_func
=
torch
.
clone
if
len
(
output_shape
)
==
3
else
lambda
x
:
x
return
clone_func
(
output
.
view
(
output_shape
))
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
req_gradB
,
_
,
req_gradBias
,
_
=
ctx
.
needs_input_grad
CAt
,
subA
,
A
=
ctx
.
tensors
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
state
=
ctx
.
state
grad_A
=
grad_B
=
grad_bias
=
None
if
req_gradBias
:
# compute grad_bias first before changing grad_output dtype
grad_bias
=
grad_output
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
# Cast grad_output to fp16
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]
).
contiguous
()
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
.
to
(
torch
.
float16
))
if
req_gradB
:
grad_B
=
torch
.
matmul
(
grad_output
.
t
(),
A
)
if
req_gradA
:
if
state
.
CBt
is
not
None
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
elif
state
.
CB
is
not
None
:
CB
=
state
.
CB
.
to
(
ctx
.
dtype_A
,
copy
=
True
).
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
mul
(
1.
/
127.0
))
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
else
:
raise
Exception
(
'State must contain either CBt or CB matrix for backward'
)
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
def
matmul
(
def
matmul
(
A
:
tensor
,
A
:
tensor
,
B
:
tensor
,
B
:
tensor
,
...
@@ -479,7 +663,7 @@ def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tens
...
@@ -479,7 +663,7 @@ def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tens
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
)
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
)
def
matmul
(
def
matmul
_mixed
(
A
:
tensor
,
A
:
tensor
,
B
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
out
:
tensor
=
None
,
...
@@ -490,4 +674,4 @@ def matmul(
...
@@ -490,4 +674,4 @@ def matmul(
state
=
state
or
MatmulLtState
()
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
if
threshold
>
0.0
:
state
.
threshold
=
threshold
state
.
threshold
=
threshold
return
MatMul8bit
Lt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
return
MatMul8bit
Mixed
.
apply
(
A
,
B
,
out
,
bias
,
state
)
tests/test_autograd.py
View file @
c5c38ca1
...
@@ -239,7 +239,7 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
...
@@ -239,7 +239,7 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2
.
append
(
0
)
dim2
.
append
(
0
)
decomp
=
[
0.0
,
6.0
]
decomp
=
[
0.0
,
6.0
]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
_mixed
)]
str_funcs
=
[
"matmul"
]
str_funcs
=
[
"matmul"
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment