Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
bfa0e332
Commit
bfa0e332
authored
Aug 01, 2022
by
Titus von Koeller
Browse files
ran black and isort for coherent code formatting
parent
597a8521
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2449 additions
and
991 deletions
+2449
-991
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+11
-9
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+88
-46
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+15
-8
bitsandbytes/cuda_setup.py
bitsandbytes/cuda_setup.py
+46
-32
bitsandbytes/debug_cli.py
bitsandbytes/debug_cli.py
+0
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+981
-414
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+4
-4
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+147
-50
bitsandbytes/optim/__init__.py
bitsandbytes/optim/__init__.py
+3
-3
bitsandbytes/optim/adagrad.py
bitsandbytes/optim/adagrad.py
+93
-21
bitsandbytes/optim/adam.py
bitsandbytes/optim/adam.py
+128
-51
bitsandbytes/optim/adamw.py
bitsandbytes/optim/adamw.py
+85
-19
bitsandbytes/optim/lamb.py
bitsandbytes/optim/lamb.py
+97
-20
bitsandbytes/optim/lars.py
bitsandbytes/optim/lars.py
+126
-41
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+380
-185
bitsandbytes/optim/rmsprop.py
bitsandbytes/optim/rmsprop.py
+94
-21
bitsandbytes/optim/sgd.py
bitsandbytes/optim/sgd.py
+88
-21
bitsandbytes/utils.py
bitsandbytes/utils.py
+3
-1
quicktest.py
quicktest.py
+47
-33
setup.py
setup.py
+13
-11
No files found.
bitsandbytes/__init__.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.
nn
import
modules
from
.autograd._functions
import
mm_cublas
,
bmm
_cublas
,
m
atmul
_cublas
,
matmul
,
MatmulLtState
from
.
autograd._functions
import
(
MatmulLtState
,
bmm_cublas
,
matmul
,
matmul
_cublas
,
m
m
_cublas
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
if
COMPILED_WITH_CUDA
:
from
.optim
import
adam
__pdoc__
=
{
'libbitsandbytes'
:
False
,
'optim.optimizer.Optimizer8bit'
:
False
,
'optim.optimizer.MockArgs'
:
False
}
__pdoc__
=
{
"libbitsandbytes"
:
False
,
"optim.optimizer.Optimizer8bit"
:
False
,
"optim.optimizer.MockArgs"
:
False
,
}
bitsandbytes/autograd/_functions.py
View file @
bfa0e332
from
dataclasses
import
dataclass
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
from
dataclasses
import
dataclass
tensor
=
torch
.
Tensor
'''
"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
'''
"""
class
GlobalOutlierPooler
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
self
.
outliers
=
set
()
...
...
@@ -29,25 +32,29 @@ class GlobalOutlierPooler(object):
return
cls
.
_instance
def
add_outliers
(
self
,
outlier_idx
,
feature_dim
):
if
self
.
model_dim
is
None
:
self
.
model_dim
=
feature_dim
if
feature_dim
!=
self
.
model_dim
:
return
# we do not encode outliers for the 2nd FFN layer
if
self
.
model_dim
is
None
:
self
.
model_dim
=
feature_dim
if
feature_dim
!=
self
.
model_dim
:
return
# we do not encode outliers for the 2nd FFN layer
self
.
outliers
.
update
(
outlier_idx
.
tolist
())
def
get_current_outlier_idx
(
self
):
return
torch
.
Tensor
(
list
(
self
.
outliers
)).
to
(
torch
.
int64
)
class
MatMul8bit
(
torch
.
autograd
.
Function
):
class
MatMul8bit
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
quant_type
=
'
vector
'
,
precision
=
[
8
,
8
,
8
]):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
quant_type
=
"
vector
"
,
precision
=
[
8
,
8
,
8
]):
if
precision
[
0
]
!=
8
:
with
torch
.
no_grad
():
output
=
torch
.
matmul
(
A
,
B
)
else
:
if
len
(
B
.
shape
)
==
2
:
dim
=
0
else
:
dim
=
1
if
len
(
B
.
shape
)
==
2
:
dim
=
0
else
:
dim
=
1
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=-
1
,
quant_type
=
quant_type
)
qB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
dim
,
quant_type
=
quant_type
)
iout
=
F
.
igemm
(
qA
,
qB
)
...
...
@@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function):
else
:
if
len
(
B
.
shape
)
==
2
and
len
(
A
.
shape
)
==
3
:
grad_output
=
grad_output
.
contiguous
()
if
not
grad_output
.
is_contiguous
():
grad_output
.
contiguous
()
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
if
not
A
.
is_contiguous
():
A
=
A
.
contiguous
()
qA
,
S2
=
F
.
vectorwise_quant
(
A
.
view
(
-
1
,
A
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
if
not
grad_output
.
is_contiguous
():
grad_output
.
contiguous
()
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
,
)
if
not
A
.
is_contiguous
():
A
=
A
.
contiguous
()
qA
,
S2
=
F
.
vectorwise_quant
(
A
.
view
(
-
1
,
A
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
t
(),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
t
(),
S1
,
grad_output
.
dtype
,
quant_type
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
t
(),
S1
,
grad_output
.
dtype
,
quant_type
)
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
permute
(
permute_dim
),
S1
,
grad_output
.
dtype
,
quant_type
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
permute
(
permute_dim
),
S1
,
grad_output
.
dtype
,
quant_type
,
)
if
A
.
requires_grad
:
if
len
(
grad_output
.
shape
)
==
3
:
dims
=
[
2
]
else
:
dims
=
[
1
]
if
len
(
grad_output
.
shape
)
==
3
:
dims
=
[
2
]
else
:
dims
=
[
1
]
if
len
(
B
.
shape
)
==
3
:
# bio -> boi
...
...
@@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function):
with
torch
.
no_grad
():
grad_A
=
torch
.
matmul
(
grad_output
,
B
.
permute
(
permute_dim
))
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
)
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
)
return
grad_A
,
grad_B
,
None
,
None
,
None
...
...
@@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply
bmm_cublas
=
MatMul8bit
.
apply
matmul_cublas
=
MatMul8bit
.
apply
@
dataclass
class
MatmulLtState
:
CB
=
None
...
...
@@ -159,7 +191,6 @@ class MatmulLtState:
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
# 1. Quantize A
...
...
@@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function):
requires_gradB
=
B
.
requires_grad
formatB
=
state
.
formatB
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
assert
A
.
dtype
==
torch
.
float16
,
f
'The input data type needs to be fp16 but
{
A
.
dtype
}
was found!'
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
assert
(
A
.
dtype
==
torch
.
float16
),
f
"The input data type needs to be fp16 but
{
A
.
dtype
}
was found!"
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
...
...
@@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
#state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
#
state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#
if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
...
...
@@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
#if state.idx is not None:
#
if state.idx is not None:
# # extract outliers
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
#else:
#
else:
# subA = None
else
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
subA
=
None
# 2. Quantize B
if
state
.
has_fp16_weights
:
has_grad
=
(
True
if
(
getattr
(
B
,
'
grad
'
,
None
)
is
not
None
)
else
False
)
has_grad
=
True
if
(
getattr
(
B
,
"
grad
"
,
None
)
is
not
None
)
else
False
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
if
is_transposed
:
B
=
B
.
contiguous
()
if
is_transposed
:
B
=
B
.
contiguous
()
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
state
.
reset_grads
()
...
...
@@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
state
.
idx
=
outlier_idx
#state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
#
state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#
if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
#else:
#
else:
# state.idx = outlier_idx
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
state
.
subB
=
(
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
)
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
subA
=
A
[:,
state
.
idx
.
long
()]
...
...
@@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
# 3. Matmul
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
...
...
@@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
#
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func
=
torch
.
clone
return
clone_func
(
output
.
view
(
output_shape
))
...
...
@@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
state
=
ctx
.
state
assert
state
.
has_fp16_weights
,
'
Backprop only supported for fp16 weights.
'
assert
state
.
has_fp16_weights
,
"
Backprop only supported for fp16 weights.
"
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
...
...
@@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
'
col32
'
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"
col32
"
,
transpose
=
True
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
'
col32
'
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"
col32
"
)
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
...
...
@@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function):
matmul
=
MatMul8bitLt
.
apply
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
):
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
):
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
state
.
threshold
=
threshold
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
state
)
bitsandbytes/cextension.py
View file @
bfa0e332
import
ctypes
as
ct
import
os
from
warnings
import
warn
from
bitsandbytes.cuda_setup
import
evaluate_cuda_setup
...
...
@@ -8,17 +9,21 @@ class CUDALibrary_Singleton(object):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
self
.
context
=
{}
binary_name
=
evaluate_cuda_setup
()
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
__file__
)
+
f
'/
{
binary_name
}
'
):
print
(
f
'TODO: compile library for specific version:
{
binary_name
}
'
)
print
(
'defaulting to libbitsandbytes.so'
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
'/libbitsandbytes.so'
)
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
__file__
)
+
f
"/
{
binary_name
}
"
):
print
(
f
"TODO: compile library for specific version:
{
binary_name
}
"
)
print
(
"defaulting to libbitsandbytes.so"
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
"/libbitsandbytes.so"
)
else
:
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
f
'/
{
binary_name
}
'
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
f
"/
{
binary_name
}
"
)
@
classmethod
def
get_instance
(
cls
):
...
...
@@ -35,6 +40,8 @@ try:
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
COMPILED_WITH_CUDA
=
True
except
AttributeError
:
warn
(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
warn
(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
COMPILED_WITH_CUDA
=
False
bitsandbytes/cuda_setup.py
View file @
bfa0e332
...
...
@@ -18,31 +18,36 @@ evaluation:
- based on that set the default path
"""
import
ctypes
import
shlex
import
subprocess
from
os
import
environ
as
env
from
pathlib
import
Path
from
typing
import
Set
,
Union
from
.utils
import
warn_of_missing_prerequisite
,
print_err
import
ctypes
import
shlex
import
subprocess
from
.utils
import
print_err
,
warn_of_missing_prerequisite
def
execute_and_return
(
strCMD
):
proc
=
subprocess
.
Popen
(
shlex
.
split
(
strCMD
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
proc
=
subprocess
.
Popen
(
shlex
.
split
(
strCMD
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
out
,
err
=
proc
.
communicate
()
out
,
err
=
out
.
decode
(
"UTF-8"
).
strip
(),
err
.
decode
(
"UTF-8"
).
strip
()
return
out
,
err
def
check_cuda_result
(
cuda
,
result_val
):
if
result_val
!=
0
:
cuda
.
cuGetErrorString
(
result_val
,
ctypes
.
byref
(
error_str
))
print
(
f
"Count not initialize CUDA - failure!"
)
raise
Exception
(
'
CUDA exception!
'
)
raise
Exception
(
"
CUDA exception!
"
)
return
result_val
# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
def
get_compute_capability
():
libnames
=
(
'
libcuda.so
'
,
'
libcuda.dylib
'
,
'
cuda.dll
'
)
libnames
=
(
"
libcuda.so
"
,
"
libcuda.dylib
"
,
"
cuda.dll
"
)
for
libname
in
libnames
:
try
:
cuda
=
ctypes
.
CDLL
(
libname
)
...
...
@@ -51,8 +56,7 @@ def get_compute_capability():
else
:
break
else
:
raise
OSError
(
"could not load any of: "
+
' '
.
join
(
libnames
))
raise
OSError
(
"could not load any of: "
+
" "
.
join
(
libnames
))
nGpus
=
ctypes
.
c_int
()
cc_major
=
ctypes
.
c_int
()
...
...
@@ -69,39 +73,43 @@ def get_compute_capability():
ccs
=
[]
for
i
in
range
(
nGpus
.
value
):
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
i
))
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceComputeCapability
(
ctypes
.
byref
(
cc_major
),
ctypes
.
byref
(
cc_minor
),
device
))
ccs
.
append
(
f
'
{
cc_major
.
value
}
.
{
cc_minor
.
value
}
'
)
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceComputeCapability
(
ctypes
.
byref
(
cc_major
),
ctypes
.
byref
(
cc_minor
),
device
),
)
ccs
.
append
(
f
"
{
cc_major
.
value
}
.
{
cc_minor
.
value
}
"
)
#TODO: handle different compute capabilities; for now, take the max
#
TODO: handle different compute capabilities; for now, take the max
ccs
.
sort
()
return
ccs
[
-
1
]
# return ccs[-1]
return
ccs
CUDA_RUNTIME_LIB
:
str
=
"libcudart.so"
def
tokenize_paths
(
paths
:
str
)
->
Set
[
Path
]:
return
{
Path
(
ld_path
)
for
ld_path
in
paths
.
split
(
':'
)
if
ld_path
}
return
{
Path
(
ld_path
)
for
ld_path
in
paths
.
split
(
":"
)
if
ld_path
}
def
get_cuda_runtime_lib_path
(
# TODO: replace this with logic for all paths in env vars
LD_LIBRARY_PATH
:
Union
[
str
,
None
]
=
env
.
get
(
"LD_LIBRARY_PATH"
)
)
->
Union
[
Path
,
None
]:
""" # TODO: add doc-string
"""
"""# TODO: add doc-string"""
if
not
LD_LIBRARY_PATH
:
warn_of_missing_prerequisite
(
'
LD_LIBRARY_PATH is completely missing from environment!
'
"
LD_LIBRARY_PATH is completely missing from environment!
"
)
return
None
ld_library_paths
:
Set
[
Path
]
=
tokenize_paths
(
LD_LIBRARY_PATH
)
non_existent_directories
:
Set
[
Path
]
=
{
path
for
path
in
ld_library_paths
if
not
path
.
exists
()
non_existent_directories
:
Set
[
Path
]
=
{
path
for
path
in
ld_library_paths
if
not
path
.
exists
()
}
if
non_existent_directories
:
...
...
@@ -111,7 +119,8 @@ def get_cuda_runtime_lib_path(
)
cuda_runtime_libs
:
Set
[
Path
]
=
{
path
/
CUDA_RUNTIME_LIB
for
path
in
ld_library_paths
path
/
CUDA_RUNTIME_LIB
for
path
in
ld_library_paths
if
(
path
/
CUDA_RUNTIME_LIB
).
is_file
()
}
-
non_existent_directories
...
...
@@ -126,26 +135,31 @@ def get_cuda_runtime_lib_path(
single_cuda_runtime_lib_dir
=
next
(
iter
(
cuda_runtime_libs
))
return
single_cuda_runtime_lib_dir
def
evaluate_cuda_setup
():
cuda_path
=
get_cuda_runtime_lib_path
()
cc
=
get_compute_capability
()
binary_name
=
'
libbitsandbytes_cpu.so
'
binary_name
=
"
libbitsandbytes_cpu.so
"
if
not
(
has_gpu
:
=
bool
(
cc
)):
print
(
'WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library...'
)
print
(
"WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library..."
)
return
binary_name
has_cublaslt
=
cc
in
[
'
7.5
'
,
'
8.0
'
,
'
8.6
'
]
has_cublaslt
=
cc
in
[
"
7.5
"
,
"
8.0
"
,
"
8.6
"
]
# TODO:
# TODO:
# (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
cuda_home
=
str
(
Path
(
cuda_path
).
parent
.
parent
)
ls_output
,
err
=
execute_and_return
(
f
'
{
cuda_home
}
/bin/nvcc --version'
)
cuda_version
=
ls_output
.
split
(
'
\n
'
)[
3
].
split
(
','
)[
-
1
].
strip
().
lower
().
replace
(
'v'
,
''
)
major
,
minor
,
revision
=
cuda_version
.
split
(
'.'
)
cuda_version_string
=
f
'
{
major
}{
minor
}
'
ls_output
,
err
=
execute_and_return
(
f
"
{
cuda_home
}
/bin/nvcc --version"
)
cuda_version
=
(
ls_output
.
split
(
"
\n
"
)[
3
].
split
(
","
)[
-
1
].
strip
().
lower
().
replace
(
"v"
,
""
)
)
major
,
minor
,
revision
=
cuda_version
.
split
(
"."
)
cuda_version_string
=
f
"
{
major
}{
minor
}
"
binary_name
=
f
'libbitsandbytes_cuda
{
cuda_version_string
}
_
{
(
"cublaslt"
if
has_cublaslt
else
""
)
}
.so'
...
...
bitsandbytes/debug_cli.py
View file @
bfa0e332
import
typer
cli
=
typer
.
Typer
()
...
...
bitsandbytes/functional.py
View file @
bfa0e332
This diff is collapsed.
Click to expand it.
bitsandbytes/nn/__init__.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.modules
import
StableEmbedding
,
Linear8bit
,
Linear8bitLt
,
Int8Params
from
.modules
import
Int8Params
,
Linear8bit
,
Linear8bitLt
,
StableEmbedding
bitsandbytes/nn/modules.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
bitsandbytes
as
bnb
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypeVar
,
Union
,
overload
)
from
typing
import
Union
,
Tuple
,
Any
,
Callable
,
Iterator
,
Set
,
Optional
,
overload
,
TypeVar
,
Mapping
,
Dict
from
torch
import
Tensor
,
device
,
dtype
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
device
,
dtype
,
nn
from
torch.nn.parameter
import
Parameter
import
bitsandbytes
as
bnb
from
bitsandbytes.optim
import
GlobalOptimManager
T
=
TypeVar
(
'T'
,
bound
=
'torch.nn.Module'
)
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
class
StableEmbedding
(
torch
.
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
)
->
None
:
super
(
StableEmbedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
)
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.0
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
,
)
->
None
:
super
(
StableEmbedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
)
self
.
norm
=
torch
.
nn
.
LayerNorm
(
embedding_dim
)
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
'weight'
,
{
'optim_bits'
:
32
})
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
"weight"
,
{
"optim_bits"
:
32
}
)
def
reset_parameters
(
self
)
->
None
:
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
self
.
_fill_padding_idx_with_zero
()
'''
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
"""
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
'''
"""
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
:
with
torch
.
no_grad
():
...
...
@@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding):
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
emb
=
F
.
embedding
(
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
return
self
.
norm
(
emb
)
class
Embedding
(
torch
.
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
)
->
None
:
super
(
Embedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
)
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
'weight'
,
{
'optim_bits'
:
32
})
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.0
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
,
)
->
None
:
super
(
Embedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
)
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
"weight"
,
{
"optim_bits"
:
32
}
)
def
reset_parameters
(
self
)
->
None
:
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
self
.
_fill_padding_idx_with_zero
()
'''
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
"""
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
'''
"""
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
:
with
torch
.
no_grad
():
...
...
@@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding):
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
emb
=
F
.
embedding
(
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
return
emb
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
SCB
=
None
...
...
@@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter):
del
CBt
del
SCBt
self
.
data
=
CB
setattr
(
self
,
'
CB
'
,
CB
)
setattr
(
self
,
'
SCB
'
,
SCB
)
setattr
(
self
,
"
CB
"
,
CB
)
setattr
(
self
,
"
SCB
"
,
SCB
)
return
self
@
overload
def
to
(
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
non_blocking
:
bool
=
...)
->
T
:
def
to
(
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
non_blocking
:
bool
=
...,
)
->
T
:
...
@
overload
...
...
@@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter):
...
def
to
(
self
,
*
args
,
**
kwargs
):
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
device
is
not
None
and
device
.
type
==
'cuda'
and
self
.
data
.
device
.
type
==
'cpu'
:
return
self
.
cuda
(
device
)
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
(
device
is
not
None
and
device
.
type
==
"cuda"
and
self
.
data
.
device
.
type
==
"cpu"
):
return
self
.
cuda
(
device
)
else
:
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
)
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
,
)
new_param
.
CB
=
self
.
CB
new_param
.
SCB
=
self
.
SCB
return
new_param
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
,
):
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
index
=
index
self
.
index
=
index
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
...
...
@@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear):
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
#assert not self.state.has_fp16_weights
#if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
# assert not self.state.has_fp16_weights
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
state
=
self
.
state
)
...
...
@@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear):
return
out
class
Linear8bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
quant_type
=
'vector'
,
index
=
None
,
args
=
None
,
sparse_decomp
=
False
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
quant_type
=
"vector"
,
index
=
None
,
args
=
None
,
sparse_decomp
=
False
,
):
super
(
Linear8bit
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
quant_type
=
quant_type
self
.
index
=
index
...
...
@@ -178,15 +266,24 @@ class Linear8bit(nn.Linear):
self
.
iter
+=
1
if
self
.
iter
%
self
.
args
.
clip_freq
==
0
:
with
torch
.
no_grad
():
maxval
,
maxidx
=
torch
.
topk
(
torch
.
abs
(
self
.
weight
.
flatten
()),
k
=
self
.
args
.
clip_idx
)
maxval
,
maxidx
=
torch
.
topk
(
torch
.
abs
(
self
.
weight
.
flatten
()),
k
=
self
.
args
.
clip_idx
)
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
print
(
'
clip
'
,
maxval
[
-
1
].
item
())
print
(
"
clip
"
,
maxval
[
-
1
].
item
())
self
.
weight
.
clip_
(
-
maxval
[
-
1
],
maxval
[
-
1
])
if
self
.
args
is
not
None
:
out
=
bnb
.
nn
.
functional
.
sparse_decomposed_linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
qval
=
self
.
args
.
sparse_decomp_val
,
quant_type
=
self
.
args
.
quant_type
)
out
=
bnb
.
nn
.
functional
.
sparse_decomposed_linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
qval
=
self
.
args
.
sparse_decomp_val
,
quant_type
=
self
.
args
.
quant_type
,
)
else
:
out
=
bnb
.
nn
.
functional
.
linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
quant_type
=
self
.
args
.
quant_type
)
out
=
bnb
.
nn
.
functional
.
linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
quant_type
=
self
.
args
.
quant_type
)
return
out
bitsandbytes/optim/__init__.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.cextension
import
COMPILED_WITH_CUDA
...
...
bitsandbytes/optim/adagrad.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
Adagrad
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
...
...
@@ -14,15 +27,39 @@ class Adagrad(Optimizer1State):
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
super
(
Adagrad
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
ValueError
(
"Lr Decay != 0.0 not supported!"
)
super
(
Adagrad
,
self
).
__init__
(
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adagrad8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
8
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
8
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
...
...
@@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State):
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
raise
ValueError
(
'
Lr Decay != 0.0 not supported!
'
)
raise
ValueError
(
"
Lr Decay != 0.0 not supported!
"
)
assert
block_wise
super
(
Adagrad8bit
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
super
(
Adagrad8bit
,
self
).
__init__
(
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adagrad32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
...
...
@@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State):
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
super
(
Adagrad32bit
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
ValueError
(
"Lr Decay != 0.0 not supported!"
)
super
(
Adagrad32bit
,
self
).
__init__
(
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/adam.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
...
...
@@ -8,29 +8,97 @@ import os
import
torch
import
torch.distributed
as
dist
from
bitsandbytes.optim.optimizer
import
Optimizer2State
import
bitsandbytes.functional
as
F
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
Adam
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
Adam
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adam8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
Adam8bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam8bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adam32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
Adam32bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam32bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AnalysisAdam
(
torch
.
optim
.
Optimizer
):
...
...
@@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer):
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
bnb_analysis
=
'
dynamic-blockwise
'
,
savedir
=
None
bnb_analysis
=
"
dynamic-blockwise
"
,
savedir
=
None
,
):
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
...
...
@@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer):
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p_data_fp32
)
# Exponential moving average of squared gradient values
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
'abserrors'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
'relerrors'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
'counts'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"abserrors"
]
=
torch
.
zeros
(
(
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"relerrors"
]
=
torch
.
zeros
(
(
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"counts"
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
...
...
@@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer):
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
step_size
=
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
e
=
state
[
'
abserrors
'
]
rele
=
state
[
'
relerrors
'
]
counts
=
state
[
'
counts
'
]
e
=
state
[
"
abserrors
"
]
rele
=
state
[
"
relerrors
"
]
counts
=
state
[
"
counts
"
]
if
group
[
"weight_decay"
]
!=
0
:
p_data_fp32
.
add_
(
...
...
@@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer):
if
amsgrad
:
max_exp_avg_sq
=
state
[
"max_exp_avg_sq"
]
# Decay the first and second moment running average coefficient
exp_avg
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
update_fp32
=
exp_avg
/
denom
update_fp32
=
exp_avg
/
denom
if
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
:
if
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
:
# embedding layer or too small
p_data_fp32
+=
-
step_size
*
update_fp32
p_data_fp32
+=
-
step_size
*
update_fp32
else
:
if
self
.
analysis
==
'
dynamic-blockwise
'
:
if
self
.
analysis
==
"
dynamic-blockwise
"
:
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize_blockwise
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize_blockwise
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize_blockwise
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize_blockwise
(
C2
,
S2
)
elif
self
.
analysis
==
'
dynamic
'
:
elif
self
.
analysis
==
"
dynamic
"
:
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
elif
self
.
analysis
==
'
linear
'
:
elif
self
.
analysis
==
"
linear
"
:
code1
=
F
.
create_linear_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_linear_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
elif
self
.
analysis
==
'
quantile
'
:
elif
self
.
analysis
==
"
quantile
"
:
code1
=
F
.
estimate_quantiles
(
exp_avg
)
code2
=
F
.
estimate_quantiles
(
exp_avg_sq
)
C1
=
F
.
quantize_no_absmax
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize_no_absmax
(
C1
,
code1
)
C2
=
F
.
quantize_no_absmax
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize_no_absmax
(
C2
,
code2
)
elif
self
.
analysis
==
'
my-quantization-routine
'
:
elif
self
.
analysis
==
"
my-quantization-routine
"
:
pass
# 1. get code
# 2. quantize
# 3. dequantize
# Error will be calculated automatically!
else
:
raise
ValueError
(
f
'
Invalid analysis value:
{
self
.
analysis
}
!
'
)
raise
ValueError
(
f
"
Invalid analysis value:
{
self
.
analysis
}
!
"
)
denom
=
state2
.
sqrt
().
add_
(
group
[
"eps"
])
update_8bit
=
state1
/
denom
update_8bit
=
state1
/
denom
abserr
=
torch
.
abs
(
update_8bit
-
update_fp32
)
relerr
=
abserr
/
torch
.
abs
(
update_fp32
+
1e-6
)
abserr
=
torch
.
abs
(
update_8bit
-
update_fp32
)
relerr
=
abserr
/
torch
.
abs
(
update_fp32
+
1e-6
)
C1
,
C2
=
C1
.
int
(),
C2
.
int
()
F
.
histogram_scatter_add_2d
(
e
,
C1
.
int
(),
C2
.
int
(),
abserr
)
F
.
histogram_scatter_add_2d
(
rele
,
C1
.
int
(),
C2
.
int
(),
relerr
)
F
.
histogram_scatter_add_2d
(
counts
,
C1
.
int
(),
C2
.
int
(),
torch
.
ones_like
(
abserr
))
p_data_fp32
+=
-
step_size
*
update_fp32
F
.
histogram_scatter_add_2d
(
counts
,
C1
.
int
(),
C2
.
int
(),
torch
.
ones_like
(
abserr
)
)
p_data_fp32
+=
-
step_size
*
update_fp32
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
if
self
.
savedir
!=
''
and
state
[
'step'
]
%
100
==
0
:
if
not
os
.
path
.
exists
(
self
.
savedir
):
os
.
makedirs
(
self
.
savedir
)
shapestr
=
'_'
.
join
([
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
])
pathe
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_abserr.pkl'
)
pathrele
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_relerr.pkl'
)
pathcounts
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_counts.pkl'
)
if
self
.
savedir
!=
""
and
state
[
"step"
]
%
100
==
0
:
if
not
os
.
path
.
exists
(
self
.
savedir
):
os
.
makedirs
(
self
.
savedir
)
shapestr
=
"_"
.
join
([
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
])
pathe
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_abserr.pkl"
)
pathrele
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_relerr.pkl"
)
pathcounts
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_counts.pkl"
)
torch
.
save
(
e
,
pathe
)
torch
.
save
(
rele
,
pathrele
)
torch
.
save
(
counts
,
pathcounts
)
...
...
@@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer):
if
p
.
data
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
p
.
data
.
copy_
(
p_data_fp32
)
return
loss
bitsandbytes/optim/adamw.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
AdamW
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AdamW8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW8bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW8bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AdamW32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW32bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
class
AdamW32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW32bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/lamb.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
LAMB
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
super
(
LAMB
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
class
LAMB8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
super
(
LAMB8bit
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
class
LAMB32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
super
(
LAMB32bit
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
class
LAMB8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB8bit
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
class
LAMB32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB32bit
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
bitsandbytes/optim/lars.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
torch.optim
import
Optimizer
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
LARS
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
super
(
LARS
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS
,
self
).
__init__
(
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
LARS8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
super
(
LARS8bit
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS8bit
,
self
).
__init__
(
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
LARS32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
super
(
LARS32bit
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS32bit
,
self
).
__init__
(
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
PytorchLARS
(
Optimizer
):
def
__init__
(
self
,
params
,
lr
=
0.01
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
max_unorm
=
0.02
):
def
__init__
(
self
,
params
,
lr
=
0.01
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
max_unorm
=
0.02
,
):
if
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
momentum
<
0.0
:
...
...
@@ -45,8 +123,14 @@ class PytorchLARS(Optimizer):
if
weight_decay
<
0.0
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
dampening
=
dampening
,
weight_decay
=
weight_decay
,
nesterov
=
nesterov
,
max_unorm
=
max_unorm
)
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
dampening
=
dampening
,
weight_decay
=
weight_decay
,
nesterov
=
nesterov
,
max_unorm
=
max_unorm
,
)
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
super
(
PytorchLARS
,
self
).
__init__
(
params
,
defaults
)
...
...
@@ -54,7 +138,7 @@ class PytorchLARS(Optimizer):
def
__setstate__
(
self
,
state
):
super
(
PytorchLARS
,
self
).
__setstate__
(
state
)
for
group
in
self
.
param_groups
:
group
.
setdefault
(
'
nesterov
'
,
False
)
group
.
setdefault
(
"
nesterov
"
,
False
)
@
torch
.
no_grad
()
def
step
(
self
,
closure
=
None
):
...
...
@@ -73,15 +157,16 @@ class PytorchLARS(Optimizer):
params_with_grad
=
[]
d_p_list
=
[]
momentum_buffer_list
=
[]
weight_decay
=
group
[
'
weight_decay
'
]
momentum
=
group
[
'
momentum
'
]
dampening
=
group
[
'
dampening
'
]
nesterov
=
group
[
'
nesterov
'
]
max_unorm
=
group
[
'
max_unorm
'
]
lr
=
group
[
'
lr
'
]
weight_decay
=
group
[
"
weight_decay
"
]
momentum
=
group
[
"
momentum
"
]
dampening
=
group
[
"
dampening
"
]
nesterov
=
group
[
"
nesterov
"
]
max_unorm
=
group
[
"
max_unorm
"
]
lr
=
group
[
"
lr
"
]
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
for
p
in
group
[
"params"
]:
if
p
.
grad
is
None
:
continue
state
=
self
.
state
[
p
]
d_p
=
p
.
grad
...
...
@@ -89,16 +174,16 @@ class PytorchLARS(Optimizer):
d_p
=
d_p
.
add
(
param
,
alpha
=
weight_decay
)
if
momentum
!=
0
:
buf
=
state
.
get
(
'
momentum_buffer
'
,
None
)
buf
=
state
.
get
(
"
momentum_buffer
"
,
None
)
if
buf
is
None
:
buf
=
torch
.
clone
(
d_p
).
detach
()
state
[
'
momentum_buffer
'
]
=
buf
state
[
"
momentum_buffer
"
]
=
buf
else
:
buf
.
mul_
(
momentum
).
add_
(
d_p
,
alpha
=
1
-
dampening
)
if
nesterov
:
update
=
d_p
+
buf
*
momentum
update
=
d_p
+
buf
*
momentum
else
:
update
=
buf
...
...
@@ -107,9 +192,9 @@ class PytorchLARS(Optimizer):
assert
p
.
dtype
==
torch
.
float32
pnorm
=
torch
.
norm
(
p
.
detach
())
unorm
=
torch
.
norm
(
update
)
if
unorm
>
max_unorm
*
pnorm
:
update_scale
=
max_unorm
*
pnorm
/
unorm
if
unorm
>
max_unorm
*
pnorm
:
update_scale
=
max_unorm
*
pnorm
/
unorm
p
.
add_
(
update
,
alpha
=-
lr
*
update_scale
)
p
.
add_
(
update
,
alpha
=-
lr
*
update_scale
)
return
loss
bitsandbytes/optim/optimizer.py
View file @
bfa0e332
This diff is collapsed.
Click to expand it.
bitsandbytes/optim/rmsprop.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
RMSprop
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
super
(
RMSprop
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop
,
self
).
__init__
(
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
RMSprop8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
super
(
RMSprop8bit
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop8bit
,
self
).
__init__
(
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
RMSprop32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
super
(
RMSprop32bit
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop32bit
,
self
).
__init__
(
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/sgd.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
SGD
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
super
(
SGD
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD
,
self
).
__init__
(
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
SGD8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
super
(
SGD8bit
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD8bit
,
self
).
__init__
(
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
SGD32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
super
(
SGD32bit
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD32bit
,
self
).
__init__
(
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/utils.py
View file @
bfa0e332
import
sys
def
print_err
(
s
:
str
)
->
None
:
print
(
s
,
file
=
sys
.
stderr
)
def
warn_of_missing_prerequisite
(
s
:
str
)
->
None
:
print_err
(
'
WARNING, missing pre-requisite:
'
+
s
)
print_err
(
"
WARNING, missing pre-requisite:
"
+
s
)
quicktest.py
View file @
bfa0e332
from
itertools
import
product
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
from
itertools
import
product
def
test_igemmlt
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
k
=
25
for
i
in
range
(
k
):
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
'
col32
'
)
B2
,
SB
=
F
.
transform
(
B
,
'
colx
'
)
A2
,
SA
=
F
.
transform
(
A
,
"
col32
"
)
B2
,
SB
=
F
.
transform
(
B
,
"
colx
"
)
if
dims
==
2
:
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
),
'col32'
)
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"col32"
,
)
else
:
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
),
'col32'
)
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"col32"
,
)
F
.
igemmlt
(
A2
,
B2
,
C2
,
SA
,
SB
,
SC
)
C3
,
S
=
F
.
transform
(
C2
,
'
row
'
,
state
=
SC
)
#torch.testing.assert_allclose(C1, C3.float())
#print(C1)
#print(C2)
#print(C3)
C3
,
S
=
F
.
transform
(
C2
,
"
row
"
,
state
=
SC
)
#
torch.testing.assert_allclose(C1, C3.float())
#
print(C1)
#
print(C2)
#
print(C3)
allclose
=
torch
.
allclose
(
C1
,
C3
.
float
())
if
allclose
:
print
(
C1
)
...
...
@@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
print
(
C3
)
## transposed
#A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
#if dims == 2:
#
A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
#
if dims == 2:
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float().t())
#elif dims == 3:
#
elif dims == 3:
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.float(), A.t().float())
# C1 = C1.permute([2, 0, 1])
#A2, SA = F.transform(A, 'col32')
#B2, SB = F.transform(B, 'colx')
#if dims == 2:
#
A2, SA = F.transform(A, 'col32')
#
B2, SB = F.transform(B, 'colx')
#
if dims == 2:
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
#else:
#
else:
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
# state = (C2.shape, 'row', A.shape[0])
# C2, SC = F.transform(C2, 'col32', state=state)
#F.igemmlt(A2, B2, C2, SA, SB, SC)
#C3, S = F.transform(C2, 'row', state=SC, ld=[0])
#torch.testing.assert_allclose(C1, C3.float())
#
F.igemmlt(A2, B2, C2, SA, SB, SC)
#
C3, S = F.transform(C2, 'row', state=SC, ld=[0])
#
torch.testing.assert_allclose(C1, C3.float())
## weight update
#if dims == 3:
#
if dims == 3:
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
...
...
@@ -73,18 +87,18 @@ dims = (2, 3)
ldb
=
[
0
]
n
=
2
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
for
ldb
in
range
(
32
,
4096
,
32
):
#
for ldb in [None]:
#
for ldb in [None]:
val
=
test_igemmlt
(
2
,
2
,
2
,
2
,
2
,
ldb
)
if
val
:
print
(
val
,
ldb
)
else
:
print
(
'
nope
'
,
ldb
)
#for val in values:
#test_igemmlt(*val)
print
(
"
nope
"
,
ldb
)
#
for val in values:
#
test_igemmlt(*val)
setup.py
View file @
bfa0e332
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment