Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
bfa0e332
Commit
bfa0e332
authored
Aug 01, 2022
by
Titus von Koeller
Browse files
ran black and isort for coherent code formatting
parent
597a8521
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2449 additions
and
991 deletions
+2449
-991
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+11
-9
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+88
-46
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+15
-8
bitsandbytes/cuda_setup.py
bitsandbytes/cuda_setup.py
+46
-32
bitsandbytes/debug_cli.py
bitsandbytes/debug_cli.py
+0
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+981
-414
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+4
-4
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+147
-50
bitsandbytes/optim/__init__.py
bitsandbytes/optim/__init__.py
+3
-3
bitsandbytes/optim/adagrad.py
bitsandbytes/optim/adagrad.py
+93
-21
bitsandbytes/optim/adam.py
bitsandbytes/optim/adam.py
+128
-51
bitsandbytes/optim/adamw.py
bitsandbytes/optim/adamw.py
+85
-19
bitsandbytes/optim/lamb.py
bitsandbytes/optim/lamb.py
+97
-20
bitsandbytes/optim/lars.py
bitsandbytes/optim/lars.py
+126
-41
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+380
-185
bitsandbytes/optim/rmsprop.py
bitsandbytes/optim/rmsprop.py
+94
-21
bitsandbytes/optim/sgd.py
bitsandbytes/optim/sgd.py
+88
-21
bitsandbytes/utils.py
bitsandbytes/utils.py
+3
-1
quicktest.py
quicktest.py
+47
-33
setup.py
setup.py
+13
-11
No files found.
bitsandbytes/__init__.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.
nn
import
modules
from
.
autograd._functions
import
(
MatmulLtState
,
bmm_cublas
,
matmul
,
from
.autograd._functions
import
mm_cublas
,
bmm
_cublas
,
m
atmul
_cublas
,
matmul
,
MatmulLtState
matmul
_cublas
,
m
m
_cublas
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
if
COMPILED_WITH_CUDA
:
if
COMPILED_WITH_CUDA
:
from
.optim
import
adam
from
.optim
import
adam
__pdoc__
=
{
'libbitsandbytes'
:
False
,
__pdoc__
=
{
'optim.optimizer.Optimizer8bit'
:
False
,
"libbitsandbytes"
:
False
,
'optim.optimizer.MockArgs'
:
False
"optim.optimizer.Optimizer8bit"
:
False
,
}
"optim.optimizer.MockArgs"
:
False
,
}
bitsandbytes/autograd/_functions.py
View file @
bfa0e332
from
dataclasses
import
dataclass
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
dataclasses
import
dataclass
tensor
=
torch
.
Tensor
tensor
=
torch
.
Tensor
'''
"""
This class pools outlier dimensions across layers.
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
are less systematic and occur with low frequency.
'''
"""
class
GlobalOutlierPooler
(
object
):
class
GlobalOutlierPooler
(
object
):
_instance
=
None
_instance
=
None
def
__init__
(
self
):
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
def
initialize
(
self
):
self
.
outliers
=
set
()
self
.
outliers
=
set
()
...
@@ -29,25 +32,29 @@ class GlobalOutlierPooler(object):
...
@@ -29,25 +32,29 @@ class GlobalOutlierPooler(object):
return
cls
.
_instance
return
cls
.
_instance
def
add_outliers
(
self
,
outlier_idx
,
feature_dim
):
def
add_outliers
(
self
,
outlier_idx
,
feature_dim
):
if
self
.
model_dim
is
None
:
self
.
model_dim
=
feature_dim
if
self
.
model_dim
is
None
:
if
feature_dim
!=
self
.
model_dim
:
return
# we do not encode outliers for the 2nd FFN layer
self
.
model_dim
=
feature_dim
if
feature_dim
!=
self
.
model_dim
:
return
# we do not encode outliers for the 2nd FFN layer
self
.
outliers
.
update
(
outlier_idx
.
tolist
())
self
.
outliers
.
update
(
outlier_idx
.
tolist
())
def
get_current_outlier_idx
(
self
):
def
get_current_outlier_idx
(
self
):
return
torch
.
Tensor
(
list
(
self
.
outliers
)).
to
(
torch
.
int64
)
return
torch
.
Tensor
(
list
(
self
.
outliers
)).
to
(
torch
.
int64
)
class
MatMul8bit
(
torch
.
autograd
.
Function
):
class
MatMul8bit
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
quant_type
=
'
vector
'
,
precision
=
[
8
,
8
,
8
]):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
quant_type
=
"
vector
"
,
precision
=
[
8
,
8
,
8
]):
if
precision
[
0
]
!=
8
:
if
precision
[
0
]
!=
8
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
torch
.
matmul
(
A
,
B
)
output
=
torch
.
matmul
(
A
,
B
)
else
:
else
:
if
len
(
B
.
shape
)
==
2
:
dim
=
0
if
len
(
B
.
shape
)
==
2
:
else
:
dim
=
1
dim
=
0
else
:
dim
=
1
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=-
1
,
quant_type
=
quant_type
)
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=-
1
,
quant_type
=
quant_type
)
qB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
dim
,
quant_type
=
quant_type
)
qB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
dim
,
quant_type
=
quant_type
)
iout
=
F
.
igemm
(
qA
,
qB
)
iout
=
F
.
igemm
(
qA
,
qB
)
...
@@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function):
...
@@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function):
else
:
else
:
if
len
(
B
.
shape
)
==
2
and
len
(
A
.
shape
)
==
3
:
if
len
(
B
.
shape
)
==
2
and
len
(
A
.
shape
)
==
3
:
grad_output
=
grad_output
.
contiguous
()
grad_output
=
grad_output
.
contiguous
()
if
not
grad_output
.
is_contiguous
():
grad_output
.
contiguous
()
if
not
grad_output
.
is_contiguous
():
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
grad_output
.
contiguous
()
if
not
A
.
is_contiguous
():
A
=
A
.
contiguous
()
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
qA
,
S2
=
F
.
vectorwise_quant
(
A
.
view
(
-
1
,
A
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
,
)
if
not
A
.
is_contiguous
():
A
=
A
.
contiguous
()
qA
,
S2
=
F
.
vectorwise_quant
(
A
.
view
(
-
1
,
A
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
t
(),
qgrad_output
)
igrad_B
=
F
.
igemm
(
qA
.
t
(),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
t
(),
S1
,
grad_output
.
dtype
,
quant_type
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
t
(),
S1
,
grad_output
.
dtype
,
quant_type
)
else
:
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
permute
(
permute_dim
),
S1
,
grad_output
.
dtype
,
quant_type
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
permute
(
permute_dim
),
S1
,
grad_output
.
dtype
,
quant_type
,
)
if
A
.
requires_grad
:
if
A
.
requires_grad
:
if
len
(
grad_output
.
shape
)
==
3
:
dims
=
[
2
]
if
len
(
grad_output
.
shape
)
==
3
:
else
:
dims
=
[
1
]
dims
=
[
2
]
else
:
dims
=
[
1
]
if
len
(
B
.
shape
)
==
3
:
if
len
(
B
.
shape
)
==
3
:
# bio -> boi
# bio -> boi
...
@@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function):
...
@@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
grad_A
=
torch
.
matmul
(
grad_output
,
B
.
permute
(
permute_dim
))
grad_A
=
torch
.
matmul
(
grad_output
,
B
.
permute
(
permute_dim
))
else
:
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
)
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
)
return
grad_A
,
grad_B
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
...
@@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply
...
@@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply
bmm_cublas
=
MatMul8bit
.
apply
bmm_cublas
=
MatMul8bit
.
apply
matmul_cublas
=
MatMul8bit
.
apply
matmul_cublas
=
MatMul8bit
.
apply
@
dataclass
@
dataclass
class
MatmulLtState
:
class
MatmulLtState
:
CB
=
None
CB
=
None
...
@@ -159,7 +191,6 @@ class MatmulLtState:
...
@@ -159,7 +191,6 @@ class MatmulLtState:
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
# 1. Quantize A
# 1. Quantize A
...
@@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function):
requires_gradB
=
B
.
requires_grad
requires_gradB
=
B
.
requires_grad
formatB
=
state
.
formatB
formatB
=
state
.
formatB
input_shape
=
A
.
shape
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
if
state
.
outlier_pool
is
None
:
assert
A
.
dtype
==
torch
.
float16
,
f
'The input data type needs to be fp16 but
{
A
.
dtype
}
was found!'
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
assert
(
A
.
dtype
==
torch
.
float16
),
f
"The input data type needs to be fp16 but
{
A
.
dtype
}
was found!"
# 1. Quantize A
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
...
@@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
#state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#
state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
#
if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
...
@@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
#if state.idx is not None:
#
if state.idx is not None:
# # extract outliers
# # extract outliers
# CA[:, state.idx] = 0
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
# subA = A[:, state.idx]
#else:
#
else:
# subA = None
# subA = None
else
:
else
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
subA
=
None
subA
=
None
# 2. Quantize B
# 2. Quantize B
if
state
.
has_fp16_weights
:
if
state
.
has_fp16_weights
:
has_grad
=
(
True
if
(
getattr
(
B
,
'
grad
'
,
None
)
is
not
None
)
else
False
)
has_grad
=
True
if
(
getattr
(
B
,
"
grad
"
,
None
)
is
not
None
)
else
False
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
if
is_transposed
:
B
=
B
.
contiguous
()
if
is_transposed
:
B
=
B
.
contiguous
()
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
state
.
reset_grads
()
state
.
reset_grads
()
...
@@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
state
.
idx
=
outlier_idx
state
.
idx
=
outlier_idx
#state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#
state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
#
if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
#else:
#
else:
# state.idx = outlier_idx
# state.idx = outlier_idx
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
state
.
subB
=
(
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
)
CA
[:,
state
.
idx
.
long
()]
=
0
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
subA
=
A
[:,
state
.
idx
.
long
()]
subA
=
A
[:,
state
.
idx
.
long
()]
...
@@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
# 3. Matmul
# 3. Matmul
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
...
@@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
#
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func
=
torch
.
clone
clone_func
=
torch
.
clone
return
clone_func
(
output
.
view
(
output_shape
))
return
clone_func
(
output
.
view
(
output_shape
))
...
@@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt
,
idx
=
ctx
.
tensor_states
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
formatB
=
ctx
.
formatB
state
=
ctx
.
state
state
=
ctx
.
state
assert
state
.
has_fp16_weights
,
'
Backprop only supported for fp16 weights.
'
assert
state
.
has_fp16_weights
,
"
Backprop only supported for fp16 weights.
"
if
len
(
grad_output
.
shape
)
==
3
:
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
...
@@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
if
req_gradB
:
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
'
col32
'
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"
col32
"
,
transpose
=
True
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
if
req_gradA
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
'
col32
'
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"
col32
"
)
if
state
.
CxBt
is
None
:
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
...
@@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function):
matmul
=
MatMul8bitLt
.
apply
matmul
=
MatMul8bitLt
.
apply
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
):
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
):
state
=
state
or
MatmulLtState
()
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
if
threshold
>
0.0
:
state
.
threshold
=
threshold
state
.
threshold
=
threshold
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
state
)
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
state
)
bitsandbytes/cextension.py
View file @
bfa0e332
import
ctypes
as
ct
import
ctypes
as
ct
import
os
import
os
from
warnings
import
warn
from
warnings
import
warn
from
bitsandbytes.cuda_setup
import
evaluate_cuda_setup
from
bitsandbytes.cuda_setup
import
evaluate_cuda_setup
...
@@ -8,17 +9,21 @@ class CUDALibrary_Singleton(object):
...
@@ -8,17 +9,21 @@ class CUDALibrary_Singleton(object):
_instance
=
None
_instance
=
None
def
__init__
(
self
):
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
def
initialize
(
self
):
self
.
context
=
{}
self
.
context
=
{}
binary_name
=
evaluate_cuda_setup
()
binary_name
=
evaluate_cuda_setup
()
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
__file__
)
+
f
'/
{
binary_name
}
'
):
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
__file__
)
+
f
"/
{
binary_name
}
"
):
print
(
f
'TODO: compile library for specific version:
{
binary_name
}
'
)
print
(
f
"TODO: compile library for specific version:
{
binary_name
}
"
)
print
(
'defaulting to libbitsandbytes.so'
)
print
(
"defaulting to libbitsandbytes.so"
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
'/libbitsandbytes.so'
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
"/libbitsandbytes.so"
)
else
:
else
:
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
f
'/
{
binary_name
}
'
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
f
"/
{
binary_name
}
"
)
@
classmethod
@
classmethod
def
get_instance
(
cls
):
def
get_instance
(
cls
):
...
@@ -35,6 +40,8 @@ try:
...
@@ -35,6 +40,8 @@ try:
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
COMPILED_WITH_CUDA
=
True
COMPILED_WITH_CUDA
=
True
except
AttributeError
:
except
AttributeError
:
warn
(
"The installed version of bitsandbytes was compiled without GPU support. "
warn
(
"8-bit optimizers and GPU quantization are unavailable."
)
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
COMPILED_WITH_CUDA
=
False
COMPILED_WITH_CUDA
=
False
bitsandbytes/cuda_setup.py
View file @
bfa0e332
...
@@ -18,31 +18,36 @@ evaluation:
...
@@ -18,31 +18,36 @@ evaluation:
- based on that set the default path
- based on that set the default path
"""
"""
import
ctypes
import
shlex
import
subprocess
from
os
import
environ
as
env
from
os
import
environ
as
env
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Set
,
Union
from
typing
import
Set
,
Union
from
.utils
import
warn_of_missing_prerequisite
,
print_err
import
ctypes
from
.utils
import
print_err
,
warn_of_missing_prerequisite
import
shlex
import
subprocess
def
execute_and_return
(
strCMD
):
def
execute_and_return
(
strCMD
):
proc
=
subprocess
.
Popen
(
shlex
.
split
(
strCMD
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
proc
=
subprocess
.
Popen
(
shlex
.
split
(
strCMD
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
out
,
err
=
proc
.
communicate
()
out
,
err
=
proc
.
communicate
()
out
,
err
=
out
.
decode
(
"UTF-8"
).
strip
(),
err
.
decode
(
"UTF-8"
).
strip
()
out
,
err
=
out
.
decode
(
"UTF-8"
).
strip
(),
err
.
decode
(
"UTF-8"
).
strip
()
return
out
,
err
return
out
,
err
def
check_cuda_result
(
cuda
,
result_val
):
def
check_cuda_result
(
cuda
,
result_val
):
if
result_val
!=
0
:
if
result_val
!=
0
:
cuda
.
cuGetErrorString
(
result_val
,
ctypes
.
byref
(
error_str
))
cuda
.
cuGetErrorString
(
result_val
,
ctypes
.
byref
(
error_str
))
print
(
f
"Count not initialize CUDA - failure!"
)
print
(
f
"Count not initialize CUDA - failure!"
)
raise
Exception
(
'
CUDA exception!
'
)
raise
Exception
(
"
CUDA exception!
"
)
return
result_val
return
result_val
# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
def
get_compute_capability
():
def
get_compute_capability
():
libnames
=
(
'
libcuda.so
'
,
'
libcuda.dylib
'
,
'
cuda.dll
'
)
libnames
=
(
"
libcuda.so
"
,
"
libcuda.dylib
"
,
"
cuda.dll
"
)
for
libname
in
libnames
:
for
libname
in
libnames
:
try
:
try
:
cuda
=
ctypes
.
CDLL
(
libname
)
cuda
=
ctypes
.
CDLL
(
libname
)
...
@@ -51,8 +56,7 @@ def get_compute_capability():
...
@@ -51,8 +56,7 @@ def get_compute_capability():
else
:
else
:
break
break
else
:
else
:
raise
OSError
(
"could not load any of: "
+
' '
.
join
(
libnames
))
raise
OSError
(
"could not load any of: "
+
" "
.
join
(
libnames
))
nGpus
=
ctypes
.
c_int
()
nGpus
=
ctypes
.
c_int
()
cc_major
=
ctypes
.
c_int
()
cc_major
=
ctypes
.
c_int
()
...
@@ -69,39 +73,43 @@ def get_compute_capability():
...
@@ -69,39 +73,43 @@ def get_compute_capability():
ccs
=
[]
ccs
=
[]
for
i
in
range
(
nGpus
.
value
):
for
i
in
range
(
nGpus
.
value
):
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
i
))
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
i
))
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceComputeCapability
(
ctypes
.
byref
(
cc_major
),
ctypes
.
byref
(
cc_minor
),
device
))
result
=
check_cuda_result
(
ccs
.
append
(
f
'
{
cc_major
.
value
}
.
{
cc_minor
.
value
}
'
)
cuda
,
cuda
.
cuDeviceComputeCapability
(
ctypes
.
byref
(
cc_major
),
ctypes
.
byref
(
cc_minor
),
device
),
)
ccs
.
append
(
f
"
{
cc_major
.
value
}
.
{
cc_minor
.
value
}
"
)
#TODO: handle different compute capabilities; for now, take the max
#
TODO: handle different compute capabilities; for now, take the max
ccs
.
sort
()
ccs
.
sort
()
return
ccs
[
-
1
]
# return ccs[-1]
return
ccs
CUDA_RUNTIME_LIB
:
str
=
"libcudart.so"
CUDA_RUNTIME_LIB
:
str
=
"libcudart.so"
def
tokenize_paths
(
paths
:
str
)
->
Set
[
Path
]:
def
tokenize_paths
(
paths
:
str
)
->
Set
[
Path
]:
return
{
return
{
Path
(
ld_path
)
for
ld_path
in
paths
.
split
(
":"
)
if
ld_path
}
Path
(
ld_path
)
for
ld_path
in
paths
.
split
(
':'
)
if
ld_path
}
def
get_cuda_runtime_lib_path
(
def
get_cuda_runtime_lib_path
(
# TODO: replace this with logic for all paths in env vars
# TODO: replace this with logic for all paths in env vars
LD_LIBRARY_PATH
:
Union
[
str
,
None
]
=
env
.
get
(
"LD_LIBRARY_PATH"
)
LD_LIBRARY_PATH
:
Union
[
str
,
None
]
=
env
.
get
(
"LD_LIBRARY_PATH"
)
)
->
Union
[
Path
,
None
]:
)
->
Union
[
Path
,
None
]:
""" # TODO: add doc-string
"""# TODO: add doc-string"""
"""
if
not
LD_LIBRARY_PATH
:
if
not
LD_LIBRARY_PATH
:
warn_of_missing_prerequisite
(
warn_of_missing_prerequisite
(
'
LD_LIBRARY_PATH is completely missing from environment!
'
"
LD_LIBRARY_PATH is completely missing from environment!
"
)
)
return
None
return
None
ld_library_paths
:
Set
[
Path
]
=
tokenize_paths
(
LD_LIBRARY_PATH
)
ld_library_paths
:
Set
[
Path
]
=
tokenize_paths
(
LD_LIBRARY_PATH
)
non_existent_directories
:
Set
[
Path
]
=
{
non_existent_directories
:
Set
[
Path
]
=
{
path
for
path
in
ld_library_paths
path
for
path
in
ld_library_paths
if
not
path
.
exists
()
if
not
path
.
exists
()
}
}
if
non_existent_directories
:
if
non_existent_directories
:
...
@@ -111,7 +119,8 @@ def get_cuda_runtime_lib_path(
...
@@ -111,7 +119,8 @@ def get_cuda_runtime_lib_path(
)
)
cuda_runtime_libs
:
Set
[
Path
]
=
{
cuda_runtime_libs
:
Set
[
Path
]
=
{
path
/
CUDA_RUNTIME_LIB
for
path
in
ld_library_paths
path
/
CUDA_RUNTIME_LIB
for
path
in
ld_library_paths
if
(
path
/
CUDA_RUNTIME_LIB
).
is_file
()
if
(
path
/
CUDA_RUNTIME_LIB
).
is_file
()
}
-
non_existent_directories
}
-
non_existent_directories
...
@@ -126,26 +135,31 @@ def get_cuda_runtime_lib_path(
...
@@ -126,26 +135,31 @@ def get_cuda_runtime_lib_path(
single_cuda_runtime_lib_dir
=
next
(
iter
(
cuda_runtime_libs
))
single_cuda_runtime_lib_dir
=
next
(
iter
(
cuda_runtime_libs
))
return
single_cuda_runtime_lib_dir
return
single_cuda_runtime_lib_dir
def
evaluate_cuda_setup
():
def
evaluate_cuda_setup
():
cuda_path
=
get_cuda_runtime_lib_path
()
cuda_path
=
get_cuda_runtime_lib_path
()
cc
=
get_compute_capability
()
cc
=
get_compute_capability
()
binary_name
=
'
libbitsandbytes_cpu.so
'
binary_name
=
"
libbitsandbytes_cpu.so
"
if
not
(
has_gpu
:
=
bool
(
cc
)):
if
not
(
has_gpu
:
=
bool
(
cc
)):
print
(
'WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library...'
)
print
(
"WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library..."
)
return
binary_name
return
binary_name
has_cublaslt
=
cc
in
[
'
7.5
'
,
'
8.0
'
,
'
8.6
'
]
has_cublaslt
=
cc
in
[
"
7.5
"
,
"
8.0
"
,
"
8.6
"
]
# TODO:
# TODO:
# (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
# (2) Multiple CUDA versions installed
cuda_home
=
str
(
Path
(
cuda_path
).
parent
.
parent
)
cuda_home
=
str
(
Path
(
cuda_path
).
parent
.
parent
)
ls_output
,
err
=
execute_and_return
(
f
'
{
cuda_home
}
/bin/nvcc --version'
)
ls_output
,
err
=
execute_and_return
(
f
"
{
cuda_home
}
/bin/nvcc --version"
)
cuda_version
=
ls_output
.
split
(
'
\n
'
)[
3
].
split
(
','
)[
-
1
].
strip
().
lower
().
replace
(
'v'
,
''
)
cuda_version
=
(
major
,
minor
,
revision
=
cuda_version
.
split
(
'.'
)
ls_output
.
split
(
"
\n
"
)[
3
].
split
(
","
)[
-
1
].
strip
().
lower
().
replace
(
"v"
,
""
)
cuda_version_string
=
f
'
{
major
}{
minor
}
'
)
major
,
minor
,
revision
=
cuda_version
.
split
(
"."
)
cuda_version_string
=
f
"
{
major
}{
minor
}
"
binary_name
=
f
'libbitsandbytes_cuda
{
cuda_version_string
}
_
{
(
"cublaslt"
if
has_cublaslt
else
""
)
}
.so'
binary_name
=
f
'libbitsandbytes_cuda
{
cuda_version_string
}
_
{
(
"cublaslt"
if
has_cublaslt
else
""
)
}
.so'
...
...
bitsandbytes/debug_cli.py
View file @
bfa0e332
import
typer
import
typer
cli
=
typer
.
Typer
()
cli
=
typer
.
Typer
()
...
...
bitsandbytes/functional.py
View file @
bfa0e332
This diff is collapsed.
Click to expand it.
bitsandbytes/nn/__init__.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.modules
import
StableEmbedding
,
Linear8bit
,
Linear8bitLt
,
Int8Params
from
.modules
import
Int8Params
,
Linear8bit
,
Linear8bitLt
,
StableEmbedding
bitsandbytes/nn/modules.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
torch
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Set
,
import
bitsandbytes
as
bnb
Tuple
,
TypeVar
,
Union
,
overload
)
from
typing
import
Union
,
Tuple
,
Any
,
Callable
,
Iterator
,
Set
,
Optional
,
overload
,
TypeVar
,
Mapping
,
Dict
import
torch
from
torch
import
Tensor
,
device
,
dtype
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
device
,
dtype
,
nn
from
torch.nn.parameter
import
Parameter
import
bitsandbytes
as
bnb
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.optim
import
GlobalOptimManager
T
=
TypeVar
(
'T'
,
bound
=
'torch.nn.Module'
)
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
class
StableEmbedding
(
torch
.
nn
.
Embedding
):
class
StableEmbedding
(
torch
.
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
def
__init__
(
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
self
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
)
->
None
:
num_embeddings
:
int
,
super
(
StableEmbedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
)
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.0
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
,
)
->
None
:
super
(
StableEmbedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
)
self
.
norm
=
torch
.
nn
.
LayerNorm
(
embedding_dim
)
self
.
norm
=
torch
.
nn
.
LayerNorm
(
embedding_dim
)
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
'weight'
,
{
'optim_bits'
:
32
})
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
"weight"
,
{
"optim_bits"
:
32
}
)
def
reset_parameters
(
self
)
->
None
:
def
reset_parameters
(
self
)
->
None
:
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
self
.
_fill_padding_idx_with_zero
()
self
.
_fill_padding_idx_with_zero
()
'''
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
"""
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
PyTorch releases.
'''
"""
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
:
if
self
.
padding_idx
is
not
None
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding):
...
@@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding):
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
emb
=
F
.
embedding
(
emb
=
F
.
embedding
(
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
input
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
return
self
.
norm
(
emb
)
return
self
.
norm
(
emb
)
class
Embedding
(
torch
.
nn
.
Embedding
):
class
Embedding
(
torch
.
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
def
__init__
(
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
self
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
)
->
None
:
num_embeddings
:
int
,
super
(
Embedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
)
embedding_dim
:
int
,
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
'weight'
,
{
'optim_bits'
:
32
})
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.0
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
,
)
->
None
:
super
(
Embedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
)
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
"weight"
,
{
"optim_bits"
:
32
}
)
def
reset_parameters
(
self
)
->
None
:
def
reset_parameters
(
self
)
->
None
:
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
self
.
_fill_padding_idx_with_zero
()
self
.
_fill_padding_idx_with_zero
()
'''
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
"""
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
PyTorch releases.
'''
"""
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
:
if
self
.
padding_idx
is
not
None
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding):
...
@@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding):
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
emb
=
F
.
embedding
(
emb
=
F
.
embedding
(
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
input
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
return
emb
return
emb
class
Int8Params
(
torch
.
nn
.
Parameter
):
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
CB
=
None
cls
.
SCB
=
None
cls
.
SCB
=
None
...
@@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter):
...
@@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter):
del
CBt
del
CBt
del
SCBt
del
SCBt
self
.
data
=
CB
self
.
data
=
CB
setattr
(
self
,
'
CB
'
,
CB
)
setattr
(
self
,
"
CB
"
,
CB
)
setattr
(
self
,
'
SCB
'
,
SCB
)
setattr
(
self
,
"
SCB
"
,
SCB
)
return
self
return
self
@
overload
@
overload
def
to
(
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
def
to
(
non_blocking
:
bool
=
...)
->
T
:
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
non_blocking
:
bool
=
...,
)
->
T
:
...
...
@
overload
@
overload
...
@@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter):
...
@@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter):
...
...
def
to
(
self
,
*
args
,
**
kwargs
):
def
to
(
self
,
*
args
,
**
kwargs
):
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
if
device
is
not
None
and
device
.
type
==
'cuda'
and
self
.
data
.
device
.
type
==
'cpu'
:
return
self
.
cuda
(
device
)
)
if
(
device
is
not
None
and
device
.
type
==
"cuda"
and
self
.
data
.
device
.
type
==
"cpu"
):
return
self
.
cuda
(
device
)
else
:
else
:
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
)
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
,
)
new_param
.
CB
=
self
.
CB
new_param
.
CB
=
self
.
CB
new_param
.
SCB
=
self
.
SCB
new_param
.
SCB
=
self
.
SCB
return
new_param
return
new_param
class
Linear8bitLt
(
nn
.
Linear
):
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
,
):
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
state
=
bnb
.
MatmulLtState
()
self
.
index
=
index
self
.
index
=
index
self
.
state
.
threshold
=
threshold
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
self
.
state
.
has_fp16_weights
=
has_fp16_weights
...
@@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear):
...
@@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
self
.
state
.
is_training
=
self
.
training
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
if
self
.
weight
.
CB
is
not
None
:
#assert not self.state.has_fp16_weights
self
.
init_8bit_state
()
#if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
# assert not self.state.has_fp16_weights
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
state
=
self
.
state
)
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
state
=
self
.
state
)
...
@@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear):
...
@@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear):
return
out
return
out
class
Linear8bit
(
nn
.
Linear
):
class
Linear8bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
quant_type
=
'vector'
,
index
=
None
,
args
=
None
,
sparse_decomp
=
False
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
quant_type
=
"vector"
,
index
=
None
,
args
=
None
,
sparse_decomp
=
False
,
):
super
(
Linear8bit
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
super
(
Linear8bit
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
quant_type
=
quant_type
self
.
quant_type
=
quant_type
self
.
index
=
index
self
.
index
=
index
...
@@ -178,15 +266,24 @@ class Linear8bit(nn.Linear):
...
@@ -178,15 +266,24 @@ class Linear8bit(nn.Linear):
self
.
iter
+=
1
self
.
iter
+=
1
if
self
.
iter
%
self
.
args
.
clip_freq
==
0
:
if
self
.
iter
%
self
.
args
.
clip_freq
==
0
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
maxval
,
maxidx
=
torch
.
topk
(
torch
.
abs
(
self
.
weight
.
flatten
()),
k
=
self
.
args
.
clip_idx
)
maxval
,
maxidx
=
torch
.
topk
(
torch
.
abs
(
self
.
weight
.
flatten
()),
k
=
self
.
args
.
clip_idx
)
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
print
(
'
clip
'
,
maxval
[
-
1
].
item
())
print
(
"
clip
"
,
maxval
[
-
1
].
item
())
self
.
weight
.
clip_
(
-
maxval
[
-
1
],
maxval
[
-
1
])
self
.
weight
.
clip_
(
-
maxval
[
-
1
],
maxval
[
-
1
])
if
self
.
args
is
not
None
:
if
self
.
args
is
not
None
:
out
=
bnb
.
nn
.
functional
.
sparse_decomposed_linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
qval
=
self
.
args
.
sparse_decomp_val
,
quant_type
=
self
.
args
.
quant_type
)
out
=
bnb
.
nn
.
functional
.
sparse_decomposed_linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
qval
=
self
.
args
.
sparse_decomp_val
,
quant_type
=
self
.
args
.
quant_type
,
)
else
:
else
:
out
=
bnb
.
nn
.
functional
.
linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
quant_type
=
self
.
args
.
quant_type
)
out
=
bnb
.
nn
.
functional
.
linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
quant_type
=
self
.
args
.
quant_type
)
return
out
return
out
bitsandbytes/optim/__init__.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.cextension
import
COMPILED_WITH_CUDA
from
bitsandbytes.cextension
import
COMPILED_WITH_CUDA
...
...
bitsandbytes/optim/adagrad.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
Adagrad
(
Optimizer1State
):
class
Adagrad
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
def
__init__
(
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
...
@@ -14,15 +27,39 @@ class Adagrad(Optimizer1State):
...
@@ -14,15 +27,39 @@ class Adagrad(Optimizer1State):
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
raise
ValueError
(
"Lr Decay != 0.0 not supported!"
)
super
(
Adagrad
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
super
(
Adagrad
,
self
).
__init__
(
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adagrad8bit
(
Optimizer1State
):
class
Adagrad8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
def
__init__
(
optim_bits
=
8
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
8
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
...
@@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State):
...
@@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State):
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
if
lr_decay
!=
0.0
:
raise
ValueError
(
'
Lr Decay != 0.0 not supported!
'
)
raise
ValueError
(
"
Lr Decay != 0.0 not supported!
"
)
assert
block_wise
assert
block_wise
super
(
Adagrad8bit
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
super
(
Adagrad8bit
,
self
).
__init__
(
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adagrad32bit
(
Optimizer1State
):
class
Adagrad32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
def
__init__
(
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
...
@@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State):
...
@@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State):
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
raise
ValueError
(
"Lr Decay != 0.0 not supported!"
)
super
(
Adagrad32bit
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
super
(
Adagrad32bit
,
self
).
__init__
(
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/adam.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
math
import
math
...
@@ -8,29 +8,97 @@ import os
...
@@ -8,29 +8,97 @@ import os
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
bitsandbytes.optim.optimizer
import
Optimizer2State
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
Adam
(
Optimizer2State
):
class
Adam
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
super
(
Adam
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adam8bit
(
Optimizer2State
):
class
Adam8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
super
(
Adam8bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam8bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adam32bit
(
Optimizer2State
):
class
Adam32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
super
(
Adam32bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam32bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AnalysisAdam
(
torch
.
optim
.
Optimizer
):
class
AnalysisAdam
(
torch
.
optim
.
Optimizer
):
...
@@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer):
eps
=
1e-8
,
eps
=
1e-8
,
weight_decay
=
0
,
weight_decay
=
0
,
amsgrad
=
False
,
amsgrad
=
False
,
bnb_analysis
=
'
dynamic-blockwise
'
,
bnb_analysis
=
"
dynamic-blockwise
"
,
savedir
=
None
savedir
=
None
,
):
):
defaults
=
dict
(
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
...
@@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer):
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p_data_fp32
)
# Exponential moving average of squared gradient values
# Exponential moving average of squared gradient values
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
'abserrors'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"abserrors"
]
=
torch
.
zeros
(
state
[
'relerrors'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
(
256
,
256
),
device
=
p_data_fp32
.
device
state
[
'counts'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
)
state
[
"relerrors"
]
=
torch
.
zeros
(
(
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"counts"
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
if
amsgrad
:
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. values
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
...
@@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer):
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
step_size
=
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
step_size
=
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
e
=
state
[
'
abserrors
'
]
e
=
state
[
"
abserrors
"
]
rele
=
state
[
'
relerrors
'
]
rele
=
state
[
"
relerrors
"
]
counts
=
state
[
'
counts
'
]
counts
=
state
[
"
counts
"
]
if
group
[
"weight_decay"
]
!=
0
:
if
group
[
"weight_decay"
]
!=
0
:
p_data_fp32
.
add_
(
p_data_fp32
.
add_
(
...
@@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer):
if
amsgrad
:
if
amsgrad
:
max_exp_avg_sq
=
state
[
"max_exp_avg_sq"
]
max_exp_avg_sq
=
state
[
"max_exp_avg_sq"
]
# Decay the first and second moment running average coefficient
# Decay the first and second moment running average coefficient
exp_avg
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
update_fp32
=
exp_avg
/
denom
update_fp32
=
exp_avg
/
denom
if
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
:
if
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
:
# embedding layer or too small
# embedding layer or too small
p_data_fp32
+=
-
step_size
*
update_fp32
p_data_fp32
+=
-
step_size
*
update_fp32
else
:
else
:
if
self
.
analysis
==
'
dynamic-blockwise
'
:
if
self
.
analysis
==
"
dynamic-blockwise
"
:
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize_blockwise
(
exp_avg
,
code
=
code1
)
C1
,
S1
=
F
.
quantize_blockwise
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize_blockwise
(
C1
,
S1
)
state1
=
F
.
dequantize_blockwise
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize_blockwise
(
exp_avg_sq
,
code
=
code2
)
C2
,
S2
=
F
.
quantize_blockwise
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize_blockwise
(
C2
,
S2
)
state2
=
F
.
dequantize_blockwise
(
C2
,
S2
)
elif
self
.
analysis
==
'
dynamic
'
:
elif
self
.
analysis
==
"
dynamic
"
:
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
elif
self
.
analysis
==
'
linear
'
:
elif
self
.
analysis
==
"
linear
"
:
code1
=
F
.
create_linear_map
(
signed
=
True
).
to
(
p
.
device
)
code1
=
F
.
create_linear_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_linear_map
(
signed
=
False
).
to
(
p
.
device
)
code2
=
F
.
create_linear_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
elif
self
.
analysis
==
'
quantile
'
:
elif
self
.
analysis
==
"
quantile
"
:
code1
=
F
.
estimate_quantiles
(
exp_avg
)
code1
=
F
.
estimate_quantiles
(
exp_avg
)
code2
=
F
.
estimate_quantiles
(
exp_avg_sq
)
code2
=
F
.
estimate_quantiles
(
exp_avg_sq
)
C1
=
F
.
quantize_no_absmax
(
exp_avg
,
code
=
code1
)
C1
=
F
.
quantize_no_absmax
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize_no_absmax
(
C1
,
code1
)
state1
=
F
.
dequantize_no_absmax
(
C1
,
code1
)
C2
=
F
.
quantize_no_absmax
(
exp_avg_sq
,
code
=
code2
)
C2
=
F
.
quantize_no_absmax
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize_no_absmax
(
C2
,
code2
)
state2
=
F
.
dequantize_no_absmax
(
C2
,
code2
)
elif
self
.
analysis
==
'
my-quantization-routine
'
:
elif
self
.
analysis
==
"
my-quantization-routine
"
:
pass
pass
# 1. get code
# 1. get code
# 2. quantize
# 2. quantize
# 3. dequantize
# 3. dequantize
# Error will be calculated automatically!
# Error will be calculated automatically!
else
:
else
:
raise
ValueError
(
f
'
Invalid analysis value:
{
self
.
analysis
}
!
'
)
raise
ValueError
(
f
"
Invalid analysis value:
{
self
.
analysis
}
!
"
)
denom
=
state2
.
sqrt
().
add_
(
group
[
"eps"
])
denom
=
state2
.
sqrt
().
add_
(
group
[
"eps"
])
update_8bit
=
state1
/
denom
update_8bit
=
state1
/
denom
abserr
=
torch
.
abs
(
update_8bit
-
update_fp32
)
abserr
=
torch
.
abs
(
update_8bit
-
update_fp32
)
relerr
=
abserr
/
torch
.
abs
(
update_fp32
+
1e-6
)
relerr
=
abserr
/
torch
.
abs
(
update_fp32
+
1e-6
)
C1
,
C2
=
C1
.
int
(),
C2
.
int
()
C1
,
C2
=
C1
.
int
(),
C2
.
int
()
F
.
histogram_scatter_add_2d
(
e
,
C1
.
int
(),
C2
.
int
(),
abserr
)
F
.
histogram_scatter_add_2d
(
e
,
C1
.
int
(),
C2
.
int
(),
abserr
)
F
.
histogram_scatter_add_2d
(
rele
,
C1
.
int
(),
C2
.
int
(),
relerr
)
F
.
histogram_scatter_add_2d
(
rele
,
C1
.
int
(),
C2
.
int
(),
relerr
)
F
.
histogram_scatter_add_2d
(
counts
,
C1
.
int
(),
C2
.
int
(),
torch
.
ones_like
(
abserr
))
F
.
histogram_scatter_add_2d
(
counts
,
C1
.
int
(),
C2
.
int
(),
torch
.
ones_like
(
abserr
)
p_data_fp32
+=
-
step_size
*
update_fp32
)
p_data_fp32
+=
-
step_size
*
update_fp32
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
if
self
.
savedir
!=
''
and
state
[
'step'
]
%
100
==
0
:
if
self
.
savedir
!=
""
and
state
[
"step"
]
%
100
==
0
:
if
not
os
.
path
.
exists
(
self
.
savedir
):
os
.
makedirs
(
self
.
savedir
)
if
not
os
.
path
.
exists
(
self
.
savedir
):
shapestr
=
'_'
.
join
([
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
])
os
.
makedirs
(
self
.
savedir
)
pathe
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_abserr.pkl'
)
shapestr
=
"_"
.
join
([
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
])
pathrele
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_relerr.pkl'
)
pathe
=
os
.
path
.
join
(
pathcounts
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_counts.pkl'
)
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_abserr.pkl"
)
pathrele
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_relerr.pkl"
)
pathcounts
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_counts.pkl"
)
torch
.
save
(
e
,
pathe
)
torch
.
save
(
e
,
pathe
)
torch
.
save
(
rele
,
pathrele
)
torch
.
save
(
rele
,
pathrele
)
torch
.
save
(
counts
,
pathcounts
)
torch
.
save
(
counts
,
pathcounts
)
...
@@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer):
if
p
.
data
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
if
p
.
data
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
p
.
data
.
copy_
(
p_data_fp32
)
p
.
data
.
copy_
(
p_data_fp32
)
return
loss
return
loss
bitsandbytes/optim/adamw.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer2State
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
AdamW
(
Optimizer2State
):
class
AdamW
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
1e-2
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
super
(
AdamW
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AdamW8bit
(
Optimizer2State
):
class
AdamW8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
super
(
AdamW8bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW8bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AdamW32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW32bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
class
AdamW32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW32bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/lamb.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer2State
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
LAMB
(
Optimizer2State
):
class
LAMB
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
params
,
super
(
LAMB
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
class
LAMB8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
super
(
LAMB8bit
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
class
LAMB32bit
(
Optimizer2State
):
class
LAMB8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
params
,
super
(
LAMB32bit
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB8bit
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
class
LAMB32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB32bit
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
bitsandbytes/optim/lars.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
torch
import
torch
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
bitsandbytes.optim.optimizer
import
Optimizer1State
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
LARS
(
Optimizer1State
):
class
LARS
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
LARS
,
self
).
__init__
(
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
LARS8bit
(
Optimizer1State
):
class
LARS8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS8bit
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
LARS8bit
,
self
).
__init__
(
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
LARS32bit
(
Optimizer1State
):
class
LARS32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS32bit
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
LARS32bit
,
self
).
__init__
(
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
PytorchLARS
(
Optimizer
):
class
PytorchLARS
(
Optimizer
):
def
__init__
(
self
,
params
,
lr
=
0.01
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
max_unorm
=
0.02
):
self
,
params
,
lr
=
0.01
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
max_unorm
=
0.02
,
):
if
lr
<
0.0
:
if
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
momentum
<
0.0
:
if
momentum
<
0.0
:
...
@@ -45,8 +123,14 @@ class PytorchLARS(Optimizer):
...
@@ -45,8 +123,14 @@ class PytorchLARS(Optimizer):
if
weight_decay
<
0.0
:
if
weight_decay
<
0.0
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
dampening
=
dampening
,
defaults
=
dict
(
weight_decay
=
weight_decay
,
nesterov
=
nesterov
,
max_unorm
=
max_unorm
)
lr
=
lr
,
momentum
=
momentum
,
dampening
=
dampening
,
weight_decay
=
weight_decay
,
nesterov
=
nesterov
,
max_unorm
=
max_unorm
,
)
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
super
(
PytorchLARS
,
self
).
__init__
(
params
,
defaults
)
super
(
PytorchLARS
,
self
).
__init__
(
params
,
defaults
)
...
@@ -54,7 +138,7 @@ class PytorchLARS(Optimizer):
...
@@ -54,7 +138,7 @@ class PytorchLARS(Optimizer):
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
super
(
PytorchLARS
,
self
).
__setstate__
(
state
)
super
(
PytorchLARS
,
self
).
__setstate__
(
state
)
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
group
.
setdefault
(
'
nesterov
'
,
False
)
group
.
setdefault
(
"
nesterov
"
,
False
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
step
(
self
,
closure
=
None
):
def
step
(
self
,
closure
=
None
):
...
@@ -73,15 +157,16 @@ class PytorchLARS(Optimizer):
...
@@ -73,15 +157,16 @@ class PytorchLARS(Optimizer):
params_with_grad
=
[]
params_with_grad
=
[]
d_p_list
=
[]
d_p_list
=
[]
momentum_buffer_list
=
[]
momentum_buffer_list
=
[]
weight_decay
=
group
[
'
weight_decay
'
]
weight_decay
=
group
[
"
weight_decay
"
]
momentum
=
group
[
'
momentum
'
]
momentum
=
group
[
"
momentum
"
]
dampening
=
group
[
'
dampening
'
]
dampening
=
group
[
"
dampening
"
]
nesterov
=
group
[
'
nesterov
'
]
nesterov
=
group
[
"
nesterov
"
]
max_unorm
=
group
[
'
max_unorm
'
]
max_unorm
=
group
[
"
max_unorm
"
]
lr
=
group
[
'
lr
'
]
lr
=
group
[
"
lr
"
]
for
p
in
group
[
'params'
]:
for
p
in
group
[
"params"
]:
if
p
.
grad
is
None
:
continue
if
p
.
grad
is
None
:
continue
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
d_p
=
p
.
grad
d_p
=
p
.
grad
...
@@ -89,16 +174,16 @@ class PytorchLARS(Optimizer):
...
@@ -89,16 +174,16 @@ class PytorchLARS(Optimizer):
d_p
=
d_p
.
add
(
param
,
alpha
=
weight_decay
)
d_p
=
d_p
.
add
(
param
,
alpha
=
weight_decay
)
if
momentum
!=
0
:
if
momentum
!=
0
:
buf
=
state
.
get
(
'
momentum_buffer
'
,
None
)
buf
=
state
.
get
(
"
momentum_buffer
"
,
None
)
if
buf
is
None
:
if
buf
is
None
:
buf
=
torch
.
clone
(
d_p
).
detach
()
buf
=
torch
.
clone
(
d_p
).
detach
()
state
[
'
momentum_buffer
'
]
=
buf
state
[
"
momentum_buffer
"
]
=
buf
else
:
else
:
buf
.
mul_
(
momentum
).
add_
(
d_p
,
alpha
=
1
-
dampening
)
buf
.
mul_
(
momentum
).
add_
(
d_p
,
alpha
=
1
-
dampening
)
if
nesterov
:
if
nesterov
:
update
=
d_p
+
buf
*
momentum
update
=
d_p
+
buf
*
momentum
else
:
else
:
update
=
buf
update
=
buf
...
@@ -107,9 +192,9 @@ class PytorchLARS(Optimizer):
...
@@ -107,9 +192,9 @@ class PytorchLARS(Optimizer):
assert
p
.
dtype
==
torch
.
float32
assert
p
.
dtype
==
torch
.
float32
pnorm
=
torch
.
norm
(
p
.
detach
())
pnorm
=
torch
.
norm
(
p
.
detach
())
unorm
=
torch
.
norm
(
update
)
unorm
=
torch
.
norm
(
update
)
if
unorm
>
max_unorm
*
pnorm
:
if
unorm
>
max_unorm
*
pnorm
:
update_scale
=
max_unorm
*
pnorm
/
unorm
update_scale
=
max_unorm
*
pnorm
/
unorm
p
.
add_
(
update
,
alpha
=-
lr
*
update_scale
)
p
.
add_
(
update
,
alpha
=-
lr
*
update_scale
)
return
loss
return
loss
bitsandbytes/optim/optimizer.py
View file @
bfa0e332
This diff is collapsed.
Click to expand it.
bitsandbytes/optim/rmsprop.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
RMSprop
(
Optimizer1State
):
class
RMSprop
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
optim_bits
=
32
,
args
=
None
,
def
__init__
(
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
super
(
RMSprop
,
self
).
__init__
(
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
RMSprop8bit
(
Optimizer1State
):
class
RMSprop8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
def
__init__
(
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop8bit
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
super
(
RMSprop8bit
,
self
).
__init__
(
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
RMSprop32bit
(
Optimizer1State
):
class
RMSprop32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
def
__init__
(
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop32bit
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
super
(
RMSprop32bit
,
self
).
__init__
(
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/sgd.py
View file @
bfa0e332
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
SGD
(
Optimizer1State
):
class
SGD
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
SGD
,
self
).
__init__
(
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
SGD8bit
(
Optimizer1State
):
class
SGD8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD8bit
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
SGD8bit
,
self
).
__init__
(
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
SGD32bit
(
Optimizer1State
):
class
SGD32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD32bit
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
SGD32bit
,
self
).
__init__
(
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/utils.py
View file @
bfa0e332
import
sys
import
sys
def
print_err
(
s
:
str
)
->
None
:
def
print_err
(
s
:
str
)
->
None
:
print
(
s
,
file
=
sys
.
stderr
)
print
(
s
,
file
=
sys
.
stderr
)
def
warn_of_missing_prerequisite
(
s
:
str
)
->
None
:
def
warn_of_missing_prerequisite
(
s
:
str
)
->
None
:
print_err
(
'
WARNING, missing pre-requisite:
'
+
s
)
print_err
(
"
WARNING, missing pre-requisite:
"
+
s
)
quicktest.py
View file @
bfa0e332
from
itertools
import
product
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
itertools
import
product
def
test_igemmlt
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
def
test_igemmlt
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
k
=
25
k
=
25
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
'
col32
'
)
A2
,
SA
=
F
.
transform
(
A
,
"
col32
"
)
B2
,
SB
=
F
.
transform
(
B
,
'
colx
'
)
B2
,
SB
=
F
.
transform
(
B
,
"
colx
"
)
if
dims
==
2
:
if
dims
==
2
:
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
),
'col32'
)
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"col32"
,
)
else
:
else
:
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
),
'col32'
)
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"col32"
,
)
F
.
igemmlt
(
A2
,
B2
,
C2
,
SA
,
SB
,
SC
)
F
.
igemmlt
(
A2
,
B2
,
C2
,
SA
,
SB
,
SC
)
C3
,
S
=
F
.
transform
(
C2
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
transform
(
C2
,
"
row
"
,
state
=
SC
)
#torch.testing.assert_allclose(C1, C3.float())
#
torch.testing.assert_allclose(C1, C3.float())
#print(C1)
#
print(C1)
#print(C2)
#
print(C2)
#print(C3)
#
print(C3)
allclose
=
torch
.
allclose
(
C1
,
C3
.
float
())
allclose
=
torch
.
allclose
(
C1
,
C3
.
float
())
if
allclose
:
if
allclose
:
print
(
C1
)
print
(
C1
)
...
@@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
...
@@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
print
(
C3
)
print
(
C3
)
## transposed
## transposed
#A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
#
A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
#if dims == 2:
#
if dims == 2:
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float().t())
# C1 = torch.matmul(A.float(), B.float().t())
#elif dims == 3:
#
elif dims == 3:
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.float(), A.t().float())
# C1 = torch.matmul(B.float(), A.t().float())
# C1 = C1.permute([2, 0, 1])
# C1 = C1.permute([2, 0, 1])
#A2, SA = F.transform(A, 'col32')
#
A2, SA = F.transform(A, 'col32')
#B2, SB = F.transform(B, 'colx')
#
B2, SB = F.transform(B, 'colx')
#if dims == 2:
#
if dims == 2:
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
#else:
#
else:
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
# state = (C2.shape, 'row', A.shape[0])
# state = (C2.shape, 'row', A.shape[0])
# C2, SC = F.transform(C2, 'col32', state=state)
# C2, SC = F.transform(C2, 'col32', state=state)
#F.igemmlt(A2, B2, C2, SA, SB, SC)
#
F.igemmlt(A2, B2, C2, SA, SB, SC)
#C3, S = F.transform(C2, 'row', state=SC, ld=[0])
#
C3, S = F.transform(C2, 'row', state=SC, ld=[0])
#torch.testing.assert_allclose(C1, C3.float())
#
torch.testing.assert_allclose(C1, C3.float())
## weight update
## weight update
#if dims == 3:
#
if dims == 3:
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
...
@@ -73,18 +87,18 @@ dims = (2, 3)
...
@@ -73,18 +87,18 @@ dims = (2, 3)
ldb
=
[
0
]
ldb
=
[
0
]
n
=
2
n
=
2
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
for
ldb
in
range
(
32
,
4096
,
32
):
for
ldb
in
range
(
32
,
4096
,
32
):
#
for ldb in [None]:
#
for ldb in [None]:
val
=
test_igemmlt
(
2
,
2
,
2
,
2
,
2
,
ldb
)
val
=
test_igemmlt
(
2
,
2
,
2
,
2
,
2
,
ldb
)
if
val
:
if
val
:
print
(
val
,
ldb
)
print
(
val
,
ldb
)
else
:
else
:
print
(
'
nope
'
,
ldb
)
print
(
"
nope
"
,
ldb
)
#for val in values:
#
for val in values:
#test_igemmlt(*val)
#
test_igemmlt(*val)
setup.py
View file @
bfa0e332
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment