Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
bfa0e332
Commit
bfa0e332
authored
Aug 01, 2022
by
Titus von Koeller
Browse files
ran black and isort for coherent code formatting
parent
597a8521
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2449 additions
and
991 deletions
+2449
-991
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+11
-9
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+88
-46
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+15
-8
bitsandbytes/cuda_setup.py
bitsandbytes/cuda_setup.py
+46
-32
bitsandbytes/debug_cli.py
bitsandbytes/debug_cli.py
+0
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+981
-414
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+4
-4
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+147
-50
bitsandbytes/optim/__init__.py
bitsandbytes/optim/__init__.py
+3
-3
bitsandbytes/optim/adagrad.py
bitsandbytes/optim/adagrad.py
+93
-21
bitsandbytes/optim/adam.py
bitsandbytes/optim/adam.py
+128
-51
bitsandbytes/optim/adamw.py
bitsandbytes/optim/adamw.py
+85
-19
bitsandbytes/optim/lamb.py
bitsandbytes/optim/lamb.py
+97
-20
bitsandbytes/optim/lars.py
bitsandbytes/optim/lars.py
+126
-41
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+380
-185
bitsandbytes/optim/rmsprop.py
bitsandbytes/optim/rmsprop.py
+94
-21
bitsandbytes/optim/sgd.py
bitsandbytes/optim/sgd.py
+88
-21
bitsandbytes/utils.py
bitsandbytes/utils.py
+3
-1
quicktest.py
quicktest.py
+47
-33
setup.py
setup.py
+13
-11
No files found.
bitsandbytes/__init__.py
View file @
bfa0e332
...
...
@@ -3,14 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.
nn
import
modules
from
.autograd._functions
import
mm_cublas
,
bmm
_cublas
,
m
atmul
_cublas
,
matmul
,
MatmulLtState
from
.
autograd._functions
import
(
MatmulLtState
,
bmm_cublas
,
matmul
,
matmul
_cublas
,
m
m
_cublas
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
if
COMPILED_WITH_CUDA
:
from
.optim
import
adam
__pdoc__
=
{
'libbitsandbytes'
:
False
,
'optim.optimizer.Optimizer8bit'
:
False
,
'optim.optimizer.MockArgs'
:
False
}
__pdoc__
=
{
"libbitsandbytes"
:
False
,
"optim.optimizer.Optimizer8bit"
:
False
,
"optim.optimizer.MockArgs"
:
False
,
}
bitsandbytes/autograd/_functions.py
View file @
bfa0e332
from
dataclasses
import
dataclass
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
from
dataclasses
import
dataclass
tensor
=
torch
.
Tensor
'''
"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
'''
"""
class
GlobalOutlierPooler
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
self
.
outliers
=
set
()
...
...
@@ -29,25 +32,29 @@ class GlobalOutlierPooler(object):
return
cls
.
_instance
def
add_outliers
(
self
,
outlier_idx
,
feature_dim
):
if
self
.
model_dim
is
None
:
self
.
model_dim
=
feature_dim
if
feature_dim
!=
self
.
model_dim
:
return
# we do not encode outliers for the 2nd FFN layer
if
self
.
model_dim
is
None
:
self
.
model_dim
=
feature_dim
if
feature_dim
!=
self
.
model_dim
:
return
# we do not encode outliers for the 2nd FFN layer
self
.
outliers
.
update
(
outlier_idx
.
tolist
())
def
get_current_outlier_idx
(
self
):
return
torch
.
Tensor
(
list
(
self
.
outliers
)).
to
(
torch
.
int64
)
class
MatMul8bit
(
torch
.
autograd
.
Function
):
class
MatMul8bit
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
quant_type
=
'
vector
'
,
precision
=
[
8
,
8
,
8
]):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
quant_type
=
"
vector
"
,
precision
=
[
8
,
8
,
8
]):
if
precision
[
0
]
!=
8
:
with
torch
.
no_grad
():
output
=
torch
.
matmul
(
A
,
B
)
else
:
if
len
(
B
.
shape
)
==
2
:
dim
=
0
else
:
dim
=
1
if
len
(
B
.
shape
)
==
2
:
dim
=
0
else
:
dim
=
1
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=-
1
,
quant_type
=
quant_type
)
qB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
dim
,
quant_type
=
quant_type
)
iout
=
F
.
igemm
(
qA
,
qB
)
...
...
@@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function):
else
:
if
len
(
B
.
shape
)
==
2
and
len
(
A
.
shape
)
==
3
:
grad_output
=
grad_output
.
contiguous
()
if
not
grad_output
.
is_contiguous
():
grad_output
.
contiguous
()
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
if
not
A
.
is_contiguous
():
A
=
A
.
contiguous
()
qA
,
S2
=
F
.
vectorwise_quant
(
A
.
view
(
-
1
,
A
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
if
not
grad_output
.
is_contiguous
():
grad_output
.
contiguous
()
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
,
)
if
not
A
.
is_contiguous
():
A
=
A
.
contiguous
()
qA
,
S2
=
F
.
vectorwise_quant
(
A
.
view
(
-
1
,
A
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
t
(),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
t
(),
S1
,
grad_output
.
dtype
,
quant_type
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
t
(),
S1
,
grad_output
.
dtype
,
quant_type
)
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
permute
(
permute_dim
),
S1
,
grad_output
.
dtype
,
quant_type
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
permute
(
permute_dim
),
S1
,
grad_output
.
dtype
,
quant_type
,
)
if
A
.
requires_grad
:
if
len
(
grad_output
.
shape
)
==
3
:
dims
=
[
2
]
else
:
dims
=
[
1
]
if
len
(
grad_output
.
shape
)
==
3
:
dims
=
[
2
]
else
:
dims
=
[
1
]
if
len
(
B
.
shape
)
==
3
:
# bio -> boi
...
...
@@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function):
with
torch
.
no_grad
():
grad_A
=
torch
.
matmul
(
grad_output
,
B
.
permute
(
permute_dim
))
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
)
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
)
return
grad_A
,
grad_B
,
None
,
None
,
None
...
...
@@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply
bmm_cublas
=
MatMul8bit
.
apply
matmul_cublas
=
MatMul8bit
.
apply
@
dataclass
class
MatmulLtState
:
CB
=
None
...
...
@@ -159,7 +191,6 @@ class MatmulLtState:
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
# 1. Quantize A
...
...
@@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function):
requires_gradB
=
B
.
requires_grad
formatB
=
state
.
formatB
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
assert
A
.
dtype
==
torch
.
float16
,
f
'The input data type needs to be fp16 but
{
A
.
dtype
}
was found!'
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
assert
(
A
.
dtype
==
torch
.
float16
),
f
"The input data type needs to be fp16 but
{
A
.
dtype
}
was found!"
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
...
...
@@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
#state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
#
state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#
if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
...
...
@@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
#if state.idx is not None:
#
if state.idx is not None:
# # extract outliers
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
#else:
#
else:
# subA = None
else
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
subA
=
None
# 2. Quantize B
if
state
.
has_fp16_weights
:
has_grad
=
(
True
if
(
getattr
(
B
,
'
grad
'
,
None
)
is
not
None
)
else
False
)
has_grad
=
True
if
(
getattr
(
B
,
"
grad
"
,
None
)
is
not
None
)
else
False
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
if
is_transposed
:
B
=
B
.
contiguous
()
if
is_transposed
:
B
=
B
.
contiguous
()
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
state
.
reset_grads
()
...
...
@@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
state
.
idx
=
outlier_idx
#state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
#
state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#
if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
#else:
#
else:
# state.idx = outlier_idx
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
state
.
subB
=
(
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
)
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
subA
=
A
[:,
state
.
idx
.
long
()]
...
...
@@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
# 3. Matmul
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
...
...
@@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
#
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func
=
torch
.
clone
return
clone_func
(
output
.
view
(
output_shape
))
...
...
@@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
state
=
ctx
.
state
assert
state
.
has_fp16_weights
,
'
Backprop only supported for fp16 weights.
'
assert
state
.
has_fp16_weights
,
"
Backprop only supported for fp16 weights.
"
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
...
...
@@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
'
col32
'
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"
col32
"
,
transpose
=
True
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
'
col32
'
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"
col32
"
)
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
...
...
@@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function):
matmul
=
MatMul8bitLt
.
apply
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
):
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
):
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
state
.
threshold
=
threshold
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
state
)
bitsandbytes/cextension.py
View file @
bfa0e332
import
ctypes
as
ct
import
os
from
warnings
import
warn
from
bitsandbytes.cuda_setup
import
evaluate_cuda_setup
...
...
@@ -8,17 +9,21 @@ class CUDALibrary_Singleton(object):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
self
.
context
=
{}
binary_name
=
evaluate_cuda_setup
()
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
__file__
)
+
f
'/
{
binary_name
}
'
):
print
(
f
'TODO: compile library for specific version:
{
binary_name
}
'
)
print
(
'defaulting to libbitsandbytes.so'
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
'/libbitsandbytes.so'
)
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
__file__
)
+
f
"/
{
binary_name
}
"
):
print
(
f
"TODO: compile library for specific version:
{
binary_name
}
"
)
print
(
"defaulting to libbitsandbytes.so"
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
"/libbitsandbytes.so"
)
else
:
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
f
'/
{
binary_name
}
'
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
f
"/
{
binary_name
}
"
)
@
classmethod
def
get_instance
(
cls
):
...
...
@@ -35,6 +40,8 @@ try:
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
COMPILED_WITH_CUDA
=
True
except
AttributeError
:
warn
(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
warn
(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
COMPILED_WITH_CUDA
=
False
bitsandbytes/cuda_setup.py
View file @
bfa0e332
...
...
@@ -18,31 +18,36 @@ evaluation:
- based on that set the default path
"""
import
ctypes
import
shlex
import
subprocess
from
os
import
environ
as
env
from
pathlib
import
Path
from
typing
import
Set
,
Union
from
.utils
import
warn_of_missing_prerequisite
,
print_err
import
ctypes
import
shlex
import
subprocess
from
.utils
import
print_err
,
warn_of_missing_prerequisite
def
execute_and_return
(
strCMD
):
proc
=
subprocess
.
Popen
(
shlex
.
split
(
strCMD
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
proc
=
subprocess
.
Popen
(
shlex
.
split
(
strCMD
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
out
,
err
=
proc
.
communicate
()
out
,
err
=
out
.
decode
(
"UTF-8"
).
strip
(),
err
.
decode
(
"UTF-8"
).
strip
()
return
out
,
err
def
check_cuda_result
(
cuda
,
result_val
):
if
result_val
!=
0
:
cuda
.
cuGetErrorString
(
result_val
,
ctypes
.
byref
(
error_str
))
print
(
f
"Count not initialize CUDA - failure!"
)
raise
Exception
(
'
CUDA exception!
'
)
raise
Exception
(
"
CUDA exception!
"
)
return
result_val
# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
def
get_compute_capability
():
libnames
=
(
'
libcuda.so
'
,
'
libcuda.dylib
'
,
'
cuda.dll
'
)
libnames
=
(
"
libcuda.so
"
,
"
libcuda.dylib
"
,
"
cuda.dll
"
)
for
libname
in
libnames
:
try
:
cuda
=
ctypes
.
CDLL
(
libname
)
...
...
@@ -51,8 +56,7 @@ def get_compute_capability():
else
:
break
else
:
raise
OSError
(
"could not load any of: "
+
' '
.
join
(
libnames
))
raise
OSError
(
"could not load any of: "
+
" "
.
join
(
libnames
))
nGpus
=
ctypes
.
c_int
()
cc_major
=
ctypes
.
c_int
()
...
...
@@ -69,39 +73,43 @@ def get_compute_capability():
ccs
=
[]
for
i
in
range
(
nGpus
.
value
):
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
i
))
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceComputeCapability
(
ctypes
.
byref
(
cc_major
),
ctypes
.
byref
(
cc_minor
),
device
))
ccs
.
append
(
f
'
{
cc_major
.
value
}
.
{
cc_minor
.
value
}
'
)
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceComputeCapability
(
ctypes
.
byref
(
cc_major
),
ctypes
.
byref
(
cc_minor
),
device
),
)
ccs
.
append
(
f
"
{
cc_major
.
value
}
.
{
cc_minor
.
value
}
"
)
#TODO: handle different compute capabilities; for now, take the max
#
TODO: handle different compute capabilities; for now, take the max
ccs
.
sort
()
return
ccs
[
-
1
]
# return ccs[-1]
return
ccs
CUDA_RUNTIME_LIB
:
str
=
"libcudart.so"
def
tokenize_paths
(
paths
:
str
)
->
Set
[
Path
]:
return
{
Path
(
ld_path
)
for
ld_path
in
paths
.
split
(
':'
)
if
ld_path
}
return
{
Path
(
ld_path
)
for
ld_path
in
paths
.
split
(
":"
)
if
ld_path
}
def
get_cuda_runtime_lib_path
(
# TODO: replace this with logic for all paths in env vars
LD_LIBRARY_PATH
:
Union
[
str
,
None
]
=
env
.
get
(
"LD_LIBRARY_PATH"
)
)
->
Union
[
Path
,
None
]:
""" # TODO: add doc-string
"""
"""# TODO: add doc-string"""
if
not
LD_LIBRARY_PATH
:
warn_of_missing_prerequisite
(
'
LD_LIBRARY_PATH is completely missing from environment!
'
"
LD_LIBRARY_PATH is completely missing from environment!
"
)
return
None
ld_library_paths
:
Set
[
Path
]
=
tokenize_paths
(
LD_LIBRARY_PATH
)
non_existent_directories
:
Set
[
Path
]
=
{
path
for
path
in
ld_library_paths
if
not
path
.
exists
()
path
for
path
in
ld_library_paths
if
not
path
.
exists
()
}
if
non_existent_directories
:
...
...
@@ -111,7 +119,8 @@ def get_cuda_runtime_lib_path(
)
cuda_runtime_libs
:
Set
[
Path
]
=
{
path
/
CUDA_RUNTIME_LIB
for
path
in
ld_library_paths
path
/
CUDA_RUNTIME_LIB
for
path
in
ld_library_paths
if
(
path
/
CUDA_RUNTIME_LIB
).
is_file
()
}
-
non_existent_directories
...
...
@@ -126,26 +135,31 @@ def get_cuda_runtime_lib_path(
single_cuda_runtime_lib_dir
=
next
(
iter
(
cuda_runtime_libs
))
return
single_cuda_runtime_lib_dir
def
evaluate_cuda_setup
():
cuda_path
=
get_cuda_runtime_lib_path
()
cc
=
get_compute_capability
()
binary_name
=
'
libbitsandbytes_cpu.so
'
binary_name
=
"
libbitsandbytes_cpu.so
"
if
not
(
has_gpu
:
=
bool
(
cc
)):
print
(
'WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library...'
)
print
(
"WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library..."
)
return
binary_name
has_cublaslt
=
cc
in
[
'
7.5
'
,
'
8.0
'
,
'
8.6
'
]
has_cublaslt
=
cc
in
[
"
7.5
"
,
"
8.0
"
,
"
8.6
"
]
# TODO:
# (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
cuda_home
=
str
(
Path
(
cuda_path
).
parent
.
parent
)
ls_output
,
err
=
execute_and_return
(
f
'
{
cuda_home
}
/bin/nvcc --version'
)
cuda_version
=
ls_output
.
split
(
'
\n
'
)[
3
].
split
(
','
)[
-
1
].
strip
().
lower
().
replace
(
'v'
,
''
)
major
,
minor
,
revision
=
cuda_version
.
split
(
'.'
)
cuda_version_string
=
f
'
{
major
}{
minor
}
'
ls_output
,
err
=
execute_and_return
(
f
"
{
cuda_home
}
/bin/nvcc --version"
)
cuda_version
=
(
ls_output
.
split
(
"
\n
"
)[
3
].
split
(
","
)[
-
1
].
strip
().
lower
().
replace
(
"v"
,
""
)
)
major
,
minor
,
revision
=
cuda_version
.
split
(
"."
)
cuda_version_string
=
f
"
{
major
}{
minor
}
"
binary_name
=
f
'libbitsandbytes_cuda
{
cuda_version_string
}
_
{
(
"cublaslt"
if
has_cublaslt
else
""
)
}
.so'
...
...
bitsandbytes/debug_cli.py
View file @
bfa0e332
import
typer
cli
=
typer
.
Typer
()
...
...
bitsandbytes/functional.py
View file @
bfa0e332
...
...
@@ -9,47 +9,68 @@ from typing import Tuple
import
torch
from
torch
import
Tensor
from
.cextension
import
lib
,
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
,
lib
name2qmap
=
{}
if
COMPILED_WITH_CUDA
:
'''
C FUNCTIONS FOR OPTIMIZERS
'''
"""
C FUNCTIONS FOR OPTIMIZERS
"""
str2optimizer32bit
=
{}
str2optimizer32bit
[
'
adam
'
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
'
momentum
'
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
'
rmsprop
'
]
=
(
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
)
str2optimizer32bit
[
'
adagrad
'
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
)
str2optimizer32bit
[
'
lars
'
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
'
lamb
'
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
"
adam
"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
"
momentum
"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
"
rmsprop
"
]
=
(
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
)
str2optimizer32bit
[
"
adagrad
"
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
)
str2optimizer32bit
[
"
lars
"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
"
lamb
"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer8bit
=
{}
str2optimizer8bit
[
'adam'
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
'momentum'
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
)
str2optimizer8bit
[
'rmsprop'
]
=
(
lib
.
crmsprop_static_8bit_g32
,
lib
.
crmsprop_static_8bit_g16
)
str2optimizer8bit
[
'lamb'
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
'lars'
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
)
str2optimizer8bit
[
"adam"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
"momentum"
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
,
)
str2optimizer8bit
[
"rmsprop"
]
=
(
lib
.
crmsprop_static_8bit_g32
,
lib
.
crmsprop_static_8bit_g16
,
)
str2optimizer8bit
[
"lamb"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
"lars"
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
,
)
str2optimizer8bit_blockwise
=
{}
str2optimizer8bit_blockwise
[
'adam'
]
=
(
lib
.
cadam_8bit_blockwise_fp32
,
lib
.
cadam_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'momentum'
]
=
(
lib
.
cmomentum_8bit_blockwise_fp32
,
lib
.
cmomentum_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'rmsprop'
]
=
(
lib
.
crmsprop_8bit_blockwise_fp32
,
lib
.
crmsprop_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'adagrad'
]
=
(
lib
.
cadagrad_8bit_blockwise_fp32
,
lib
.
cadagrad_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
"adam"
]
=
(
lib
.
cadam_8bit_blockwise_fp32
,
lib
.
cadam_8bit_blockwise_fp16
,
)
str2optimizer8bit_blockwise
[
"momentum"
]
=
(
lib
.
cmomentum_8bit_blockwise_fp32
,
lib
.
cmomentum_8bit_blockwise_fp16
,
)
str2optimizer8bit_blockwise
[
"rmsprop"
]
=
(
lib
.
crmsprop_8bit_blockwise_fp32
,
lib
.
crmsprop_8bit_blockwise_fp16
,
)
str2optimizer8bit_blockwise
[
"adagrad"
]
=
(
lib
.
cadagrad_8bit_blockwise_fp32
,
lib
.
cadagrad_8bit_blockwise_fp16
,
)
class
CUBLAS_Context
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
self
.
context
=
{}
#prev_device = torch.cuda.current_device()
#for i in range(torch.cuda.device_count()):
#
prev_device = torch.cuda.current_device()
#
for i in range(torch.cuda.device_count()):
# torch.cuda.set_device(torch.device('cuda', i))
# self.context.append(ct.c_void_p(lib.get_context()))
#torch.cuda.set_device(prev_device)
#
torch.cuda.set_device(prev_device)
@
classmethod
def
get_instance
(
cls
):
...
...
@@ -66,11 +87,12 @@ class CUBLAS_Context(object):
torch
.
cuda
.
set_device
(
prev_device
)
return
self
.
context
[
device
.
index
]
class
Cusparse_Context
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
self
.
context
=
ct
.
c_void_p
(
lib
.
get_cusparse
())
...
...
@@ -82,14 +104,16 @@ class Cusparse_Context(object):
cls
.
_instance
.
initialize
()
return
cls
.
_instance
def
create_linear_map
(
signed
=
True
):
if
signed
:
return
torch
.
linspace
(
-
1.0
,
1.0
,
256
)
else
:
return
torch
.
linspace
(
0.0
,
1.0
,
256
)
def
create_dynamic_map
(
signed
=
True
,
n
=
7
):
'''
"""
Creates the dynamic quantiztion map.
The dynamic data type is made up of a dynamic exponent and
...
...
@@ -103,46 +127,54 @@ def create_dynamic_map(signed=True, n=7):
For more details see
(8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
'''
"""
data
=
[]
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
additional_items
=
2
**
(
7
-
n
)
-
1
if
not
signed
:
additional_items
=
2
*
additional_items
additional_items
=
2
**
(
7
-
n
)
-
1
if
not
signed
:
additional_items
=
2
*
additional_items
for
i
in
range
(
n
):
fraction_items
=
2
**
(
i
+
7
-
n
)
+
1
if
signed
else
2
**
(
i
+
7
-
n
+
1
)
+
1
fraction_items
=
2
**
(
i
+
7
-
n
)
+
1
if
signed
else
2
**
(
i
+
7
-
n
+
1
)
+
1
boundaries
=
torch
.
linspace
(
0.1
,
1
,
fraction_items
)
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
if
signed
:
data
+=
(
-
(
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
data
+=
(
-
(
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
if
additional_items
>
0
:
boundaries
=
torch
.
linspace
(
0.1
,
1
,
additional_items
+
1
)
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
boundaries
=
torch
.
linspace
(
0.1
,
1
,
additional_items
+
1
)
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
if
signed
:
data
+=
(
-
(
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
data
+=
(
-
(
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
data
.
append
(
0
)
data
.
append
(
1.0
)
data
.
sort
()
return
Tensor
(
data
)
def
get_special_format_str
():
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
if
major
<
7
:
print
(
f
'Device with CUDA capability of
{
major
}
not supported for 8-bit matmul. Device has no tensor cores!'
)
print
(
f
"Device with CUDA capability of
{
major
}
not supported for 8-bit matmul. Device has no tensor cores!"
)
assert
major
>=
7
if
major
==
7
:
return
'col_turing'
elif
major
==
8
:
return
'col_ampere'
else
:
return
'col_turing'
if
major
==
7
:
return
"col_turing"
elif
major
==
8
:
return
"col_ampere"
else
:
return
"col_turing"
def
get_ptr
(
A
:
Tensor
)
->
ct
.
c_void_p
:
'''
"""
Get the ctypes pointer from a PyTorch Tensor.
Parameters
...
...
@@ -153,31 +185,39 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
Returns
-------
ctypes.c_void_p
'''
if
A
is
None
:
return
None
else
:
return
ct
.
c_void_p
(
A
.
data
.
storage
().
data_ptr
())
"""
if
A
is
None
:
return
None
else
:
return
ct
.
c_void_p
(
A
.
data
.
storage
().
data_ptr
())
def
pre_call
(
device
):
prev_device
=
torch
.
cuda
.
current_device
()
torch
.
cuda
.
set_device
(
device
)
return
prev_device
def
post_call
(
prev_device
):
torch
.
cuda
.
set_device
(
prev_device
)
def
get_transform_func
(
dtype
,
orderA
,
orderOut
,
transpose
=
False
):
name
=
f
'ctransform_
{
(
8
if
dtype
==
torch
.
int8
else
32
)
}
_
{
orderA
}
_to_
{
orderOut
}
_
{
"t"
if
transpose
else
"n"
}
'
if
not
hasattr
(
lib
,
name
):
print
(
name
)
raise
ValueError
(
f
'Transform function not supported:
{
orderA
}
to
{
orderOut
}
for data type
{
dtype
}
and transpose=
{
transpose
}
'
)
raise
ValueError
(
f
"Transform function not supported:
{
orderA
}
to
{
orderOut
}
for data type
{
dtype
}
and transpose=
{
transpose
}
"
)
else
:
return
getattr
(
lib
,
name
)
class
GlobalData
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
self
.
data
=
{}
...
...
@@ -190,15 +230,17 @@ class GlobalData(object):
return
cls
.
_instance
def
get_transform_buffer
(
shape
,
dtype
,
device
,
to_order
,
from_order
=
'row'
,
transpose
=
False
):
#init_func = torch.empty
def
get_transform_buffer
(
shape
,
dtype
,
device
,
to_order
,
from_order
=
"row"
,
transpose
=
False
):
# init_func = torch.empty
init_func
=
torch
.
zeros
dims
=
len
(
shape
)
if
dims
==
2
:
rows
=
shape
[
0
]
elif
dims
==
3
:
rows
=
shape
[
0
]
*
shape
[
1
]
rows
=
shape
[
0
]
*
shape
[
1
]
cols
=
shape
[
-
1
]
state
=
(
shape
,
to_order
)
...
...
@@ -209,30 +251,39 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order='row', trans
cols
=
tmp
state
=
(
shape
[::
-
1
],
to_order
)
if
to_order
==
'
row
'
or
to_order
==
'
col
'
:
if
to_order
==
"
row
"
or
to_order
==
"
col
"
:
return
init_func
(
shape
,
dtype
=
dtype
,
device
=
device
),
state
elif
to_order
==
'
col32
'
:
elif
to_order
==
"
col32
"
:
# blocks of 32 columns (padded)
cols
=
32
*
((
cols
+
31
)
//
32
)
cols
=
32
*
((
cols
+
31
)
//
32
)
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
elif
to_order
==
'
col_turing
'
:
elif
to_order
==
"
col_turing
"
:
# blocks of 32 columns and 8 rows
cols
=
32
*
((
cols
+
31
)
//
32
)
rows
=
8
*
((
rows
+
7
)
//
8
)
cols
=
32
*
((
cols
+
31
)
//
32
)
rows
=
8
*
((
rows
+
7
)
//
8
)
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
elif
to_order
==
'
col_ampere
'
:
elif
to_order
==
"
col_ampere
"
:
# blocks of 32 columns and 32 rows
cols
=
32
*
((
cols
+
31
)
//
32
)
rows
=
32
*
((
rows
+
31
)
//
32
)
cols
=
32
*
((
cols
+
31
)
//
32
)
rows
=
32
*
((
rows
+
31
)
//
32
)
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
else
:
raise
NotImplementedError
(
f
'
To_order not supported:
{
to_order
}
'
)
raise
NotImplementedError
(
f
"
To_order not supported:
{
to_order
}
"
)
def
nvidia_transform
(
A
,
to_order
,
from_order
=
'row'
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
else
:
from_order
=
state
[
1
]
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
])
else
:
new_state
=
(
state
[
1
],
to_order
)
def
nvidia_transform
(
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
else
:
from_order
=
state
[
1
]
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
]
)
else
:
new_state
=
(
state
[
1
],
to_order
)
func
=
get_transform_func
(
A
.
dtype
,
from_order
,
to_order
,
transpose
)
shape
=
state
[
0
]
...
...
@@ -242,10 +293,10 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
elif
ld
is
not
None
:
n
=
math
.
prod
(
shape
)
dim1
=
math
.
prod
([
shape
[
i
]
for
i
in
ld
])
dim2
=
ct
.
c_int32
(
n
//
dim1
)
dim2
=
ct
.
c_int32
(
n
//
dim1
)
dim1
=
ct
.
c_int32
(
dim1
)
else
:
dim1
=
ct
.
c_int32
(
shape
[
0
]
*
shape
[
1
])
dim1
=
ct
.
c_int32
(
shape
[
0
]
*
shape
[
1
])
dim2
=
ct
.
c_int32
(
shape
[
2
])
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
...
...
@@ -253,11 +304,13 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
ptrOut
=
get_ptr
(
out
)
func
(
ptr
,
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
return
out
,
new_state
def
estimate_quantiles
(
A
:
Tensor
,
out
:
Tensor
=
None
,
offset
:
float
=
1
/
512
)
->
Tensor
:
'''
def
estimate_quantiles
(
A
:
Tensor
,
out
:
Tensor
=
None
,
offset
:
float
=
1
/
512
)
->
Tensor
:
"""
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
...
...
@@ -282,18 +335,26 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
'''
if
out
is
None
:
out
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
"""
if
out
is
None
:
out
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
if
A
.
dtype
==
torch
.
float32
:
lib
.
cestimate_quantiles_fp32
(
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_float
(
offset
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cestimate_quantiles_fp32
(
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_float
(
offset
),
ct
.
c_int
(
A
.
numel
())
)
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cestimate_quantiles_fp16
(
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_float
(
offset
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cestimate_quantiles_fp16
(
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_float
(
offset
),
ct
.
c_int
(
A
.
numel
())
)
else
:
raise
NotImplementedError
(
f
'
Not supported data type
{
A
.
dtype
}
'
)
raise
NotImplementedError
(
f
"
Not supported data type
{
A
.
dtype
}
"
)
return
out
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
'''
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
"""
Quantize tensor A in blocks of size 4096 values.
Quantizes tensor A by dividing it into blocks of 4096 values.
...
...
@@ -319,51 +380,96 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
The 8-bit tensor.
tuple(torch.Tensor, torch.Tensor):
The quantization state to undo the quantization.
'''
"""
if
code
is
None
:
if
'dynamic'
not
in
name2qmap
:
name2qmap
[
'dynamic'
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
'dynamic'
]
if
"dynamic"
not
in
name2qmap
:
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
"dynamic"
]
code
=
code
.
to
(
A
.
device
)
if
absmax
is
None
:
n
=
A
.
numel
()
num_blocks
=
4096
blocks
=
n
//
num_blocks
blocks
=
n
//
num_blocks
blocks
+=
1
if
n
%
num_blocks
>
0
else
0
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
if
A
.
device
.
type
!=
'
cpu
'
:
if
A
.
device
.
type
!=
"
cpu
"
:
if
rand
is
not
None
:
assert
rand
.
numel
()
>=
1024
rand_offset
=
random
.
randint
(
0
,
1023
)
if
A
.
dtype
==
torch
.
float32
:
lib
.
cquantize_blockwise_stochastic_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_stochastic_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()),
)
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_stochastic_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_stochastic_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()),
)
else
:
raise
ValueError
(
f
'Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
'
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
else
:
if
A
.
dtype
==
torch
.
float32
:
lib
.
cquantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
else
:
raise
ValueError
(
f
'Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
'
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
else
:
# cpu
assert
rand
is
None
lib
.
cquantize_blockwise_cpu_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_cpu_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
return
out
,
(
absmax
,
code
)
def
dequantize_blockwise
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
4096
)
->
Tensor
:
'''
def
dequantize_blockwise
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
4096
,
)
->
Tensor
:
"""
Dequantizes blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in
...
...
@@ -387,57 +493,94 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
-------
torch.Tensor:
Dequantized tensor (default: float32)
'''
"""
assert
quant_state
is
not
None
or
absmax
is
not
None
if
code
is
None
and
quant_state
is
None
:
if
'dynamic'
not
in
name2qmap
:
name2qmap
[
'dynamic'
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
'dynamic'
]
if
"dynamic"
not
in
name2qmap
:
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
"dynamic"
]
code
=
code
.
to
(
A
.
device
)
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
)
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
)
if
blocksize
not
in
[
2048
,
4096
]:
raise
ValueError
(
f
'The blockwise of
{
blocksize
}
is not supported. Supported values: [2048 4096]'
)
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048 4096]"
)
if
A
.
device
.
type
!=
'
cpu
'
:
if
A
.
device
.
type
!=
"
cpu
"
:
if
out
.
dtype
==
torch
.
float32
:
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()),
)
elif
out
.
dtype
==
torch
.
float16
:
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()),
)
else
:
raise
ValueError
(
f
'Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
'
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
else
:
lib
.
cdequantize_blockwise_cpu_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_cpu_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
return
out
def
quantize
(
A
:
Tensor
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
def
quantize
(
A
:
Tensor
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
if
code
is
None
:
if
'dynamic'
not
in
name2qmap
:
name2qmap
[
'dynamic'
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
'dynamic'
]
if
"dynamic"
not
in
name2qmap
:
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
"dynamic"
]
code
=
code
.
to
(
A
.
device
)
absmax
=
torch
.
abs
(
A
).
max
()
inp
=
A
/
absmax
inp
=
A
/
absmax
out
=
quantize_no_absmax
(
inp
,
code
,
out
)
return
out
,
(
absmax
,
code
)
def
dequantize
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
def
dequantize
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
,
)
->
Tensor
:
assert
quant_state
is
not
None
or
absmax
is
not
None
if
code
is
None
and
quant_state
is
None
:
if
'dynamic'
not
in
name2qmap
:
name2qmap
[
'dynamic'
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
'dynamic'
]
if
"dynamic"
not
in
name2qmap
:
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
"dynamic"
]
code
=
code
.
to
(
A
.
device
)
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
)
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
)
out
=
dequantize_no_absmax
(
A
,
quant_state
[
1
],
out
)
return
out
*
quant_state
[
0
]
return
out
*
quant_state
[
0
]
def
quantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
'''
def
quantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
"""
Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
...
...
@@ -456,13 +599,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
-------
torch.Tensor:
Quantized 8-bit tensor.
'''
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
"""
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
lib
.
cquantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
return
out
def
dequantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
'''
def
dequantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
"""
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
...
...
@@ -481,17 +626,31 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
-------
torch.Tensor:
32-bit output tensor.
'''
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
"""
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
lib
.
cdequantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
return
out
def
optimizer_update_32bit
(
optimizer_name
:
str
,
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
beta1
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
state2
:
Tensor
=
None
,
beta2
:
float
=
0.0
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
unorm_vec
:
Tensor
=
None
,
max_unorm
:
float
=
0.0
,
skip_zeros
=
False
)
->
None
:
'''
def
optimizer_update_32bit
(
optimizer_name
:
str
,
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
beta1
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
state2
:
Tensor
=
None
,
beta2
:
float
=
0.0
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
unorm_vec
:
Tensor
=
None
,
max_unorm
:
float
=
0.0
,
skip_zeros
=
False
,
)
->
None
:
"""
Performs an inplace optimizer update with one or two optimizer states.
Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.
...
...
@@ -528,33 +687,84 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
The maximum update norm relative to the weight norm.
skip_zeros : bool
Whether to skip zero-valued gradients or not (default: False).
'''
"""
param_norm
=
0.0
if
max_unorm
>
0.0
:
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
if
optimizer_name
not
in
str2optimizer32bit
:
raise
NotImplementedError
(
f
'Optimizer not implemented:
{
optimizer_name
}
. Choices:
{
","
.
join
(
str2optimizer32bit
.
keys
())
}
'
)
raise
NotImplementedError
(
f
'Optimizer not implemented:
{
optimizer_name
}
. Choices:
{
","
.
join
(
str2optimizer32bit
.
keys
())
}
'
)
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
float32
:
str2optimizer32bit
[
optimizer_name
][
0
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
str2optimizer32bit
[
optimizer_name
][
0
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
float32
:
str2optimizer32bit
[
optimizer_name
][
1
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
str2optimizer32bit
[
optimizer_name
][
1
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
else
:
raise
ValueError
(
f
'Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
'
)
def
optimizer_update_8bit
(
optimizer_name
:
str
,
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
state2
:
Tensor
,
beta1
:
float
,
beta2
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
Tensor
,
qmap2
:
Tensor
,
max1
:
Tensor
,
max2
:
Tensor
,
new_max1
:
Tensor
,
new_max2
:
Tensor
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
unorm_vec
:
Tensor
=
None
,
max_unorm
:
float
=
0.0
)
->
None
:
'''
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
def
optimizer_update_8bit
(
optimizer_name
:
str
,
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
state2
:
Tensor
,
beta1
:
float
,
beta2
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
Tensor
,
qmap2
:
Tensor
,
max1
:
Tensor
,
max2
:
Tensor
,
new_max1
:
Tensor
,
new_max2
:
Tensor
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
unorm_vec
:
Tensor
=
None
,
max_unorm
:
float
=
0.0
,
)
->
None
:
"""
Performs an inplace Adam update.
Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
...
...
@@ -602,56 +812,135 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
'''
"""
param_norm
=
0.0
if
max_unorm
>
0.0
:
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
max1
),
get_ptr
(
max2
),
get_ptr
(
new_max1
),
get_ptr
(
new_max2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()))
str2optimizer8bit
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
max1
),
get_ptr
(
max2
),
get_ptr
(
new_max1
),
get_ptr
(
new_max2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()),
)
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit
[
optimizer_name
][
1
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
max1
),
get_ptr
(
max2
),
get_ptr
(
new_max1
),
get_ptr
(
new_max2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()))
str2optimizer8bit
[
optimizer_name
][
1
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
max1
),
get_ptr
(
max2
),
get_ptr
(
new_max1
),
get_ptr
(
new_max2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()),
)
else
:
raise
ValueError
(
f
'Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
'
)
def
optimizer_update_8bit_blockwise
(
optimizer_name
:
str
,
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
state2
:
Tensor
,
beta1
:
float
,
beta2
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
Tensor
,
qmap2
:
Tensor
,
absmax1
:
Tensor
,
absmax2
:
Tensor
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
skip_zeros
=
False
)
->
None
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
def
optimizer_update_8bit_blockwise
(
optimizer_name
:
str
,
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
state2
:
Tensor
,
beta1
:
float
,
beta2
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
Tensor
,
qmap2
:
Tensor
,
absmax1
:
Tensor
,
absmax2
:
Tensor
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
skip_zeros
=
False
,
)
->
None
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit_blockwise
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
str2optimizer8bit_blockwise
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit_blockwise
[
optimizer_name
][
1
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
str2optimizer8bit_blockwise
[
optimizer_name
][
1
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
else
:
raise
ValueError
(
f
'Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
'
)
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
def
percentile_clipping
(
grad
:
Tensor
,
gnorm_vec
:
Tensor
,
step
:
int
,
percentile
:
int
=
5
):
def
percentile_clipping
(
grad
:
Tensor
,
gnorm_vec
:
Tensor
,
step
:
int
,
percentile
:
int
=
5
):
"""Applies percentile clipping
grad: torch.Tensor
...
...
@@ -663,11 +952,21 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
"""
if
grad
.
dtype
==
torch
.
float32
:
lib
.
cpercentile_clipping_g32
(
get_ptr
(
grad
),
get_ptr
(
gnorm_vec
),
ct
.
c_int32
(
step
),
ct
.
c_int32
(
grad
.
numel
()))
lib
.
cpercentile_clipping_g32
(
get_ptr
(
grad
),
get_ptr
(
gnorm_vec
),
ct
.
c_int32
(
step
),
ct
.
c_int32
(
grad
.
numel
()),
)
elif
grad
.
dtype
==
torch
.
float16
:
lib
.
cpercentile_clipping_g16
(
get_ptr
(
grad
),
get_ptr
(
gnorm_vec
),
ct
.
c_int32
(
step
),
ct
.
c_int32
(
grad
.
numel
()))
lib
.
cpercentile_clipping_g16
(
get_ptr
(
grad
),
get_ptr
(
gnorm_vec
),
ct
.
c_int32
(
step
),
ct
.
c_int32
(
grad
.
numel
()),
)
else
:
raise
ValueError
(
f
'
Gradient type
{
grad
.
dtype
}
not supported!
'
)
raise
ValueError
(
f
"
Gradient type
{
grad
.
dtype
}
not supported!
"
)
current_gnorm
=
torch
.
sqrt
(
gnorm_vec
[
step
%
100
])
vals
,
idx
=
torch
.
sort
(
gnorm_vec
)
...
...
@@ -675,31 +974,44 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
gnorm_scale
=
1.0
if
current_gnorm
>
clip_value
:
gnorm_scale
=
clip_value
/
current_gnorm
gnorm_scale
=
clip_value
/
current_gnorm
return
current_gnorm
,
clip_value
,
gnorm_scale
def
histogram_scatter_add_2d
(
histogram
:
Tensor
,
index1
:
Tensor
,
index2
:
Tensor
,
source
:
Tensor
):
def
histogram_scatter_add_2d
(
histogram
:
Tensor
,
index1
:
Tensor
,
index2
:
Tensor
,
source
:
Tensor
):
assert
len
(
histogram
.
shape
)
==
2
assert
histogram
.
dtype
==
torch
.
float32
assert
source
.
dtype
==
torch
.
float32
assert
index1
.
dtype
==
torch
.
int32
assert
index2
.
dtype
==
torch
.
int32
assert
histogram
.
device
.
type
==
'
cuda
'
assert
index1
.
device
.
type
==
'
cuda
'
assert
index2
.
device
.
type
==
'
cuda
'
assert
source
.
device
.
type
==
'
cuda
'
assert
histogram
.
device
.
type
==
"
cuda
"
assert
index1
.
device
.
type
==
"
cuda
"
assert
index2
.
device
.
type
==
"
cuda
"
assert
source
.
device
.
type
==
"
cuda
"
maxdim1
=
ct
.
c_int32
(
histogram
.
shape
[
0
])
n
=
ct
.
c_int32
(
index1
.
numel
())
lib
.
chistogram_scatter_add_2d
(
get_ptr
(
histogram
),
get_ptr
(
index1
),
get_ptr
(
index2
),
get_ptr
(
source
),
maxdim1
,
n
)
lib
.
chistogram_scatter_add_2d
(
get_ptr
(
histogram
),
get_ptr
(
index1
),
get_ptr
(
index2
),
get_ptr
(
source
),
maxdim1
,
n
,
)
def
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
int8
):
if
not
torch
.
cuda
.
is_initialized
():
torch
.
cuda
.
init
()
if
not
torch
.
cuda
.
is_initialized
():
torch
.
cuda
.
init
()
if
A
.
dtype
!=
expected_type
or
B
.
dtype
!=
expected_type
:
raise
TypeError
(
f
'Expected torch.int8 input tensors A and B, but got
{
A
.
dtype
}
and
{
B
.
dtype
}
'
)
raise
TypeError
(
f
"Expected torch.int8 input tensors A and B, but got
{
A
.
dtype
}
and
{
B
.
dtype
}
"
)
sA
=
A
.
shape
sB
=
B
.
shape
...
...
@@ -709,64 +1021,101 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
correct
=
True
if
len
(
sA
)
==
2
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
0
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
0
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
if
not
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
0
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
0
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
1
]:
correct
=
False
if
not
tA
and
not
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
if
not
tA
and
not
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
2
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
2
]:
correct
=
False
if
not
tA
and
not
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
2
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
2
]:
correct
=
False
if
out
is
not
None
:
sout
=
out
.
shape
# special case common in backprop
if
not
correct
and
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
if
(
sout
[
0
]
==
sA
[
2
]
and
sout
[
1
]
==
sB
[
2
]
and
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
if
(
sout
[
0
]
==
sA
[
2
]
and
sout
[
1
]
==
sB
[
2
]
and
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]
):
correct
=
True
else
:
if
len
(
sA
)
==
2
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sB
[
1
])
elif
tA
and
tB
:
sout
=
(
sA
[
1
],
sB
[
0
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
1
],
sB
[
1
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sB
[
0
])
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sB
[
1
])
elif
tA
and
tB
:
sout
=
(
sA
[
1
],
sB
[
0
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
1
],
sB
[
1
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sB
[
0
])
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
1
])
elif
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
0
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
1
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
0
])
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
1
])
elif
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
0
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
1
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
0
])
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
2
])
elif
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
1
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
2
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
1
])
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
2
])
elif
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
1
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
2
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
1
])
if
not
correct
:
raise
ValueError
(
f
'Tensor dimensions incorrect for matrix mulitiplication: A x B:
{
sA
}
x
{
sB
}
with transpose for A x B:
{
tA
}
x
{
tB
}
.'
)
raise
ValueError
(
f
"Tensor dimensions incorrect for matrix mulitiplication: A x B:
{
sA
}
x
{
sB
}
with transpose for A x B:
{
tA
}
x
{
tB
}
."
)
return
sout
def
igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
):
def
igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
):
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
int32
,
device
=
A
.
device
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
int32
,
device
=
A
.
device
)
if
len
(
A
.
shape
)
==
3
and
len
(
B
.
shape
)
==
3
:
if
A
.
shape
[
0
]
==
B
.
shape
[
0
]
and
A
.
shape
[
2
]
==
B
.
shape
[
1
]:
return
batched_igemm
(
A
,
B
,
out
)
sA
=
A
.
shape
sB
=
B
.
shape
if
transposed_A
and
len
(
sA
)
==
2
:
sA
=
(
sA
[
1
],
sA
[
0
])
elif
transposed_A
and
len
(
sA
)
==
3
:
sA
=
(
sA
[
0
],
sA
[
2
],
sA
[
0
])
if
transposed_B
and
len
(
sB
)
==
2
:
sB
=
(
sB
[
1
],
sB
[
0
])
elif
transposed_B
and
len
(
sB
)
==
3
:
sB
=
(
sB
[
0
],
sB
[
2
],
sB
[
0
])
if
transposed_A
and
len
(
sA
)
==
2
:
sA
=
(
sA
[
1
],
sA
[
0
])
elif
transposed_A
and
len
(
sA
)
==
3
:
sA
=
(
sA
[
0
],
sA
[
2
],
sA
[
0
])
if
transposed_B
and
len
(
sB
)
==
2
:
sB
=
(
sB
[
1
],
sB
[
0
])
elif
transposed_B
and
len
(
sB
)
==
3
:
sB
=
(
sB
[
0
],
sB
[
2
],
sB
[
0
])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
...
...
@@ -777,23 +1126,28 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if
len
(
sB
)
==
2
:
if
B
.
stride
()[
0
]
==
B
.
shape
[
1
]:
transposed_B
=
False
elif
B
.
stride
()[
1
]
==
B
.
shape
[
0
]:
transposed_B
=
True
if
B
.
stride
()[
0
]
==
B
.
shape
[
1
]:
transposed_B
=
False
elif
B
.
stride
()[
1
]
==
B
.
shape
[
0
]:
transposed_B
=
True
if
len
(
A
.
shape
)
==
2
:
if
A
.
stride
()[
0
]
==
A
.
shape
[
1
]:
transposed_A
=
False
elif
A
.
stride
()[
1
]
==
A
.
shape
[
0
]:
transposed_A
=
True
if
A
.
stride
()[
0
]
==
A
.
shape
[
1
]:
transposed_A
=
False
elif
A
.
stride
()[
1
]
==
A
.
shape
[
0
]:
transposed_A
=
True
else
:
if
A
.
stride
()[
1
]
==
A
.
shape
[
2
]:
transposed_A
=
False
elif
A
.
stride
()[
2
]
==
A
.
shape
[
1
]:
transposed_A
=
True
if
A
.
stride
()[
1
]
==
A
.
shape
[
2
]:
transposed_A
=
False
elif
A
.
stride
()[
2
]
==
A
.
shape
[
1
]:
transposed_A
=
True
if
len
(
sA
)
==
2
:
n
=
sA
[
0
]
ldb
=
A
.
stride
()[
1
if
transposed_A
else
0
]
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
n
=
sA
[
0
]
*
sA
[
1
]
n
=
sA
[
0
]
*
sA
[
1
]
ldb
=
sA
[
2
]
m
=
sB
[
1
]
k
=
sB
[
0
]
lda
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
...
...
@@ -802,34 +1156,52 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# special case
assert
len
(
sA
)
==
3
if
not
(
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
raise
ValueError
(
f
'Only bsi,bso->io supported for tensor contractions, but dims for A x B were:
{
sA
}
x
{
sB
}
'
)
raise
ValueError
(
f
"Only bsi,bso->io supported for tensor contractions, but dims for A x B were:
{
sA
}
x
{
sB
}
"
)
transposed_A
=
True
transposed_B
=
False
m
=
sB
[
2
]
n
=
sA
[
2
]
k
=
sB
[
0
]
*
sB
[
1
]
k
=
sB
[
0
]
*
sB
[
1
]
lda
=
m
ldb
=
sA
[
2
]
ldc
=
m
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
# B^T @ A^T = C^T
# [km, nk -> mn]
lib
.
cigemm
(
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
))
lib
.
cigemm
(
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
),
)
return
out
def
batched_igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
):
def
batched_igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
):
if
not
len
(
A
.
shape
)
==
3
or
not
len
(
B
.
shape
)
==
3
:
raise
ValueError
(
f
'Expected 3-dimensional tensors for bmm, but got shapes A and B:
{
A
.
shape
}
and
{
B
.
shape
}
'
)
raise
ValueError
(
f
"Expected 3-dimensional tensors for bmm, but got shapes A and B:
{
A
.
shape
}
and
{
B
.
shape
}
"
)
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
int32
,
device
=
A
.
device
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
int32
,
device
=
A
.
device
)
if
B
.
is_contiguous
():
lda
=
B
.
stride
()[
1
]
...
...
@@ -886,17 +1258,33 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ldc
=
m
strideA
=
B
.
shape
[
1
]
*
B
.
shape
[
2
]
strideB
=
A
.
shape
[
1
]
*
A
.
shape
[
2
]
strideC
=
A
.
shape
[
1
]
*
B
.
shape
[
2
]
strideA
=
B
.
shape
[
1
]
*
B
.
shape
[
2
]
strideB
=
A
.
shape
[
1
]
*
A
.
shape
[
2
]
strideC
=
A
.
shape
[
1
]
*
B
.
shape
[
2
]
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
lib
.
cbatched_igemm
(
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
),
ct
.
c_long
(
strideA
),
ct
.
c_long
(
strideB
),
ct
.
c_long
(
strideC
),
ct
.
c_uint32
(
num_batch
))
lib
.
cbatched_igemm
(
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
),
ct
.
c_long
(
strideA
),
ct
.
c_long
(
strideB
),
ct
.
c_long
(
strideC
),
ct
.
c_uint32
(
num_batch
),
)
return
out
def
igemmlt
(
A
,
B
,
SA
,
SB
,
out
=
None
,
Sout
=
None
,
dtype
=
torch
.
int32
):
shapeA
=
SA
[
0
]
shapeB
=
SB
[
0
]
...
...
@@ -905,28 +1293,34 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
if
dimsA
==
2
:
m
=
shapeA
[
0
]
elif
dimsA
==
3
:
m
=
shapeA
[
0
]
*
shapeA
[
1
]
m
=
shapeA
[
0
]
*
shapeA
[
1
]
if
dimsB
==
2
:
rows
=
n
=
shapeB
[
0
]
elif
dimsB
==
3
:
rows
=
n
=
shapeB
[
0
]
*
shapeB
[
1
]
rows
=
n
=
shapeB
[
0
]
*
shapeB
[
1
]
if
dimsA
==
2
and
out
is
None
:
out
,
Sout
=
get_transform_buffer
((
shapeA
[
0
],
shapeB
[
0
]),
dtype
,
A
.
device
,
'col32'
,
'row'
)
out
,
Sout
=
get_transform_buffer
(
(
shapeA
[
0
],
shapeB
[
0
]),
dtype
,
A
.
device
,
"col32"
,
"row"
)
elif
dimsA
==
3
and
out
is
None
:
out
,
Sout
=
get_transform_buffer
((
shapeA
[
0
],
shapeA
[
1
],
shapeB
[
0
]),
dtype
,
A
.
device
,
'col32'
,
'row'
)
out
,
Sout
=
get_transform_buffer
(
(
shapeA
[
0
],
shapeA
[
1
],
shapeB
[
0
]),
dtype
,
A
.
device
,
"col32"
,
"row"
)
assert
dimsB
!=
3
,
'
len(B.shape)==3 not supported
'
assert
A
.
device
.
type
==
'
cuda
'
assert
B
.
device
.
type
==
'
cuda
'
assert
dimsB
!=
3
,
"
len(B.shape)==3 not supported
"
assert
A
.
device
.
type
==
"
cuda
"
assert
B
.
device
.
type
==
"
cuda
"
assert
A
.
dtype
==
torch
.
int8
assert
B
.
dtype
==
torch
.
int8
assert
out
.
dtype
==
dtype
assert
SA
[
1
]
==
'col32'
assert
SB
[
1
]
in
[
'col_turing'
,
'col_ampere'
]
assert
Sout
[
1
]
==
'col32'
assert
shapeA
[
-
1
]
==
shapeB
[
-
1
],
f
'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B =
{
shapeA
}
@
{
shapeB
}
'
assert
SA
[
1
]
==
"col32"
assert
SB
[
1
]
in
[
"col_turing"
,
"col_ampere"
]
assert
Sout
[
1
]
==
"col32"
assert
(
shapeA
[
-
1
]
==
shapeB
[
-
1
]
),
f
"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B =
{
shapeA
}
@
{
shapeB
}
"
formatB
=
SB
[
1
]
prev_device
=
A
.
device
torch
.
cuda
.
set_device
(
A
.
device
)
...
...
@@ -937,53 +1331,76 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
ptrC
=
get_ptr
(
out
)
k
=
shapeA
[
-
1
]
lda
=
ct
.
c_int32
(
m
*
32
)
if
formatB
==
'
col_turing
'
:
lda
=
ct
.
c_int32
(
m
*
32
)
if
formatB
==
"
col_turing
"
:
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# n = rows
ldb
=
ct
.
c_int32
(((
rows
+
7
)
//
8
)
*
8
*
32
)
ldb
=
ct
.
c_int32
(((
rows
+
7
)
//
8
)
*
8
*
32
)
else
:
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# n = rows
ldb
=
ct
.
c_int32
(((
rows
+
31
)
//
32
)
*
32
*
32
)
ldb
=
ct
.
c_int32
(((
rows
+
31
)
//
32
)
*
32
*
32
)
ldc
=
ct
.
c_int32
(
m
*
32
)
ldc
=
ct
.
c_int32
(
m
*
32
)
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
k
=
ct
.
c_int32
(
k
)
has_error
=
0
ptrRowScale
=
get_ptr
(
None
)
if
formatB
==
'
col_turing
'
:
if
formatB
==
"
col_turing
"
:
if
dtype
==
torch
.
int32
:
has_error
=
lib
.
cigemmlt_turing_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
has_error
=
lib
.
cigemmlt_turing_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
else
:
has_error
=
lib
.
cigemmlt_turing_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
elif
formatB
==
'col_ampere'
:
has_error
=
lib
.
cigemmlt_turing_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
elif
formatB
==
"col_ampere"
:
if
dtype
==
torch
.
int32
:
has_error
=
lib
.
cigemmlt_ampere_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
has_error
=
lib
.
cigemmlt_ampere_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
else
:
has_error
=
lib
.
cigemmlt_ampere_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
has_error
=
lib
.
cigemmlt_ampere_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
if
has_error
==
1
:
raise
Exception
(
'
cublasLt ran into an error!
'
)
raise
Exception
(
"
cublasLt ran into an error!
"
)
torch
.
cuda
.
set_device
(
prev_device
)
return
out
,
Sout
def
mm_dequant
(
A
,
quant_state
,
row_stats
,
col_stats
,
out
=
None
,
new_row_stats
=
None
,
new_col_stats
=
None
):
def
mm_dequant
(
A
,
quant_state
,
row_stats
,
col_stats
,
out
=
None
,
new_row_stats
=
None
,
new_col_stats
=
None
,
):
assert
A
.
dtype
==
torch
.
int32
out_shape
=
quant_state
[
0
]
if
len
(
out_shape
)
==
3
:
out_shape
=
(
out_shape
[
0
]
*
out_shape
[
1
],
out_shape
[
2
])
if
out
is
None
:
out
=
torch
.
empty
(
out_shape
,
dtype
=
torch
.
float16
,
device
=
A
.
device
)
if
new_row_stats
is
None
:
new_row_stats
=
torch
.
empty
(
out_shape
[
0
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
if
new_col_stats
is
None
:
new_col_stats
=
torch
.
empty
(
out_shape
[
1
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
assert
new_row_stats
.
shape
[
0
]
==
row_stats
.
shape
[
0
],
f
"
{
new_row_stats
.
shape
}
vs
{
row_stats
.
shape
}
"
assert
new_col_stats
.
shape
[
0
]
==
col_stats
.
shape
[
0
],
f
"
{
new_col_stats
.
shape
}
vs
{
col_stats
.
shape
}
"
if
len
(
out_shape
)
==
3
:
out_shape
=
(
out_shape
[
0
]
*
out_shape
[
1
],
out_shape
[
2
])
if
out
is
None
:
out
=
torch
.
empty
(
out_shape
,
dtype
=
torch
.
float16
,
device
=
A
.
device
)
if
new_row_stats
is
None
:
new_row_stats
=
torch
.
empty
(
out_shape
[
0
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
if
new_col_stats
is
None
:
new_col_stats
=
torch
.
empty
(
out_shape
[
1
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
assert
(
new_row_stats
.
shape
[
0
]
==
row_stats
.
shape
[
0
]
),
f
"
{
new_row_stats
.
shape
}
vs
{
row_stats
.
shape
}
"
assert
(
new_col_stats
.
shape
[
0
]
==
col_stats
.
shape
[
0
]
),
f
"
{
new_col_stats
.
shape
}
vs
{
col_stats
.
shape
}
"
ptrA
=
get_ptr
(
A
)
ptrOut
=
get_ptr
(
out
)
...
...
@@ -994,27 +1411,47 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
numRows
=
ct
.
c_int32
(
out_shape
[
0
])
numCols
=
ct
.
c_int32
(
out_shape
[
1
])
lib
.
cdequant_mm_int32_fp16
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOut
,
ptrNewRowStats
,
ptrNewColStats
,
numRows
,
numCols
)
lib
.
cdequant_mm_int32_fp16
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOut
,
ptrNewRowStats
,
ptrNewColStats
,
numRows
,
numCols
,
)
return
out
def
get_colrow_absmax
(
A
,
row_stats
=
None
,
col_stats
=
None
,
nnz_block_ptr
=
None
,
threshold
=
0.0
):
def
get_colrow_absmax
(
A
,
row_stats
=
None
,
col_stats
=
None
,
nnz_block_ptr
=
None
,
threshold
=
0.0
):
assert
A
.
dtype
==
torch
.
float16
device
=
A
.
device
cols
=
A
.
shape
[
-
1
]
if
len
(
A
.
shape
)
==
3
:
rows
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
rows
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
else
:
rows
=
A
.
shape
[
0
]
col_tiles
=
(
cols
+
255
)
//
256
tiled_rows
=
((
rows
+
15
)
//
16
)
*
16
if
row_stats
is
None
:
row_stats
=
torch
.
empty
((
rows
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
if
col_stats
is
None
:
col_stats
=
torch
.
empty
((
cols
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
if
nnz_block_ptr
is
None
and
threshold
>
0.0
:
nnz_block_ptr
=
torch
.
zeros
(((
tiled_rows
*
col_tiles
)
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
col_tiles
=
(
cols
+
255
)
//
256
tiled_rows
=
((
rows
+
15
)
//
16
)
*
16
if
row_stats
is
None
:
row_stats
=
torch
.
empty
((
rows
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
if
col_stats
is
None
:
col_stats
=
torch
.
empty
((
cols
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
if
nnz_block_ptr
is
None
and
threshold
>
0.0
:
nnz_block_ptr
=
torch
.
zeros
(
((
tiled_rows
*
col_tiles
)
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
ptrA
=
get_ptr
(
A
)
ptrRowStats
=
get_ptr
(
row_stats
)
...
...
@@ -1024,16 +1461,17 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr
cols
=
ct
.
c_int32
(
cols
)
prev_device
=
pre_call
(
A
.
device
)
lib
.
cget_col_row_stats
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
)
lib
.
cget_col_row_stats
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
)
post_call
(
prev_device
)
if
threshold
>
0.0
:
nnz_block_ptr
.
cumsum_
(
0
)
return
row_stats
,
col_stats
,
nnz_block_ptr
class
COOSparseTensor
(
object
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
rowidx
,
colidx
,
values
):
assert
rowidx
.
dtype
==
torch
.
int32
...
...
@@ -1050,6 +1488,7 @@ class COOSparseTensor(object):
self
.
colidx
=
colidx
self
.
values
=
values
class
CSRSparseTensor
(
object
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
rowptr
,
colidx
,
values
):
assert
rowptr
.
dtype
==
torch
.
int32
...
...
@@ -1057,7 +1496,7 @@ class CSRSparseTensor(object):
assert
values
.
dtype
==
torch
.
float16
assert
values
.
numel
()
==
nnz
assert
colidx
.
numel
()
==
nnz
assert
rowptr
.
numel
()
==
rows
+
1
assert
rowptr
.
numel
()
==
rows
+
1
self
.
rows
=
rows
self
.
cols
=
cols
...
...
@@ -1066,6 +1505,7 @@ class CSRSparseTensor(object):
self
.
colidx
=
colidx
self
.
values
=
values
class
CSCSparseTensor
(
object
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
colptr
,
rowidx
,
values
):
assert
colptr
.
dtype
==
torch
.
int32
...
...
@@ -1073,7 +1513,7 @@ class CSCSparseTensor(object):
assert
values
.
dtype
==
torch
.
float16
assert
values
.
numel
()
==
nnz
assert
rowidx
.
numel
()
==
nnz
assert
colptr
.
numel
()
==
cols
+
1
assert
colptr
.
numel
()
==
cols
+
1
self
.
rows
=
rows
self
.
cols
=
cols
...
...
@@ -1082,13 +1522,17 @@ class CSCSparseTensor(object):
self
.
rowidx
=
rowidx
self
.
values
=
values
def
coo2csr
(
cooA
):
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
values
.
add_
(
1
)
rowptr
=
torch
.
zeros
((
cooA
.
rows
+
1
,
),
dtype
=
torch
.
int32
,
device
=
cooA
.
rowidx
.
device
)
rowptr
=
torch
.
zeros
((
cooA
.
rows
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
rowidx
.
device
)
rowptr
.
scatter_
(
index
=
values
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
rowptr
.
cumsum_
(
0
)
return
CSRSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
rowptr
,
cooA
.
colidx
,
cooA
.
values
)
return
CSRSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
rowptr
,
cooA
.
colidx
,
cooA
.
values
)
def
coo2csc
(
cooA
):
val
,
col2rowidx
=
torch
.
sort
(
cooA
.
colidx
)
...
...
@@ -1096,11 +1540,12 @@ def coo2csc(cooA):
values
=
cooA
.
values
[
col2rowidx
]
colvalues
,
counts
=
torch
.
unique
(
val
,
return_counts
=
True
)
colvalues
.
add_
(
1
)
colptr
=
torch
.
zeros
((
cooA
.
cols
+
1
,
),
dtype
=
torch
.
int32
,
device
=
cooA
.
colidx
.
device
)
colptr
=
torch
.
zeros
((
cooA
.
cols
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
colidx
.
device
)
colptr
.
scatter_
(
index
=
colvalues
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
colptr
.
cumsum_
(
0
)
return
CSCSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
colptr
,
rowidx
,
values
)
def
coo_zeros
(
rows
,
cols
,
nnz
,
device
,
dtype
=
torch
.
half
):
rowidx
=
torch
.
zeros
((
nnz
,),
dtype
=
torch
.
int32
,
device
=
device
)
colidx
=
torch
.
zeros
((
nnz
,),
dtype
=
torch
.
int32
,
device
=
device
)
...
...
@@ -1108,23 +1553,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return
COOSparseTensor
(
rows
,
cols
,
nnz
,
rowidx
,
colidx
,
values
)
def
double_quant
(
A
,
col_stats
=
None
,
row_stats
=
None
,
out_col
=
None
,
out_row
=
None
,
threshold
=
0.0
):
def
double_quant
(
A
,
col_stats
=
None
,
row_stats
=
None
,
out_col
=
None
,
out_row
=
None
,
threshold
=
0.0
):
device
=
A
.
device
assert
A
.
dtype
==
torch
.
half
assert
device
.
type
==
'
cuda
'
assert
device
.
type
==
"
cuda
"
prev_device
=
pre_call
(
A
.
device
)
cols
=
A
.
shape
[
-
1
]
if
len
(
A
.
shape
)
==
3
:
rows
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
rows
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
else
:
rows
=
A
.
shape
[
0
]
if
row_stats
is
None
or
col_stats
is
None
:
row_stats
,
col_stats
,
nnz_row_ptr
=
get_colrow_absmax
(
A
,
threshold
=
threshold
)
if
out_col
is
None
:
out_col
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
if
out_row
is
None
:
out_row
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
if
out_col
is
None
:
out_col
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
if
out_row
is
None
:
out_row
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
coo_tensor
=
None
ptrA
=
get_ptr
(
A
)
...
...
@@ -1136,21 +1585,62 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
if
threshold
>
0.0
:
nnz
=
nnz_row_ptr
[
-
1
].
item
()
if
nnz
>
0
:
coo_tensor
=
coo_zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz_row_ptr
[
-
1
].
item
(),
device
)
coo_tensor
=
coo_zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz_row_ptr
[
-
1
].
item
(),
device
)
ptrRowIdx
=
get_ptr
(
coo_tensor
.
rowidx
)
ptrColIdx
=
get_ptr
(
coo_tensor
.
colidx
)
ptrVal
=
get_ptr
(
coo_tensor
.
values
)
ptrRowPtr
=
get_ptr
(
nnz_row_ptr
)
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
ptrRowIdx
,
ptrColIdx
,
ptrVal
,
ptrRowPtr
,
ct
.
c_float
(
threshold
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
))
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
ptrRowIdx
,
ptrColIdx
,
ptrVal
,
ptrRowPtr
,
ct
.
c_float
(
threshold
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
),
)
val
,
idx
=
torch
.
sort
(
coo_tensor
.
rowidx
)
coo_tensor
.
rowidx
=
val
coo_tensor
.
colidx
=
coo_tensor
.
colidx
[
idx
]
coo_tensor
.
values
=
coo_tensor
.
values
[
idx
]
else
:
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
None
,
None
,
None
,
None
,
ct
.
c_float
(
0.0
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
))
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
None
,
None
,
None
,
None
,
ct
.
c_float
(
0.0
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
),
)
else
:
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
None
,
None
,
None
,
None
,
ct
.
c_float
(
threshold
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
))
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
None
,
None
,
None
,
None
,
ct
.
c_float
(
threshold
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
),
)
post_call
(
prev_device
)
return
out_row
,
out_col
,
row_stats
,
col_stats
,
coo_tensor
...
...
@@ -1159,69 +1649,81 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
def
get_special_format_str
():
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
if
major
<
7
:
print
(
f
'Device with CUDA capability of
{
major
}
not supported for 8-bit matmul. Device has no tensor cores!'
)
print
(
f
"Device with CUDA capability of
{
major
}
not supported for 8-bit matmul. Device has no tensor cores!"
)
assert
major
>=
7
if
major
==
7
:
return
'col_turing'
elif
major
==
8
:
return
'col_ampere'
else
:
return
'col_turing'
if
major
==
7
:
return
"col_turing"
elif
major
==
8
:
return
"col_ampere"
else
:
return
"col_turing"
def
transform
(
A
,
to_order
,
from_order
=
'row'
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
else
:
from_order
=
state
[
1
]
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
],
transpose
)
else
:
new_state
=
(
state
[
0
],
to_order
)
# (shape, order)
def
transform
(
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
else
:
from_order
=
state
[
1
]
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
],
transpose
)
else
:
new_state
=
(
state
[
0
],
to_order
)
# (shape, order)
shape
=
state
[
0
]
if
len
(
shape
)
==
2
:
dim1
=
ct
.
c_int32
(
shape
[
0
])
dim2
=
ct
.
c_int32
(
shape
[
1
])
else
:
dim1
=
ct
.
c_int32
(
shape
[
0
]
*
shape
[
1
])
dim1
=
ct
.
c_int32
(
shape
[
0
]
*
shape
[
1
])
dim2
=
ct
.
c_int32
(
shape
[
2
])
ptrA
=
get_ptr
(
A
)
ptrOut
=
get_ptr
(
out
)
if
to_order
==
'
col32
'
:
if
to_order
==
"
col32
"
:
if
transpose
:
lib
.
ctransform_row2col32T
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
lib
.
ctransform_row2col32
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
to_order
==
'
col_turing
'
:
elif
to_order
==
"
col_turing
"
:
if
transpose
:
lib
.
ctransform_row2turingT
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
lib
.
ctransform_row2turing
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
to_order
==
'
col_ampere
'
:
elif
to_order
==
"
col_ampere
"
:
if
transpose
:
lib
.
ctransform_row2ampereT
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
lib
.
ctransform_row2ampere
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
to_order
==
'
row
'
:
if
from_order
==
'
col_turing
'
:
elif
to_order
==
"
row
"
:
if
from_order
==
"
col_turing
"
:
lib
.
ctransform_turing2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
from_order
==
'
col_ampere
'
:
elif
from_order
==
"
col_ampere
"
:
lib
.
ctransform_ampere2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
raise
NotImplementedError
(
f
'Transform function not implemented: From
{
from_order
}
to
{
to_order
}
'
)
raise
NotImplementedError
(
f
"Transform function not implemented: From
{
from_order
}
to
{
to_order
}
"
)
return
out
,
new_state
def
spmm_coo
(
cooA
,
B
,
out
=
None
):
if
out
is
None
:
out
=
torch
.
empty
((
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
B
.
dtype
)
if
out
is
None
:
out
=
torch
.
empty
((
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
B
.
dtype
)
nnz
=
cooA
.
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
values
.
numel
()
==
nnz
assert
cooA
.
cols
==
B
.
shape
[
0
]
transposed_B
=
(
False
if
B
.
is_contiguous
()
else
True
)
transposed_B
=
False
if
B
.
is_contiguous
()
else
True
ldb
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
ldc
=
B
.
shape
[
1
]
...
...
@@ -1240,19 +1742,37 @@ def spmm_coo(cooA, B, out=None):
cldb
=
ct
.
c_int32
(
ldb
)
cldc
=
ct
.
c_int32
(
ldc
)
lib
.
cspmm_coo
(
ptr
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
cnnz
,
crowsA
,
ccolsA
,
ccolsB
,
cldb
,
ptrB
,
cldc
,
ptrC
,
ct
.
c_bool
(
transposed_B
))
lib
.
cspmm_coo
(
ptr
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
cnnz
,
crowsA
,
ccolsA
,
ccolsB
,
cldb
,
ptrB
,
cldc
,
ptrC
,
ct
.
c_bool
(
transposed_B
),
)
return
out
def
spmm_coo_very_sparse
(
cooA
,
B
,
dequant_stats
=
None
,
out
=
None
):
if
out
is
None
:
out
=
torch
.
zeros
((
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
cooA
.
values
.
dtype
)
if
out
is
None
:
out
=
torch
.
zeros
(
(
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
cooA
.
values
.
dtype
)
nnz
=
cooA
.
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
values
.
numel
()
==
nnz
assert
cooA
.
cols
==
B
.
shape
[
0
],
f
'
{
cooA
.
cols
}
vs
{
B
.
shape
}
'
assert
cooA
.
cols
==
B
.
shape
[
0
],
f
"
{
cooA
.
cols
}
vs
{
B
.
shape
}
"
transposed_B
=
(
False
if
B
.
is_contiguous
()
else
True
)
transposed_B
=
False
if
B
.
is_contiguous
()
else
True
ldb
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
ldc
=
B
.
shape
[
1
]
...
...
@@ -1262,7 +1782,9 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
max_count
,
max_idx
=
torch
.
sort
(
counts
,
descending
=
True
)
max_idx
=
max_idx
.
int
()
max_count
=
max_count
.
int
()
assert
max_count
[
0
]
<=
32
,
f
'Current max count per row is 8 but found
{
max_count
[
0
]
}
.'
assert
(
max_count
[
0
]
<=
32
),
f
"Current max count per row is 8 but found
{
max_count
[
0
]
}
."
assert
B
.
dtype
in
[
torch
.
float16
,
torch
.
int8
]
ptrOffset
=
get_ptr
(
offset
)
ptrMaxCount
=
get_ptr
(
max_count
)
...
...
@@ -1282,134 +1804,183 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB
=
ct
.
c_int32
(
B
.
shape
[
1
])
cldb
=
ct
.
c_int32
(
ldb
)
cldc
=
ct
.
c_int32
(
ldc
)
#print(cooA.rowidx[:64])
#print(cooA.colidx[:64].sort()[0])
#
print(cooA.rowidx[:64])
#
print(cooA.colidx[:64].sort()[0])
if
B
.
dtype
==
torch
.
float16
:
lib
.
cspmm_coo_very_sparse_naive_fp16
(
ptrMaxCount
,
ptrMaxIdx
,
ptrOffset
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
ptrB
,
ptrC
,
ptrDequantStats
,
cnnz_rows
,
cnnz
,
crowsA
,
crowsB
,
ccolsB
)
lib
.
cspmm_coo_very_sparse_naive_fp16
(
ptrMaxCount
,
ptrMaxIdx
,
ptrOffset
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
ptrB
,
ptrC
,
ptrDequantStats
,
cnnz_rows
,
cnnz
,
crowsA
,
crowsB
,
ccolsB
,
)
elif
B
.
dtype
==
torch
.
int8
:
lib
.
cspmm_coo_very_sparse_naive_int8
(
ptrMaxCount
,
ptrMaxIdx
,
ptrOffset
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
ptrB
,
ptrC
,
ptrDequantStats
,
cnnz_rows
,
cnnz
,
crowsA
,
crowsB
,
ccolsB
)
#else: assertion error
lib
.
cspmm_coo_very_sparse_naive_int8
(
ptrMaxCount
,
ptrMaxIdx
,
ptrOffset
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
ptrB
,
ptrC
,
ptrDequantStats
,
cnnz_rows
,
cnnz
,
crowsA
,
crowsB
,
ccolsB
,
)
# else: assertion error
return
out
C
=
127.0
def
vectorwise_quant
(
x
,
dim
=
1
,
quant_type
=
'vector'
):
if
quant_type
==
'linear'
:
def
vectorwise_quant
(
x
,
dim
=
1
,
quant_type
=
"vector"
):
if
quant_type
==
"linear"
:
max1
=
torch
.
abs
(
x
).
max
().
float
()
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
elif
quant_type
in
[
'
vector
'
,
'
row
'
]:
elif
quant_type
in
[
"
vector
"
,
"
row
"
]:
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
xq
=
torch
.
round
(
x
*
(
C
/
max1
)).
to
(
torch
.
int8
)
xq
=
torch
.
round
(
x
*
(
C
/
max1
)).
to
(
torch
.
int8
)
return
xq
,
max1
elif
quant_type
==
'
zeropoint
'
:
elif
quant_type
==
"
zeropoint
"
:
dtype
=
x
.
dtype
x
=
x
.
float
()
dyna
=
x
.
max
()
-
x
.
min
()
if
dyna
==
0
:
dyna
=
1
qx
=
255.
/
dyna
if
dyna
==
0
:
dyna
=
1
qx
=
255.0
/
dyna
minx
=
x
.
min
()
zpx
=
torch
.
round
(
minx
*
qx
)
x
=
torch
.
round
(
qx
*
x
-
zpx
)
+
zpx
zpx
=
torch
.
round
(
minx
*
qx
)
x
=
torch
.
round
(
qx
*
x
-
zpx
)
+
zpx
return
x
,
qx
elif
quant_type
in
[
'
vector-zeropoint
'
,
'
row-zeropoint
'
]:
elif
quant_type
in
[
"
vector-zeropoint
"
,
"
row-zeropoint
"
]:
dtype
=
x
.
dtype
x
=
x
.
float
()
dyna
=
(
torch
.
amax
(
x
,
dim
=
dim
,
keepdim
=
True
)
-
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
))
dyna
[
dyna
==
0
]
=
1
qx
=
255.
/
dyna
dyna
=
torch
.
amax
(
x
,
dim
=
dim
,
keepdim
=
True
)
-
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
)
dyna
[
dyna
==
0
]
=
1
qx
=
255.0
/
dyna
minx
=
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
)
zpx
=
torch
.
round
(
minx
*
qx
)
x
=
torch
.
round
(
qx
*
x
-
zpx
)
+
zpx
zpx
=
torch
.
round
(
minx
*
qx
)
x
=
torch
.
round
(
qx
*
x
-
zpx
)
+
zpx
return
x
,
qx
elif
quant_type
==
'
truncated-vector
'
:
elif
quant_type
==
"
truncated-vector
"
:
with
torch
.
no_grad
():
absx
=
torch
.
abs
(
x
)
max1
=
torch
.
amax
(
absx
,
dim
=
dim
,
keepdim
=
True
)
max1
=
max1
*
0.7
idx
=
(
absx
>
max1
.
expand_as
(
absx
)
)
max1
=
max1
*
0.7
idx
=
absx
>
max1
.
expand_as
(
absx
)
sign
=
torch
.
sign
(
x
[
idx
])
x
[
idx
]
=
max1
.
expand_as
(
absx
)[
idx
]
*
sign
xq
=
torch
.
round
(
x
/
max1
*
C
).
to
(
torch
.
int8
)
x
[
idx
]
=
max1
.
expand_as
(
absx
)[
idx
]
*
sign
xq
=
torch
.
round
(
x
/
max1
*
C
).
to
(
torch
.
int8
)
return
xq
,
max1
else
:
return
None
else
:
return
None
def
vectorwise_dequant
(
xq
,
max1
,
quant_type
=
'
vector
'
):
if
quant_type
==
'
vector
'
:
x
=
(
xq
/
C
*
max1
).
to
(
torch
.
float32
)
def
vectorwise_dequant
(
xq
,
max1
,
quant_type
=
"
vector
"
):
if
quant_type
==
"
vector
"
:
x
=
(
xq
/
C
*
max1
).
to
(
torch
.
float32
)
return
x
else
:
return
None
else
:
return
None
def
vectorwise_mm_dequant
(
xq
,
S1
,
S2
,
dtype
=
torch
.
half
,
quant_type
=
'
vector
'
):
if
quant_type
==
'
linear
'
:
norm
=
S1
*
S2
/
(
C
*
C
)
def
vectorwise_mm_dequant
(
xq
,
S1
,
S2
,
dtype
=
torch
.
half
,
quant_type
=
"
vector
"
):
if
quant_type
==
"
linear
"
:
norm
=
S1
*
S2
/
(
C
*
C
)
# double cast needed to prevent overflows
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
'
zeropoint
'
:
norm
=
1.0
/
(
S1
*
S2
)
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
'
row-zeropoint
'
:
norm
=
1.0
/
(
S1
*
S2
)
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
"
zeropoint
"
:
norm
=
1.0
/
(
S1
*
S2
)
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
"
row-zeropoint
"
:
norm
=
1.0
/
(
S1
*
S2
)
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
x
*=
norm
else
:
x
*=
norm
return
x
.
to
(
dtype
)
elif
quant_type
==
'
vector-zeropoint
'
:
elif
quant_type
==
"
vector-zeropoint
"
:
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
x
*=
1.0
/
S1
x
*=
1.0
/
S1
else
:
x
*=
1.0
/
S1
x
*=
1.0
/
S2
.
t
()
x
*=
1.0
/
S1
x
*=
1.0
/
S2
.
t
()
return
x
.
to
(
dtype
)
elif
quant_type
==
'
row
'
:
elif
quant_type
==
"
row
"
:
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
x
*=
S1
*
S2
/
(
C
*
C
)
x
*=
S1
*
S2
/
(
C
*
C
)
else
:
x
*=
S1
*
S2
/
(
C
*
C
)
x
*=
S1
*
S2
/
(
C
*
C
)
return
x
.
to
(
dtype
)
elif
quant_type
in
[
'
truncated-vector
'
,
'
vector
'
]:
elif
quant_type
in
[
"
truncated-vector
"
,
"
vector
"
]:
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
x
*=
S1
/
C
x
*=
S1
/
C
else
:
x
*=
S1
/
C
x
*=
S2
/
C
x
*=
S1
/
C
x
*=
S2
/
C
return
x
.
to
(
dtype
)
else
:
return
None
else
:
return
None
def
dequant_min_max
(
xq
,
A
,
B
,
SA
,
SB
,
dtype
=
torch
.
half
):
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
SB
=
SB
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
SB
=
SB
.
squeeze
(
0
)
if
len
(
SB
.
shape
)
==
2
:
x
*=
SB
.
t
()
/
127
x
*=
SB
.
t
()
/
127
else
:
x
*=
SB
/
127
x
*=
SA
[
1
]
/
127
x
+=
offset
x
*=
SB
/
127
x
*=
SA
[
1
]
/
127
x
+=
offset
return
x
.
to
(
dtype
)
def
extract_outliers
(
A
,
SA
,
idx
):
shapeA
=
SA
[
0
]
formatA
=
SA
[
1
]
assert
formatA
in
[
'
col_turing
'
,
'
col_ampere
'
]
assert
A
.
device
.
type
==
'
cuda
'
assert
formatA
in
[
"
col_turing
"
,
"
col_ampere
"
]
assert
A
.
device
.
type
==
"
cuda
"
out
=
torch
.
zeros
((
shapeA
[
0
],
idx
.
numel
()),
dtype
=
torch
.
int8
,
device
=
A
.
device
)
...
...
@@ -1420,13 +1991,9 @@ def extract_outliers(A, SA, idx):
ptrIdx
=
get_ptr
(
idx
)
ptrOut
=
get_ptr
(
out
)
if
formatA
==
'
col_turing
'
:
if
formatA
==
"
col_turing
"
:
lib
.
cextractOutliers_turing
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
elif
formatA
==
'
col_ampere
'
:
elif
formatA
==
"
col_ampere
"
:
lib
.
cextractOutliers_ampere
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
return
out
bitsandbytes/nn/__init__.py
View file @
bfa0e332
...
...
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.modules
import
StableEmbedding
,
Linear8bit
,
Linear8bitLt
,
Int8Params
from
.modules
import
Int8Params
,
Linear8bit
,
Linear8bitLt
,
StableEmbedding
bitsandbytes/nn/modules.py
View file @
bfa0e332
...
...
@@ -2,38 +2,58 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
bitsandbytes
as
bnb
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypeVar
,
Union
,
overload
)
from
typing
import
Union
,
Tuple
,
Any
,
Callable
,
Iterator
,
Set
,
Optional
,
overload
,
TypeVar
,
Mapping
,
Dict
from
torch
import
Tensor
,
device
,
dtype
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
device
,
dtype
,
nn
from
torch.nn.parameter
import
Parameter
import
bitsandbytes
as
bnb
from
bitsandbytes.optim
import
GlobalOptimManager
T
=
TypeVar
(
'T'
,
bound
=
'torch.nn.Module'
)
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
class
StableEmbedding
(
torch
.
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
)
->
None
:
super
(
StableEmbedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
)
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.0
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
,
)
->
None
:
super
(
StableEmbedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
)
self
.
norm
=
torch
.
nn
.
LayerNorm
(
embedding_dim
)
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
'weight'
,
{
'optim_bits'
:
32
})
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
"weight"
,
{
"optim_bits"
:
32
}
)
def
reset_parameters
(
self
)
->
None
:
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
self
.
_fill_padding_idx_with_zero
()
'''
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
"""
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
'''
"""
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
:
with
torch
.
no_grad
():
...
...
@@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding):
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
emb
=
F
.
embedding
(
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
return
self
.
norm
(
emb
)
class
Embedding
(
torch
.
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
)
->
None
:
super
(
Embedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
)
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
'weight'
,
{
'optim_bits'
:
32
})
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.0
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
,
)
->
None
:
super
(
Embedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
)
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
"weight"
,
{
"optim_bits"
:
32
}
)
def
reset_parameters
(
self
)
->
None
:
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
self
.
_fill_padding_idx_with_zero
()
'''
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
"""
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
'''
"""
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
:
with
torch
.
no_grad
():
...
...
@@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding):
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
emb
=
F
.
embedding
(
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
return
emb
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
SCB
=
None
...
...
@@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter):
del
CBt
del
SCBt
self
.
data
=
CB
setattr
(
self
,
'
CB
'
,
CB
)
setattr
(
self
,
'
SCB
'
,
SCB
)
setattr
(
self
,
"
CB
"
,
CB
)
setattr
(
self
,
"
SCB
"
,
SCB
)
return
self
@
overload
def
to
(
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
non_blocking
:
bool
=
...)
->
T
:
def
to
(
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
non_blocking
:
bool
=
...,
)
->
T
:
...
@
overload
...
...
@@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter):
...
def
to
(
self
,
*
args
,
**
kwargs
):
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
device
is
not
None
and
device
.
type
==
'cuda'
and
self
.
data
.
device
.
type
==
'cpu'
:
return
self
.
cuda
(
device
)
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
(
device
is
not
None
and
device
.
type
==
"cuda"
and
self
.
data
.
device
.
type
==
"cpu"
):
return
self
.
cuda
(
device
)
else
:
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
)
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
,
)
new_param
.
CB
=
self
.
CB
new_param
.
SCB
=
self
.
SCB
return
new_param
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
,
):
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
index
=
index
self
.
index
=
index
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
...
...
@@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear):
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
#assert not self.state.has_fp16_weights
#if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
# assert not self.state.has_fp16_weights
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
state
=
self
.
state
)
...
...
@@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear):
return
out
class
Linear8bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
quant_type
=
'vector'
,
index
=
None
,
args
=
None
,
sparse_decomp
=
False
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
quant_type
=
"vector"
,
index
=
None
,
args
=
None
,
sparse_decomp
=
False
,
):
super
(
Linear8bit
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
quant_type
=
quant_type
self
.
index
=
index
...
...
@@ -178,15 +266,24 @@ class Linear8bit(nn.Linear):
self
.
iter
+=
1
if
self
.
iter
%
self
.
args
.
clip_freq
==
0
:
with
torch
.
no_grad
():
maxval
,
maxidx
=
torch
.
topk
(
torch
.
abs
(
self
.
weight
.
flatten
()),
k
=
self
.
args
.
clip_idx
)
maxval
,
maxidx
=
torch
.
topk
(
torch
.
abs
(
self
.
weight
.
flatten
()),
k
=
self
.
args
.
clip_idx
)
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
print
(
'
clip
'
,
maxval
[
-
1
].
item
())
print
(
"
clip
"
,
maxval
[
-
1
].
item
())
self
.
weight
.
clip_
(
-
maxval
[
-
1
],
maxval
[
-
1
])
if
self
.
args
is
not
None
:
out
=
bnb
.
nn
.
functional
.
sparse_decomposed_linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
qval
=
self
.
args
.
sparse_decomp_val
,
quant_type
=
self
.
args
.
quant_type
)
out
=
bnb
.
nn
.
functional
.
sparse_decomposed_linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
qval
=
self
.
args
.
sparse_decomp_val
,
quant_type
=
self
.
args
.
quant_type
,
)
else
:
out
=
bnb
.
nn
.
functional
.
linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
quant_type
=
self
.
args
.
quant_type
)
out
=
bnb
.
nn
.
functional
.
linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
quant_type
=
self
.
args
.
quant_type
)
return
out
bitsandbytes/optim/__init__.py
View file @
bfa0e332
bitsandbytes/optim/adagrad.py
View file @
bfa0e332
...
...
@@ -4,9 +4,22 @@
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
Adagrad
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
...
...
@@ -14,15 +27,39 @@ class Adagrad(Optimizer1State):
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
super
(
Adagrad
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
ValueError
(
"Lr Decay != 0.0 not supported!"
)
super
(
Adagrad
,
self
).
__init__
(
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adagrad8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
8
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
8
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
...
...
@@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State):
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
raise
ValueError
(
'
Lr Decay != 0.0 not supported!
'
)
raise
ValueError
(
"
Lr Decay != 0.0 not supported!
"
)
assert
block_wise
super
(
Adagrad8bit
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
super
(
Adagrad8bit
,
self
).
__init__
(
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adagrad32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
...
...
@@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State):
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
super
(
Adagrad32bit
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
ValueError
(
"Lr Decay != 0.0 not supported!"
)
super
(
Adagrad32bit
,
self
).
__init__
(
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/adam.py
View file @
bfa0e332
...
...
@@ -8,29 +8,97 @@ import os
import
torch
import
torch.distributed
as
dist
from
bitsandbytes.optim.optimizer
import
Optimizer2State
import
bitsandbytes.functional
as
F
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
Adam
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
Adam
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adam8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
Adam8bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam8bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adam32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
Adam32bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam32bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AnalysisAdam
(
torch
.
optim
.
Optimizer
):
...
...
@@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer):
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
bnb_analysis
=
'
dynamic-blockwise
'
,
savedir
=
None
bnb_analysis
=
"
dynamic-blockwise
"
,
savedir
=
None
,
):
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
...
...
@@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer):
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p_data_fp32
)
# Exponential moving average of squared gradient values
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
'abserrors'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
'relerrors'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
'counts'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"abserrors"
]
=
torch
.
zeros
(
(
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"relerrors"
]
=
torch
.
zeros
(
(
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"counts"
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
...
...
@@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer):
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
step_size
=
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
e
=
state
[
'
abserrors
'
]
rele
=
state
[
'
relerrors
'
]
counts
=
state
[
'
counts
'
]
e
=
state
[
"
abserrors
"
]
rele
=
state
[
"
relerrors
"
]
counts
=
state
[
"
counts
"
]
if
group
[
"weight_decay"
]
!=
0
:
p_data_fp32
.
add_
(
...
...
@@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer):
if
amsgrad
:
max_exp_avg_sq
=
state
[
"max_exp_avg_sq"
]
# Decay the first and second moment running average coefficient
exp_avg
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
update_fp32
=
exp_avg
/
denom
update_fp32
=
exp_avg
/
denom
if
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
:
if
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
:
# embedding layer or too small
p_data_fp32
+=
-
step_size
*
update_fp32
p_data_fp32
+=
-
step_size
*
update_fp32
else
:
if
self
.
analysis
==
'
dynamic-blockwise
'
:
if
self
.
analysis
==
"
dynamic-blockwise
"
:
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize_blockwise
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize_blockwise
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize_blockwise
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize_blockwise
(
C2
,
S2
)
elif
self
.
analysis
==
'
dynamic
'
:
elif
self
.
analysis
==
"
dynamic
"
:
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
elif
self
.
analysis
==
'
linear
'
:
elif
self
.
analysis
==
"
linear
"
:
code1
=
F
.
create_linear_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_linear_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
elif
self
.
analysis
==
'
quantile
'
:
elif
self
.
analysis
==
"
quantile
"
:
code1
=
F
.
estimate_quantiles
(
exp_avg
)
code2
=
F
.
estimate_quantiles
(
exp_avg_sq
)
C1
=
F
.
quantize_no_absmax
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize_no_absmax
(
C1
,
code1
)
C2
=
F
.
quantize_no_absmax
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize_no_absmax
(
C2
,
code2
)
elif
self
.
analysis
==
'
my-quantization-routine
'
:
elif
self
.
analysis
==
"
my-quantization-routine
"
:
pass
# 1. get code
# 2. quantize
# 3. dequantize
# Error will be calculated automatically!
else
:
raise
ValueError
(
f
'
Invalid analysis value:
{
self
.
analysis
}
!
'
)
raise
ValueError
(
f
"
Invalid analysis value:
{
self
.
analysis
}
!
"
)
denom
=
state2
.
sqrt
().
add_
(
group
[
"eps"
])
update_8bit
=
state1
/
denom
update_8bit
=
state1
/
denom
abserr
=
torch
.
abs
(
update_8bit
-
update_fp32
)
relerr
=
abserr
/
torch
.
abs
(
update_fp32
+
1e-6
)
abserr
=
torch
.
abs
(
update_8bit
-
update_fp32
)
relerr
=
abserr
/
torch
.
abs
(
update_fp32
+
1e-6
)
C1
,
C2
=
C1
.
int
(),
C2
.
int
()
F
.
histogram_scatter_add_2d
(
e
,
C1
.
int
(),
C2
.
int
(),
abserr
)
F
.
histogram_scatter_add_2d
(
rele
,
C1
.
int
(),
C2
.
int
(),
relerr
)
F
.
histogram_scatter_add_2d
(
counts
,
C1
.
int
(),
C2
.
int
(),
torch
.
ones_like
(
abserr
))
p_data_fp32
+=
-
step_size
*
update_fp32
F
.
histogram_scatter_add_2d
(
counts
,
C1
.
int
(),
C2
.
int
(),
torch
.
ones_like
(
abserr
)
)
p_data_fp32
+=
-
step_size
*
update_fp32
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
if
self
.
savedir
!=
''
and
state
[
'step'
]
%
100
==
0
:
if
not
os
.
path
.
exists
(
self
.
savedir
):
os
.
makedirs
(
self
.
savedir
)
shapestr
=
'_'
.
join
([
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
])
pathe
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_abserr.pkl'
)
pathrele
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_relerr.pkl'
)
pathcounts
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_counts.pkl'
)
if
self
.
savedir
!=
""
and
state
[
"step"
]
%
100
==
0
:
if
not
os
.
path
.
exists
(
self
.
savedir
):
os
.
makedirs
(
self
.
savedir
)
shapestr
=
"_"
.
join
([
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
])
pathe
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_abserr.pkl"
)
pathrele
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_relerr.pkl"
)
pathcounts
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_counts.pkl"
)
torch
.
save
(
e
,
pathe
)
torch
.
save
(
rele
,
pathrele
)
torch
.
save
(
counts
,
pathcounts
)
...
...
@@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer):
if
p
.
data
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
p
.
data
.
copy_
(
p_data_fp32
)
return
loss
bitsandbytes/optim/adamw.py
View file @
bfa0e332
...
...
@@ -4,24 +4,90 @@
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
AdamW
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AdamW8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW8bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW8bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AdamW32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW32bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
class
AdamW32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW32bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/lamb.py
View file @
bfa0e332
...
...
@@ -4,25 +4,102 @@
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
LAMB
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
super
(
LAMB
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
class
LAMB8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
super
(
LAMB8bit
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
class
LAMB32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
super
(
LAMB32bit
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
class
LAMB8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB8bit
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
class
LAMB32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB32bit
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
bitsandbytes/optim/lars.py
View file @
bfa0e332
...
...
@@ -3,41 +3,119 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
torch.optim
import
Optimizer
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
LARS
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
super
(
LARS
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS
,
self
).
__init__
(
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
LARS8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
super
(
LARS8bit
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS8bit
,
self
).
__init__
(
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
LARS32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
super
(
LARS32bit
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS32bit
,
self
).
__init__
(
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
PytorchLARS
(
Optimizer
):
def
__init__
(
self
,
params
,
lr
=
0.01
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
max_unorm
=
0.02
):
def
__init__
(
self
,
params
,
lr
=
0.01
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
max_unorm
=
0.02
,
):
if
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
momentum
<
0.0
:
...
...
@@ -45,8 +123,14 @@ class PytorchLARS(Optimizer):
if
weight_decay
<
0.0
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
dampening
=
dampening
,
weight_decay
=
weight_decay
,
nesterov
=
nesterov
,
max_unorm
=
max_unorm
)
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
dampening
=
dampening
,
weight_decay
=
weight_decay
,
nesterov
=
nesterov
,
max_unorm
=
max_unorm
,
)
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
super
(
PytorchLARS
,
self
).
__init__
(
params
,
defaults
)
...
...
@@ -54,7 +138,7 @@ class PytorchLARS(Optimizer):
def
__setstate__
(
self
,
state
):
super
(
PytorchLARS
,
self
).
__setstate__
(
state
)
for
group
in
self
.
param_groups
:
group
.
setdefault
(
'
nesterov
'
,
False
)
group
.
setdefault
(
"
nesterov
"
,
False
)
@
torch
.
no_grad
()
def
step
(
self
,
closure
=
None
):
...
...
@@ -73,15 +157,16 @@ class PytorchLARS(Optimizer):
params_with_grad
=
[]
d_p_list
=
[]
momentum_buffer_list
=
[]
weight_decay
=
group
[
'
weight_decay
'
]
momentum
=
group
[
'
momentum
'
]
dampening
=
group
[
'
dampening
'
]
nesterov
=
group
[
'
nesterov
'
]
max_unorm
=
group
[
'
max_unorm
'
]
lr
=
group
[
'
lr
'
]
weight_decay
=
group
[
"
weight_decay
"
]
momentum
=
group
[
"
momentum
"
]
dampening
=
group
[
"
dampening
"
]
nesterov
=
group
[
"
nesterov
"
]
max_unorm
=
group
[
"
max_unorm
"
]
lr
=
group
[
"
lr
"
]
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
for
p
in
group
[
"params"
]:
if
p
.
grad
is
None
:
continue
state
=
self
.
state
[
p
]
d_p
=
p
.
grad
...
...
@@ -89,16 +174,16 @@ class PytorchLARS(Optimizer):
d_p
=
d_p
.
add
(
param
,
alpha
=
weight_decay
)
if
momentum
!=
0
:
buf
=
state
.
get
(
'
momentum_buffer
'
,
None
)
buf
=
state
.
get
(
"
momentum_buffer
"
,
None
)
if
buf
is
None
:
buf
=
torch
.
clone
(
d_p
).
detach
()
state
[
'
momentum_buffer
'
]
=
buf
state
[
"
momentum_buffer
"
]
=
buf
else
:
buf
.
mul_
(
momentum
).
add_
(
d_p
,
alpha
=
1
-
dampening
)
if
nesterov
:
update
=
d_p
+
buf
*
momentum
update
=
d_p
+
buf
*
momentum
else
:
update
=
buf
...
...
@@ -107,9 +192,9 @@ class PytorchLARS(Optimizer):
assert
p
.
dtype
==
torch
.
float32
pnorm
=
torch
.
norm
(
p
.
detach
())
unorm
=
torch
.
norm
(
update
)
if
unorm
>
max_unorm
*
pnorm
:
update_scale
=
max_unorm
*
pnorm
/
unorm
if
unorm
>
max_unorm
*
pnorm
:
update_scale
=
max_unorm
*
pnorm
/
unorm
p
.
add_
(
update
,
alpha
=-
lr
*
update_scale
)
p
.
add_
(
update
,
alpha
=-
lr
*
update_scale
)
return
loss
bitsandbytes/optim/optimizer.py
View file @
bfa0e332
...
...
@@ -2,12 +2,15 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
collections
import
abc
as
container_abcs
from
collections
import
defaultdict
from
copy
import
deepcopy
from
itertools
import
chain
import
torch
import
bitsandbytes.functional
as
F
from
copy
import
deepcopy
from
itertools
import
chain
from
collections
import
defaultdict
,
abc
as
container_abcs
class
MockArgs
(
object
):
def
__init__
(
self
,
initial_data
):
...
...
@@ -19,7 +22,7 @@ class GlobalOptimManager(object):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
self
.
pid2config
=
{}
...
...
@@ -38,15 +41,15 @@ class GlobalOptimManager(object):
def
register_parameters
(
self
,
params
):
param_groups
=
list
(
params
)
if
not
isinstance
(
param_groups
[
0
],
dict
):
param_groups
=
[{
'
params
'
:
param_groups
}]
param_groups
=
[{
"
params
"
:
param_groups
}]
for
group_index
,
group
in
enumerate
(
param_groups
):
for
p_index
,
p
in
enumerate
(
group
[
'
params
'
]):
for
p_index
,
p
in
enumerate
(
group
[
"
params
"
]):
if
id
(
p
)
in
self
.
pid2config
:
self
.
index2config
[(
group_index
,
p_index
)]
=
self
.
pid2config
[
id
(
p
)]
def
override_config
(
self
,
parameters
,
key
=
None
,
value
=
None
,
key_value_dict
=
None
):
'''
"""
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
...
...
@@ -63,7 +66,7 @@ class GlobalOptimManager(object):
The value for the hyperparamters.
key_value_dict : dict
A dictionary with multiple key-values to override.
'''
"""
self
.
uses_config_override
=
True
if
isinstance
(
parameters
,
torch
.
nn
.
Parameter
):
parameters
=
[
parameters
]
...
...
@@ -75,16 +78,16 @@ class GlobalOptimManager(object):
if
key_value_dict
is
not
None
:
for
p
in
parameters
:
if
id
(
p
)
in
self
.
pid2config
:
self
.
pid2config
[
id
(
p
)].
update
(
key_value_dict
)
else
:
self
.
pid2config
[
id
(
p
)]
=
key_value_dict
if
id
(
p
)
in
self
.
pid2config
:
self
.
pid2config
[
id
(
p
)].
update
(
key_value_dict
)
else
:
self
.
pid2config
[
id
(
p
)]
=
key_value_dict
def
register_module_override
(
self
,
module
,
param_name
,
config
):
self
.
module_weight_config_triple
.
append
((
module
,
param_name
,
config
))
class
Optimizer8bit
(
torch
.
optim
.
Optimizer
):
def
__init__
(
self
,
params
,
defaults
,
optim_bits
=
32
):
super
(
Optimizer8bit
,
self
).
__init__
(
params
,
defaults
)
self
.
initialized
=
False
...
...
@@ -92,23 +95,32 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
mng
=
GlobalOptimManager
.
get_instance
()
self
.
non_castable_tensor_keys
=
set
(
[
'qmap1'
,
'qmap2'
,
'max1'
,
'max2'
,
'new_max1'
,
'new_max2'
,
'state1'
,
'state2'
,
'gnorm_vec'
,
'absmax1'
,
'absmax2'
,
'unorm_vec'
])
if
optim_bits
==
8
:
self
.
fill_qmap
()
[
"qmap1"
,
"qmap2"
,
"max1"
,
"max2"
,
"new_max1"
,
"new_max2"
,
"state1"
,
"state2"
,
"gnorm_vec"
,
"absmax1"
,
"absmax2"
,
"unorm_vec"
,
]
)
if
optim_bits
==
8
:
self
.
fill_qmap
()
def
fill_qmap
(
self
):
self
.
name2qmap
[
'
dynamic
'
]
=
F
.
create_dynamic_map
(
signed
=
True
)
self
.
name2qmap
[
'
udynamic
'
]
=
F
.
create_dynamic_map
(
signed
=
False
)
self
.
name2qmap
[
"
dynamic
"
]
=
F
.
create_dynamic_map
(
signed
=
True
)
self
.
name2qmap
[
"
udynamic
"
]
=
F
.
create_dynamic_map
(
signed
=
False
)
def
__setstate__
(
self
,
state
):
super
(
Optimizer8bit
,
self
).
__setstate__
(
state
)
def
load_state_dict
(
self
,
state_dict
):
r
"""Loads the optimizer state.
...
...
@@ -120,21 +132,28 @@ class Optimizer8bit(torch.optim.Optimizer):
state_dict
=
deepcopy
(
state_dict
)
# Validate the state_dict
groups
=
self
.
param_groups
saved_groups
=
state_dict
[
'
param_groups
'
]
saved_groups
=
state_dict
[
"
param_groups
"
]
if
len
(
groups
)
!=
len
(
saved_groups
):
raise
ValueError
(
"loaded state dict has a different number of "
"parameter groups"
)
param_lens
=
(
len
(
g
[
'params'
])
for
g
in
groups
)
saved_lens
=
(
len
(
g
[
'params'
])
for
g
in
saved_groups
)
raise
ValueError
(
"loaded state dict has a different number of "
"parameter groups"
)
param_lens
=
(
len
(
g
[
"params"
])
for
g
in
groups
)
saved_lens
=
(
len
(
g
[
"params"
])
for
g
in
saved_groups
)
if
any
(
p_len
!=
s_len
for
p_len
,
s_len
in
zip
(
param_lens
,
saved_lens
)):
raise
ValueError
(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
raise
ValueError
(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
# Update the state
id_map
=
{
old_id
:
p
for
old_id
,
p
in
zip
(
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
saved_groups
)),
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
groups
)))}
id_map
=
{
old_id
:
p
for
old_id
,
p
in
zip
(
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
saved_groups
)),
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
groups
)),
)
}
def
cast
(
param
,
value
):
r
"""Make a deep copy of value, casting all tensors to device of param."""
...
...
@@ -161,7 +180,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state
=
defaultdict
(
dict
)
for
k
,
v
in
state_dict
[
'
state
'
].
items
():
for
k
,
v
in
state_dict
[
"
state
"
].
items
():
if
k
in
id_map
:
param
=
id_map
[
k
]
state
[
param
]
=
cast
(
param
,
v
)
...
...
@@ -170,15 +189,15 @@ class Optimizer8bit(torch.optim.Optimizer):
# Update parameter groups, setting their 'params' value
def
update_group
(
group
,
new_group
):
new_group
[
'
params
'
]
=
group
[
'
params
'
]
new_group
[
"
params
"
]
=
group
[
"
params
"
]
return
new_group
param_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
groups
,
saved_groups
)]
self
.
__setstate__
({
'
state
'
:
state
,
'
param_groups
'
:
param_groups
})
param_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
groups
,
saved_groups
)]
self
.
__setstate__
({
"
state
"
:
state
,
"
param_groups
"
:
param_groups
})
def
to_gpu
(
self
):
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
pindex
,
p
in
enumerate
(
group
[
'
params
'
]):
for
pindex
,
p
in
enumerate
(
group
[
"
params
"
]):
if
p
in
self
.
state
:
values
=
self
.
state
[
p
]
for
k
,
v
in
values
.
items
():
...
...
@@ -189,17 +208,23 @@ class Optimizer8bit(torch.optim.Optimizer):
for
module
,
attr
,
config
in
self
.
mng
.
module_weight_config_triple
:
pmodule
=
getattr
(
module
,
attr
)
assert
pmodule
is
not
None
assert
isinstance
(
pmodule
,
torch
.
Tensor
)
or
isinstance
(
pmodule
,
torch
.
Parameter
)
assert
isinstance
(
pmodule
,
torch
.
Tensor
)
or
isinstance
(
pmodule
,
torch
.
Parameter
)
found
=
False
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
if
found
:
break
for
pindex
,
p
in
enumerate
(
group
[
'params'
]):
if
found
:
break
if
found
:
break
for
pindex
,
p
in
enumerate
(
group
[
"params"
]):
if
found
:
break
if
id
(
p
)
==
id
(
pmodule
):
# found the matching parameter
# init override
self
.
mng
.
pid2config
[
id
(
p
)]
=
config
self
.
mng
.
index2config
[(
gindex
,
pindex
)]
=
self
.
mng
.
pid2config
[
id
(
p
)]
self
.
mng
.
index2config
[(
gindex
,
pindex
)]
=
self
.
mng
.
pid2config
[
id
(
p
)
]
found
=
True
@
torch
.
no_grad
()
...
...
@@ -223,7 +248,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
initialized
=
True
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
pindex
,
p
in
enumerate
(
group
[
'
params
'
]):
for
pindex
,
p
in
enumerate
(
group
[
"
params
"
]):
if
p
.
grad
is
None
:
continue
state
=
self
.
state
[
p
]
...
...
@@ -236,58 +261,70 @@ class Optimizer8bit(torch.optim.Optimizer):
def
get_config
(
self
,
gindex
,
pindex
,
group
):
config
=
{}
config
[
'
betas
'
]
=
group
[
'
betas
'
]
config
[
'
eps
'
]
=
group
[
'
eps
'
]
config
[
'
weight_decay
'
]
=
group
[
'
weight_decay
'
]
config
[
'
lr
'
]
=
group
[
'
lr
'
]
config
[
'
optim_bits
'
]
=
self
.
args
.
optim_bits
config
[
'
min_8bit_size
'
]
=
self
.
args
.
min_8bit_size
config
[
'
percentile_clipping
'
]
=
self
.
args
.
percentile_clipping
config
[
'
block_wise
'
]
=
self
.
args
.
block_wise
config
[
'
max_unorm
'
]
=
self
.
args
.
max_unorm
config
[
'
skip_zeros
'
]
=
self
.
args
.
skip_zeros
config
[
"
betas
"
]
=
group
[
"
betas
"
]
config
[
"
eps
"
]
=
group
[
"
eps
"
]
config
[
"
weight_decay
"
]
=
group
[
"
weight_decay
"
]
config
[
"
lr
"
]
=
group
[
"
lr
"
]
config
[
"
optim_bits
"
]
=
self
.
args
.
optim_bits
config
[
"
min_8bit_size
"
]
=
self
.
args
.
min_8bit_size
config
[
"
percentile_clipping
"
]
=
self
.
args
.
percentile_clipping
config
[
"
block_wise
"
]
=
self
.
args
.
block_wise
config
[
"
max_unorm
"
]
=
self
.
args
.
max_unorm
config
[
"
skip_zeros
"
]
=
self
.
args
.
skip_zeros
if
(
gindex
,
pindex
)
in
self
.
mng
.
index2config
:
config
.
update
(
self
.
mng
.
index2config
[(
gindex
,
pindex
)])
return
config
def
init_state
(
self
,
group
,
p
,
gindex
,
pindex
):
raise
NotImplementedError
(
f
'
init_state method needs to be overidden
'
)
raise
NotImplementedError
(
f
"
init_state method needs to be overidden
"
)
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
raise
NotImplementedError
(
f
'The update_step method needs to be overidden'
)
raise
NotImplementedError
(
f
"The update_step method needs to be overidden"
)
class
Optimizer2State
(
Optimizer8bit
):
def
__init__
(
self
,
optimizer_name
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0.0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
,
skip_zeros
=
False
):
def
__init__
(
self
,
optimizer_name
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0.0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
,
skip_zeros
=
False
,
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
isinstance
(
betas
,
str
):
# format: '(beta1, beta2)'
betas
=
betas
.
replace
(
'('
,
''
).
replace
(
')'
,
''
).
strip
().
split
(
','
)
betas
=
betas
.
replace
(
"("
,
""
).
replace
(
")"
,
""
).
strip
().
split
(
","
)
betas
=
[
float
(
b
)
for
b
in
betas
]
for
i
in
range
(
len
(
betas
)):
if
not
0.0
<=
betas
[
i
]
<
1.0
:
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
Optimizer2State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
if
args
is
None
:
args
=
{}
args
[
'
optim_bits
'
]
=
optim_bits
args
[
'
percentile_clipping
'
]
=
100
args
[
'
min_8bit_size
'
]
=
min_8bit_size
args
[
'
percentile_clipping
'
]
=
percentile_clipping
args
[
'
block_wise
'
]
=
block_wise
args
[
'
max_unorm
'
]
=
max_unorm
args
[
'
skip_zeros
'
]
=
skip_zeros
args
[
"
optim_bits
"
]
=
optim_bits
args
[
"
percentile_clipping
"
]
=
100
args
[
"
min_8bit_size
"
]
=
min_8bit_size
args
[
"
percentile_clipping
"
]
=
percentile_clipping
args
[
"
block_wise
"
]
=
block_wise
args
[
"
max_unorm
"
]
=
max_unorm
args
[
"
skip_zeros
"
]
=
skip_zeros
self
.
args
=
MockArgs
(
args
)
else
:
...
...
@@ -299,50 +336,83 @@ class Optimizer2State(Optimizer8bit):
def
init_state
(
self
,
group
,
p
,
gindex
,
pindex
):
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
if
config
[
'
optim_bits
'
]
==
32
:
if
config
[
"
optim_bits
"
]
==
32
:
dtype
=
torch
.
float32
elif
config
[
'
optim_bits
'
]
==
8
:
elif
config
[
"
optim_bits
"
]
==
8
:
dtype
=
torch
.
uint8
else
:
raise
NotImplementedError
(
f
'Amount of optimizer bits not supported:
{
config
[
"optim_bits"
]
}
'
)
else
:
raise
NotImplementedError
(
f
'Amount of optimizer bits not supported:
{
config
[
"optim_bits"
]
}
'
)
if
p
.
numel
()
<
config
[
'min_8bit_size'
]:
dtype
=
torch
.
float32
if
p
.
numel
()
<
config
[
"min_8bit_size"
]:
dtype
=
torch
.
float32
state
=
self
.
state
[
p
]
state
[
'
step
'
]
=
0
state
[
"
step
"
]
=
0
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
state
[
'state1'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
'state2'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
,
)
state
[
"state2"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
,
)
elif
dtype
==
torch
.
uint8
:
if
state
[
'step'
]
==
0
:
if
'dynamic'
not
in
self
.
name2qmap
:
self
.
fill_qmap
()
self
.
name2qmap
[
'dynamic'
]
=
self
.
name2qmap
[
'dynamic'
].
to
(
p
.
device
)
self
.
name2qmap
[
'udynamic'
]
=
self
.
name2qmap
[
'udynamic'
].
to
(
p
.
device
)
state
[
'state1'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
)
state
[
'qmap1'
]
=
self
.
name2qmap
[
'dynamic'
]
state
[
'state2'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
)
state
[
'qmap2'
]
=
self
.
name2qmap
[
'udynamic'
]
if
config
[
'block_wise'
]:
if
state
[
"step"
]
==
0
:
if
"dynamic"
not
in
self
.
name2qmap
:
self
.
fill_qmap
()
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
self
.
name2qmap
[
"udynamic"
]
=
self
.
name2qmap
[
"udynamic"
].
to
(
p
.
device
)
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
,
)
state
[
"qmap1"
]
=
self
.
name2qmap
[
"dynamic"
]
state
[
"state2"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
,
)
state
[
"qmap2"
]
=
self
.
name2qmap
[
"udynamic"
]
if
config
[
"block_wise"
]:
n
=
p
.
numel
()
blocks
=
n
//
2048
blocks
=
n
//
2048
blocks
+=
1
if
n
%
2048
>
0
else
0
state
[
'absmax1'
]
=
torch
.
zeros
((
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
'absmax2'
]
=
torch
.
zeros
((
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"absmax1"
]
=
torch
.
zeros
(
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"absmax2"
]
=
torch
.
zeros
(
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
else
:
state
[
'max1'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
'new_max1'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
'max2'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
'new_max2'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max1"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max2"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max2"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
if
config
[
'
percentile_clipping
'
]
<
100
:
state
[
'
gnorm_vec
'
]
=
torch
.
zeros
((
100
,),
device
=
p
.
device
)
if
config
[
"
percentile_clipping
"
]
<
100
:
state
[
"
gnorm_vec
"
]
=
torch
.
zeros
((
100
,),
device
=
p
.
device
)
if
config
[
'
max_unorm
'
]
>
0.0
:
state
[
'
unorm_vec
'
]
=
torch
.
zeros
((
1
,),
device
=
p
.
device
)
if
config
[
"
max_unorm
"
]
>
0.0
:
state
[
"
unorm_vec
"
]
=
torch
.
zeros
((
1
,),
device
=
p
.
device
)
@
torch
.
no_grad
()
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
...
...
@@ -351,41 +421,101 @@ class Optimizer2State(Optimizer8bit):
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
state
[
'
step
'
]
+=
1
step
=
state
[
'
step
'
]
state
[
"
step
"
]
+=
1
step
=
state
[
"
step
"
]
if
config
[
'percentile_clipping'
]
<
100
:
current_gnorm
,
clip_value
,
gnorm_scale
=
F
.
percentile_clipping
(
grad
,
state
[
'gnorm_vec'
],
step
,
config
[
'percentile_clipping'
])
if
config
[
"percentile_clipping"
]
<
100
:
current_gnorm
,
clip_value
,
gnorm_scale
=
F
.
percentile_clipping
(
grad
,
state
[
"gnorm_vec"
],
step
,
config
[
"percentile_clipping"
]
)
else
:
gnorm_scale
=
1.0
if
state
[
'state1'
].
dtype
==
torch
.
float
:
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
config
[
'betas'
][
0
],
config
[
'eps'
],
step
,
config
[
'lr'
],
state
[
'state2'
],
config
[
'betas'
][
1
],
config
[
'weight_decay'
],
gnorm_scale
,
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
],
skip_zeros
=
config
[
'skip_zeros'
])
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
not
config
[
'block_wise'
]:
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
state
[
'state2'
],
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
config
[
'eps'
],
step
,
config
[
'lr'
],
state
[
'qmap1'
],
state
[
'qmap2'
],
state
[
'max1'
],
state
[
'max2'
],
state
[
'new_max1'
],
state
[
'new_max2'
],
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
,
unorm_vec
=
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
])
if
state
[
"state1"
].
dtype
==
torch
.
float
:
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
"state1"
],
config
[
"betas"
][
0
],
config
[
"eps"
],
step
,
config
[
"lr"
],
state
[
"state2"
],
config
[
"betas"
][
1
],
config
[
"weight_decay"
],
gnorm_scale
,
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
max_unorm
=
config
[
"max_unorm"
],
skip_zeros
=
config
[
"skip_zeros"
],
)
elif
state
[
"state1"
].
dtype
==
torch
.
uint8
and
not
config
[
"block_wise"
]:
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
"state1"
],
state
[
"state2"
],
config
[
"betas"
][
0
],
config
[
"betas"
][
1
],
config
[
"eps"
],
step
,
config
[
"lr"
],
state
[
"qmap1"
],
state
[
"qmap2"
],
state
[
"max1"
],
state
[
"max2"
],
state
[
"new_max1"
],
state
[
"new_max2"
],
config
[
"weight_decay"
],
gnorm_scale
=
gnorm_scale
,
unorm_vec
=
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
max_unorm
=
config
[
"max_unorm"
],
)
# swap maxes
state
[
'max1'
],
state
[
'new_max1'
]
=
state
[
'new_max1'
],
state
[
'max1'
]
state
[
'max2'
],
state
[
'new_max2'
]
=
state
[
'new_max2'
],
state
[
'max2'
]
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
config
[
'block_wise'
]:
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
state
[
'state2'
],
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
config
[
'eps'
],
step
,
config
[
'lr'
],
state
[
'qmap1'
],
state
[
'qmap2'
],
state
[
'absmax1'
],
state
[
'absmax2'
],
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
config
[
'skip_zeros'
])
state
[
"max1"
],
state
[
"new_max1"
]
=
state
[
"new_max1"
],
state
[
"max1"
]
state
[
"max2"
],
state
[
"new_max2"
]
=
state
[
"new_max2"
],
state
[
"max2"
]
elif
state
[
"state1"
].
dtype
==
torch
.
uint8
and
config
[
"block_wise"
]:
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
"state1"
],
state
[
"state2"
],
config
[
"betas"
][
0
],
config
[
"betas"
][
1
],
config
[
"eps"
],
step
,
config
[
"lr"
],
state
[
"qmap1"
],
state
[
"qmap2"
],
state
[
"absmax1"
],
state
[
"absmax2"
],
config
[
"weight_decay"
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
config
[
"skip_zeros"
],
)
class
Optimizer1State
(
Optimizer8bit
):
def
__init__
(
self
,
optimizer_name
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.0
),
eps
=
1e-8
,
weight_decay
=
0.0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
,
skip_zeros
=
False
):
def
__init__
(
self
,
optimizer_name
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.0
),
eps
=
1e-8
,
weight_decay
=
0.0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
,
skip_zeros
=
False
,
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
eps
:
...
...
@@ -395,19 +525,18 @@ class Optimizer1State(Optimizer8bit):
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
Optimizer1State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
if
args
is
None
:
args
=
{}
args
[
'
optim_bits
'
]
=
optim_bits
args
[
'
percentile_clipping
'
]
=
100
args
[
'
min_8bit_size
'
]
=
min_8bit_size
args
[
'
percentile_clipping
'
]
=
percentile_clipping
args
[
'
block_wise
'
]
=
block_wise
args
[
'
max_unorm
'
]
=
max_unorm
args
[
'
skip_zeros
'
]
=
skip_zeros
args
[
"
optim_bits
"
]
=
optim_bits
args
[
"
percentile_clipping
"
]
=
100
args
[
"
min_8bit_size
"
]
=
min_8bit_size
args
[
"
percentile_clipping
"
]
=
percentile_clipping
args
[
"
block_wise
"
]
=
block_wise
args
[
"
max_unorm
"
]
=
max_unorm
args
[
"
skip_zeros
"
]
=
skip_zeros
self
.
args
=
MockArgs
(
args
)
else
:
...
...
@@ -419,43 +548,61 @@ class Optimizer1State(Optimizer8bit):
def
init_state
(
self
,
group
,
p
,
gindex
,
pindex
):
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
if
config
[
'
optim_bits
'
]
==
32
:
if
config
[
"
optim_bits
"
]
==
32
:
dtype
=
torch
.
float32
elif
config
[
'
optim_bits
'
]
==
8
:
elif
config
[
"
optim_bits
"
]
==
8
:
dtype
=
torch
.
uint8
else
:
raise
NotImplementedError
(
f
'Amount of optimizer bits not supported:
{
config
[
"optim_bits"
]
}
'
)
else
:
raise
NotImplementedError
(
f
'Amount of optimizer bits not supported:
{
config
[
"optim_bits"
]
}
'
)
if
p
.
numel
()
<
config
[
'min_8bit_size'
]:
dtype
=
torch
.
float32
if
p
.
numel
()
<
config
[
"min_8bit_size"
]:
dtype
=
torch
.
float32
state
=
self
.
state
[
p
]
state
[
'
step
'
]
=
0
state
[
"
step
"
]
=
0
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
state
[
'state1'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
,
)
elif
dtype
==
torch
.
uint8
:
if
state
[
'step'
]
==
0
:
if
'dynamic'
not
in
self
.
name2qmap
:
self
.
fill_qmap
()
self
.
name2qmap
[
'dynamic'
]
=
self
.
name2qmap
[
'dynamic'
].
to
(
p
.
device
)
state
[
'state1'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
)
state
[
'qmap1'
]
=
self
.
name2qmap
[
'dynamic'
]
if
config
[
'block_wise'
]:
if
state
[
"step"
]
==
0
:
if
"dynamic"
not
in
self
.
name2qmap
:
self
.
fill_qmap
()
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
,
)
state
[
"qmap1"
]
=
self
.
name2qmap
[
"dynamic"
]
if
config
[
"block_wise"
]:
n
=
p
.
numel
()
blocks
=
n
//
2048
blocks
=
n
//
2048
blocks
+=
1
if
n
%
2048
>
0
else
0
state
[
'absmax1'
]
=
torch
.
zeros
((
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"absmax1"
]
=
torch
.
zeros
(
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
else
:
state
[
'max1'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
'new_max1'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
if
config
[
'percentile_clipping'
]
<
100
:
state
[
'gnorm_vec'
]
=
torch
.
zeros
((
100
,),
device
=
p
.
device
)
state
[
"max1"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
if
config
[
'max_unorm'
]
>
0.
0
:
state
[
'u
norm_vec
'
]
=
torch
.
zeros
((
1
,),
device
=
p
.
device
)
if
config
[
"percentile_clipping"
]
<
10
0
:
state
[
"g
norm_vec
"
]
=
torch
.
zeros
((
1
00
,),
device
=
p
.
device
)
if
config
[
"max_unorm"
]
>
0.0
:
state
[
"unorm_vec"
]
=
torch
.
zeros
((
1
,),
device
=
p
.
device
)
@
torch
.
no_grad
()
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
...
...
@@ -464,29 +611,77 @@ class Optimizer1State(Optimizer8bit):
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
state
[
'
step
'
]
+=
1
step
=
state
[
'
step
'
]
state
[
"
step
"
]
+=
1
step
=
state
[
"
step
"
]
if
config
[
'percentile_clipping'
]
<
100
:
current_gnorm
,
clip_value
,
gnorm_scale
=
F
.
percentile_clipping
(
grad
,
state
[
'gnorm_vec'
],
step
,
config
[
'percentile_clipping'
])
if
config
[
"percentile_clipping"
]
<
100
:
current_gnorm
,
clip_value
,
gnorm_scale
=
F
.
percentile_clipping
(
grad
,
state
[
"gnorm_vec"
],
step
,
config
[
"percentile_clipping"
]
)
else
:
gnorm_scale
=
1.0
if
state
[
'state1'
].
dtype
==
torch
.
float
:
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
config
[
'betas'
][
0
],
config
[
'eps'
],
step
,
config
[
'lr'
],
None
,
0.0
,
config
[
'weight_decay'
],
gnorm_scale
,
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
],
skip_zeros
=
config
[
'skip_zeros'
])
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
not
config
[
'block_wise'
]:
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
None
,
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
config
[
'eps'
],
step
,
config
[
'lr'
],
state
[
'qmap1'
],
None
,
state
[
'max1'
],
None
,
state
[
'new_max1'
],
None
,
config
[
'weight_decay'
],
gnorm_scale
,
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
])
state
[
'max1'
],
state
[
'new_max1'
]
=
state
[
'new_max1'
],
state
[
'max1'
]
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
config
[
'block_wise'
]:
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
None
,
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
config
[
'eps'
],
step
,
config
[
'lr'
],
state
[
'qmap1'
],
None
,
state
[
'absmax1'
],
None
,
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
config
[
'skip_zeros'
])
if
state
[
"state1"
].
dtype
==
torch
.
float
:
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
"state1"
],
config
[
"betas"
][
0
],
config
[
"eps"
],
step
,
config
[
"lr"
],
None
,
0.0
,
config
[
"weight_decay"
],
gnorm_scale
,
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
max_unorm
=
config
[
"max_unorm"
],
skip_zeros
=
config
[
"skip_zeros"
],
)
elif
state
[
"state1"
].
dtype
==
torch
.
uint8
and
not
config
[
"block_wise"
]:
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
"state1"
],
None
,
config
[
"betas"
][
0
],
config
[
"betas"
][
1
],
config
[
"eps"
],
step
,
config
[
"lr"
],
state
[
"qmap1"
],
None
,
state
[
"max1"
],
None
,
state
[
"new_max1"
],
None
,
config
[
"weight_decay"
],
gnorm_scale
,
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
max_unorm
=
config
[
"max_unorm"
],
)
state
[
"max1"
],
state
[
"new_max1"
]
=
state
[
"new_max1"
],
state
[
"max1"
]
elif
state
[
"state1"
].
dtype
==
torch
.
uint8
and
config
[
"block_wise"
]:
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
"state1"
],
None
,
config
[
"betas"
][
0
],
config
[
"betas"
][
1
],
config
[
"eps"
],
step
,
config
[
"lr"
],
state
[
"qmap1"
],
None
,
state
[
"absmax1"
],
None
,
config
[
"weight_decay"
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
config
[
"skip_zeros"
],
)
bitsandbytes/optim/rmsprop.py
View file @
bfa0e332
...
...
@@ -4,33 +4,106 @@
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
RMSprop
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
super
(
RMSprop
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop
,
self
).
__init__
(
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
RMSprop8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
super
(
RMSprop8bit
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop8bit
,
self
).
__init__
(
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
RMSprop32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
super
(
RMSprop32bit
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop32bit
,
self
).
__init__
(
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/sgd.py
View file @
bfa0e332
...
...
@@ -4,29 +4,96 @@
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
SGD
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
super
(
SGD
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD
,
self
).
__init__
(
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
SGD8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
super
(
SGD8bit
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD8bit
,
self
).
__init__
(
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
SGD32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
super
(
SGD32bit
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD32bit
,
self
).
__init__
(
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/utils.py
View file @
bfa0e332
import
sys
def
print_err
(
s
:
str
)
->
None
:
print
(
s
,
file
=
sys
.
stderr
)
def
warn_of_missing_prerequisite
(
s
:
str
)
->
None
:
print_err
(
'
WARNING, missing pre-requisite:
'
+
s
)
print_err
(
"
WARNING, missing pre-requisite:
"
+
s
)
quicktest.py
View file @
bfa0e332
from
itertools
import
product
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
from
itertools
import
product
def
test_igemmlt
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
k
=
25
for
i
in
range
(
k
):
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
'
col32
'
)
B2
,
SB
=
F
.
transform
(
B
,
'
colx
'
)
A2
,
SA
=
F
.
transform
(
A
,
"
col32
"
)
B2
,
SB
=
F
.
transform
(
B
,
"
colx
"
)
if
dims
==
2
:
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
),
'col32'
)
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"col32"
,
)
else
:
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
),
'col32'
)
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"col32"
,
)
F
.
igemmlt
(
A2
,
B2
,
C2
,
SA
,
SB
,
SC
)
C3
,
S
=
F
.
transform
(
C2
,
'
row
'
,
state
=
SC
)
#torch.testing.assert_allclose(C1, C3.float())
#print(C1)
#print(C2)
#print(C3)
C3
,
S
=
F
.
transform
(
C2
,
"
row
"
,
state
=
SC
)
#
torch.testing.assert_allclose(C1, C3.float())
#
print(C1)
#
print(C2)
#
print(C3)
allclose
=
torch
.
allclose
(
C1
,
C3
.
float
())
if
allclose
:
print
(
C1
)
...
...
@@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
print
(
C3
)
## transposed
#A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
#if dims == 2:
#
A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
#
if dims == 2:
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float().t())
#elif dims == 3:
#
elif dims == 3:
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.float(), A.t().float())
# C1 = C1.permute([2, 0, 1])
#A2, SA = F.transform(A, 'col32')
#B2, SB = F.transform(B, 'colx')
#if dims == 2:
#
A2, SA = F.transform(A, 'col32')
#
B2, SB = F.transform(B, 'colx')
#
if dims == 2:
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
#else:
#
else:
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
# state = (C2.shape, 'row', A.shape[0])
# C2, SC = F.transform(C2, 'col32', state=state)
#F.igemmlt(A2, B2, C2, SA, SB, SC)
#C3, S = F.transform(C2, 'row', state=SC, ld=[0])
#torch.testing.assert_allclose(C1, C3.float())
#
F.igemmlt(A2, B2, C2, SA, SB, SC)
#
C3, S = F.transform(C2, 'row', state=SC, ld=[0])
#
torch.testing.assert_allclose(C1, C3.float())
## weight update
#if dims == 3:
#
if dims == 3:
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
...
...
@@ -73,18 +87,18 @@ dims = (2, 3)
ldb
=
[
0
]
n
=
2
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
for
ldb
in
range
(
32
,
4096
,
32
):
#
for ldb in [None]:
#
for ldb in [None]:
val
=
test_igemmlt
(
2
,
2
,
2
,
2
,
2
,
ldb
)
if
val
:
print
(
val
,
ldb
)
else
:
print
(
'
nope
'
,
ldb
)
#for val in values:
#test_igemmlt(*val)
print
(
"
nope
"
,
ldb
)
#
for val in values:
#
test_igemmlt(*val)
setup.py
View file @
bfa0e332
...
...
@@ -2,18 +2,20 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
glob
from
setuptools
import
setup
,
find_package
s
import
o
s
from
setuptools
import
find_packages
,
setup
libs
=
list
(
glob
.
glob
(
'
./bitsandbytes/libbitsandbytes*.so
'
))
libs
=
list
(
glob
.
glob
(
"
./bitsandbytes/libbitsandbytes*.so
"
))
libs
=
[
os
.
path
.
basename
(
p
)
for
p
in
libs
]
print
(
'libs:'
,
libs
)
print
(
"libs:"
,
libs
)
def
read
(
fname
):
return
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
fname
)).
read
()
setup
(
name
=
f
"bitsandbytes"
,
version
=
f
"0.31.0"
,
...
...
@@ -27,11 +29,11 @@ setup(
entry_points
=
{
"console_scripts"
:
[
"debug_cuda = bitsandbytes.debug_cli:cli"
],
},
package_data
=
{
''
:
libs
},
long_description
=
read
(
'
README.md
'
),
long_description_content_type
=
'
text/markdown
'
,
package_data
=
{
""
:
libs
},
long_description
=
read
(
"
README.md
"
),
long_description_content_type
=
"
text/markdown
"
,
classifiers
=
[
"Development Status :: 4 - Beta"
,
'
Topic :: Scientific/Engineering :: Artificial Intelligence
'
"
Topic :: Scientific/Engineering :: Artificial Intelligence
"
,
],
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment