Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
bfa0e332
Commit
bfa0e332
authored
Aug 01, 2022
by
Titus von Koeller
Browse files
ran black and isort for coherent code formatting
parent
597a8521
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2449 additions
and
991 deletions
+2449
-991
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+11
-9
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+88
-46
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+15
-8
bitsandbytes/cuda_setup.py
bitsandbytes/cuda_setup.py
+46
-32
bitsandbytes/debug_cli.py
bitsandbytes/debug_cli.py
+0
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+981
-414
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+4
-4
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+147
-50
bitsandbytes/optim/__init__.py
bitsandbytes/optim/__init__.py
+3
-3
bitsandbytes/optim/adagrad.py
bitsandbytes/optim/adagrad.py
+93
-21
bitsandbytes/optim/adam.py
bitsandbytes/optim/adam.py
+128
-51
bitsandbytes/optim/adamw.py
bitsandbytes/optim/adamw.py
+85
-19
bitsandbytes/optim/lamb.py
bitsandbytes/optim/lamb.py
+97
-20
bitsandbytes/optim/lars.py
bitsandbytes/optim/lars.py
+126
-41
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+380
-185
bitsandbytes/optim/rmsprop.py
bitsandbytes/optim/rmsprop.py
+94
-21
bitsandbytes/optim/sgd.py
bitsandbytes/optim/sgd.py
+88
-21
bitsandbytes/utils.py
bitsandbytes/utils.py
+3
-1
quicktest.py
quicktest.py
+47
-33
setup.py
setup.py
+13
-11
No files found.
bitsandbytes/__init__.py
View file @
bfa0e332
...
@@ -3,14 +3,16 @@
...
@@ -3,14 +3,16 @@
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.
nn
import
modules
from
.
autograd._functions
import
(
MatmulLtState
,
bmm_cublas
,
matmul
,
from
.autograd._functions
import
mm_cublas
,
bmm
_cublas
,
m
atmul
_cublas
,
matmul
,
MatmulLtState
matmul
_cublas
,
m
m
_cublas
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
if
COMPILED_WITH_CUDA
:
if
COMPILED_WITH_CUDA
:
from
.optim
import
adam
from
.optim
import
adam
__pdoc__
=
{
'libbitsandbytes'
:
False
,
__pdoc__
=
{
'optim.optimizer.Optimizer8bit'
:
False
,
"libbitsandbytes"
:
False
,
'optim.optimizer.MockArgs'
:
False
"optim.optimizer.Optimizer8bit"
:
False
,
}
"optim.optimizer.MockArgs"
:
False
,
}
bitsandbytes/autograd/_functions.py
View file @
bfa0e332
from
dataclasses
import
dataclass
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
dataclasses
import
dataclass
tensor
=
torch
.
Tensor
tensor
=
torch
.
Tensor
'''
"""
This class pools outlier dimensions across layers.
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
are less systematic and occur with low frequency.
'''
"""
class
GlobalOutlierPooler
(
object
):
class
GlobalOutlierPooler
(
object
):
_instance
=
None
_instance
=
None
def
__init__
(
self
):
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
def
initialize
(
self
):
self
.
outliers
=
set
()
self
.
outliers
=
set
()
...
@@ -29,25 +32,29 @@ class GlobalOutlierPooler(object):
...
@@ -29,25 +32,29 @@ class GlobalOutlierPooler(object):
return
cls
.
_instance
return
cls
.
_instance
def
add_outliers
(
self
,
outlier_idx
,
feature_dim
):
def
add_outliers
(
self
,
outlier_idx
,
feature_dim
):
if
self
.
model_dim
is
None
:
self
.
model_dim
=
feature_dim
if
self
.
model_dim
is
None
:
if
feature_dim
!=
self
.
model_dim
:
return
# we do not encode outliers for the 2nd FFN layer
self
.
model_dim
=
feature_dim
if
feature_dim
!=
self
.
model_dim
:
return
# we do not encode outliers for the 2nd FFN layer
self
.
outliers
.
update
(
outlier_idx
.
tolist
())
self
.
outliers
.
update
(
outlier_idx
.
tolist
())
def
get_current_outlier_idx
(
self
):
def
get_current_outlier_idx
(
self
):
return
torch
.
Tensor
(
list
(
self
.
outliers
)).
to
(
torch
.
int64
)
return
torch
.
Tensor
(
list
(
self
.
outliers
)).
to
(
torch
.
int64
)
class
MatMul8bit
(
torch
.
autograd
.
Function
):
class
MatMul8bit
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
quant_type
=
'
vector
'
,
precision
=
[
8
,
8
,
8
]):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
quant_type
=
"
vector
"
,
precision
=
[
8
,
8
,
8
]):
if
precision
[
0
]
!=
8
:
if
precision
[
0
]
!=
8
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
torch
.
matmul
(
A
,
B
)
output
=
torch
.
matmul
(
A
,
B
)
else
:
else
:
if
len
(
B
.
shape
)
==
2
:
dim
=
0
if
len
(
B
.
shape
)
==
2
:
else
:
dim
=
1
dim
=
0
else
:
dim
=
1
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=-
1
,
quant_type
=
quant_type
)
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=-
1
,
quant_type
=
quant_type
)
qB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
dim
,
quant_type
=
quant_type
)
qB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
dim
,
quant_type
=
quant_type
)
iout
=
F
.
igemm
(
qA
,
qB
)
iout
=
F
.
igemm
(
qA
,
qB
)
...
@@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function):
...
@@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function):
else
:
else
:
if
len
(
B
.
shape
)
==
2
and
len
(
A
.
shape
)
==
3
:
if
len
(
B
.
shape
)
==
2
and
len
(
A
.
shape
)
==
3
:
grad_output
=
grad_output
.
contiguous
()
grad_output
=
grad_output
.
contiguous
()
if
not
grad_output
.
is_contiguous
():
grad_output
.
contiguous
()
if
not
grad_output
.
is_contiguous
():
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
grad_output
.
contiguous
()
if
not
A
.
is_contiguous
():
A
=
A
.
contiguous
()
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
qA
,
S2
=
F
.
vectorwise_quant
(
A
.
view
(
-
1
,
A
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
,
)
if
not
A
.
is_contiguous
():
A
=
A
.
contiguous
()
qA
,
S2
=
F
.
vectorwise_quant
(
A
.
view
(
-
1
,
A
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
t
(),
qgrad_output
)
igrad_B
=
F
.
igemm
(
qA
.
t
(),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
t
(),
S1
,
grad_output
.
dtype
,
quant_type
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
t
(),
S1
,
grad_output
.
dtype
,
quant_type
)
else
:
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
permute
(
permute_dim
),
S1
,
grad_output
.
dtype
,
quant_type
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
permute
(
permute_dim
),
S1
,
grad_output
.
dtype
,
quant_type
,
)
if
A
.
requires_grad
:
if
A
.
requires_grad
:
if
len
(
grad_output
.
shape
)
==
3
:
dims
=
[
2
]
if
len
(
grad_output
.
shape
)
==
3
:
else
:
dims
=
[
1
]
dims
=
[
2
]
else
:
dims
=
[
1
]
if
len
(
B
.
shape
)
==
3
:
if
len
(
B
.
shape
)
==
3
:
# bio -> boi
# bio -> boi
...
@@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function):
...
@@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
grad_A
=
torch
.
matmul
(
grad_output
,
B
.
permute
(
permute_dim
))
grad_A
=
torch
.
matmul
(
grad_output
,
B
.
permute
(
permute_dim
))
else
:
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
)
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
)
return
grad_A
,
grad_B
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
...
@@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply
...
@@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply
bmm_cublas
=
MatMul8bit
.
apply
bmm_cublas
=
MatMul8bit
.
apply
matmul_cublas
=
MatMul8bit
.
apply
matmul_cublas
=
MatMul8bit
.
apply
@
dataclass
@
dataclass
class
MatmulLtState
:
class
MatmulLtState
:
CB
=
None
CB
=
None
...
@@ -159,7 +191,6 @@ class MatmulLtState:
...
@@ -159,7 +191,6 @@ class MatmulLtState:
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
# 1. Quantize A
# 1. Quantize A
...
@@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function):
requires_gradB
=
B
.
requires_grad
requires_gradB
=
B
.
requires_grad
formatB
=
state
.
formatB
formatB
=
state
.
formatB
input_shape
=
A
.
shape
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
if
state
.
outlier_pool
is
None
:
assert
A
.
dtype
==
torch
.
float16
,
f
'The input data type needs to be fp16 but
{
A
.
dtype
}
was found!'
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
assert
(
A
.
dtype
==
torch
.
float16
),
f
"The input data type needs to be fp16 but
{
A
.
dtype
}
was found!"
# 1. Quantize A
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
...
@@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
#state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#
state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
#
if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
...
@@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
#if state.idx is not None:
#
if state.idx is not None:
# # extract outliers
# # extract outliers
# CA[:, state.idx] = 0
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
# subA = A[:, state.idx]
#else:
#
else:
# subA = None
# subA = None
else
:
else
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
subA
=
None
subA
=
None
# 2. Quantize B
# 2. Quantize B
if
state
.
has_fp16_weights
:
if
state
.
has_fp16_weights
:
has_grad
=
(
True
if
(
getattr
(
B
,
'
grad
'
,
None
)
is
not
None
)
else
False
)
has_grad
=
True
if
(
getattr
(
B
,
"
grad
"
,
None
)
is
not
None
)
else
False
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
if
is_transposed
:
B
=
B
.
contiguous
()
if
is_transposed
:
B
=
B
.
contiguous
()
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
state
.
reset_grads
()
state
.
reset_grads
()
...
@@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
state
.
idx
=
outlier_idx
state
.
idx
=
outlier_idx
#state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#
state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
#
if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
#else:
#
else:
# state.idx = outlier_idx
# state.idx = outlier_idx
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
state
.
subB
=
(
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
)
CA
[:,
state
.
idx
.
long
()]
=
0
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
subA
=
A
[:,
state
.
idx
.
long
()]
subA
=
A
[:,
state
.
idx
.
long
()]
...
@@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
# 3. Matmul
# 3. Matmul
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
...
@@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
#
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func
=
torch
.
clone
clone_func
=
torch
.
clone
return
clone_func
(
output
.
view
(
output_shape
))
return
clone_func
(
output
.
view
(
output_shape
))
...
@@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt
,
idx
=
ctx
.
tensor_states
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
formatB
=
ctx
.
formatB
state
=
ctx
.
state
state
=
ctx
.
state
assert
state
.
has_fp16_weights
,
'
Backprop only supported for fp16 weights.
'
assert
state
.
has_fp16_weights
,
"
Backprop only supported for fp16 weights.
"
if
len
(
grad_output
.
shape
)
==
3
:
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
...
@@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
if
req_gradB
:
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
'
col32
'
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"
col32
"
,
transpose
=
True
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
if
req_gradA
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
'
col32
'
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"
col32
"
)
if
state
.
CxBt
is
None
:
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
...
@@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function):
matmul
=
MatMul8bitLt
.
apply
matmul
=
MatMul8bitLt
.
apply
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
):
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
):
state
=
state
or
MatmulLtState
()
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
if
threshold
>
0.0
:
state
.
threshold
=
threshold
state
.
threshold
=
threshold
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
state
)
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
state
)
bitsandbytes/cextension.py
View file @
bfa0e332
import
ctypes
as
ct
import
ctypes
as
ct
import
os
import
os
from
warnings
import
warn
from
warnings
import
warn
from
bitsandbytes.cuda_setup
import
evaluate_cuda_setup
from
bitsandbytes.cuda_setup
import
evaluate_cuda_setup
...
@@ -8,17 +9,21 @@ class CUDALibrary_Singleton(object):
...
@@ -8,17 +9,21 @@ class CUDALibrary_Singleton(object):
_instance
=
None
_instance
=
None
def
__init__
(
self
):
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
def
initialize
(
self
):
self
.
context
=
{}
self
.
context
=
{}
binary_name
=
evaluate_cuda_setup
()
binary_name
=
evaluate_cuda_setup
()
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
__file__
)
+
f
'/
{
binary_name
}
'
):
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
__file__
)
+
f
"/
{
binary_name
}
"
):
print
(
f
'TODO: compile library for specific version:
{
binary_name
}
'
)
print
(
f
"TODO: compile library for specific version:
{
binary_name
}
"
)
print
(
'defaulting to libbitsandbytes.so'
)
print
(
"defaulting to libbitsandbytes.so"
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
'/libbitsandbytes.so'
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
"/libbitsandbytes.so"
)
else
:
else
:
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
f
'/
{
binary_name
}
'
)
self
.
lib
=
ct
.
cdll
.
LoadLibrary
(
os
.
path
.
dirname
(
__file__
)
+
f
"/
{
binary_name
}
"
)
@
classmethod
@
classmethod
def
get_instance
(
cls
):
def
get_instance
(
cls
):
...
@@ -35,6 +40,8 @@ try:
...
@@ -35,6 +40,8 @@ try:
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
COMPILED_WITH_CUDA
=
True
COMPILED_WITH_CUDA
=
True
except
AttributeError
:
except
AttributeError
:
warn
(
"The installed version of bitsandbytes was compiled without GPU support. "
warn
(
"8-bit optimizers and GPU quantization are unavailable."
)
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
COMPILED_WITH_CUDA
=
False
COMPILED_WITH_CUDA
=
False
bitsandbytes/cuda_setup.py
View file @
bfa0e332
...
@@ -18,31 +18,36 @@ evaluation:
...
@@ -18,31 +18,36 @@ evaluation:
- based on that set the default path
- based on that set the default path
"""
"""
import
ctypes
import
shlex
import
subprocess
from
os
import
environ
as
env
from
os
import
environ
as
env
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Set
,
Union
from
typing
import
Set
,
Union
from
.utils
import
warn_of_missing_prerequisite
,
print_err
import
ctypes
from
.utils
import
print_err
,
warn_of_missing_prerequisite
import
shlex
import
subprocess
def
execute_and_return
(
strCMD
):
def
execute_and_return
(
strCMD
):
proc
=
subprocess
.
Popen
(
shlex
.
split
(
strCMD
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
proc
=
subprocess
.
Popen
(
shlex
.
split
(
strCMD
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
out
,
err
=
proc
.
communicate
()
out
,
err
=
proc
.
communicate
()
out
,
err
=
out
.
decode
(
"UTF-8"
).
strip
(),
err
.
decode
(
"UTF-8"
).
strip
()
out
,
err
=
out
.
decode
(
"UTF-8"
).
strip
(),
err
.
decode
(
"UTF-8"
).
strip
()
return
out
,
err
return
out
,
err
def
check_cuda_result
(
cuda
,
result_val
):
def
check_cuda_result
(
cuda
,
result_val
):
if
result_val
!=
0
:
if
result_val
!=
0
:
cuda
.
cuGetErrorString
(
result_val
,
ctypes
.
byref
(
error_str
))
cuda
.
cuGetErrorString
(
result_val
,
ctypes
.
byref
(
error_str
))
print
(
f
"Count not initialize CUDA - failure!"
)
print
(
f
"Count not initialize CUDA - failure!"
)
raise
Exception
(
'
CUDA exception!
'
)
raise
Exception
(
"
CUDA exception!
"
)
return
result_val
return
result_val
# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
def
get_compute_capability
():
def
get_compute_capability
():
libnames
=
(
'
libcuda.so
'
,
'
libcuda.dylib
'
,
'
cuda.dll
'
)
libnames
=
(
"
libcuda.so
"
,
"
libcuda.dylib
"
,
"
cuda.dll
"
)
for
libname
in
libnames
:
for
libname
in
libnames
:
try
:
try
:
cuda
=
ctypes
.
CDLL
(
libname
)
cuda
=
ctypes
.
CDLL
(
libname
)
...
@@ -51,8 +56,7 @@ def get_compute_capability():
...
@@ -51,8 +56,7 @@ def get_compute_capability():
else
:
else
:
break
break
else
:
else
:
raise
OSError
(
"could not load any of: "
+
' '
.
join
(
libnames
))
raise
OSError
(
"could not load any of: "
+
" "
.
join
(
libnames
))
nGpus
=
ctypes
.
c_int
()
nGpus
=
ctypes
.
c_int
()
cc_major
=
ctypes
.
c_int
()
cc_major
=
ctypes
.
c_int
()
...
@@ -69,39 +73,43 @@ def get_compute_capability():
...
@@ -69,39 +73,43 @@ def get_compute_capability():
ccs
=
[]
ccs
=
[]
for
i
in
range
(
nGpus
.
value
):
for
i
in
range
(
nGpus
.
value
):
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
i
))
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
i
))
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceComputeCapability
(
ctypes
.
byref
(
cc_major
),
ctypes
.
byref
(
cc_minor
),
device
))
result
=
check_cuda_result
(
ccs
.
append
(
f
'
{
cc_major
.
value
}
.
{
cc_minor
.
value
}
'
)
cuda
,
cuda
.
cuDeviceComputeCapability
(
ctypes
.
byref
(
cc_major
),
ctypes
.
byref
(
cc_minor
),
device
),
)
ccs
.
append
(
f
"
{
cc_major
.
value
}
.
{
cc_minor
.
value
}
"
)
#TODO: handle different compute capabilities; for now, take the max
#
TODO: handle different compute capabilities; for now, take the max
ccs
.
sort
()
ccs
.
sort
()
return
ccs
[
-
1
]
# return ccs[-1]
return
ccs
CUDA_RUNTIME_LIB
:
str
=
"libcudart.so"
CUDA_RUNTIME_LIB
:
str
=
"libcudart.so"
def
tokenize_paths
(
paths
:
str
)
->
Set
[
Path
]:
def
tokenize_paths
(
paths
:
str
)
->
Set
[
Path
]:
return
{
return
{
Path
(
ld_path
)
for
ld_path
in
paths
.
split
(
":"
)
if
ld_path
}
Path
(
ld_path
)
for
ld_path
in
paths
.
split
(
':'
)
if
ld_path
}
def
get_cuda_runtime_lib_path
(
def
get_cuda_runtime_lib_path
(
# TODO: replace this with logic for all paths in env vars
# TODO: replace this with logic for all paths in env vars
LD_LIBRARY_PATH
:
Union
[
str
,
None
]
=
env
.
get
(
"LD_LIBRARY_PATH"
)
LD_LIBRARY_PATH
:
Union
[
str
,
None
]
=
env
.
get
(
"LD_LIBRARY_PATH"
)
)
->
Union
[
Path
,
None
]:
)
->
Union
[
Path
,
None
]:
""" # TODO: add doc-string
"""# TODO: add doc-string"""
"""
if
not
LD_LIBRARY_PATH
:
if
not
LD_LIBRARY_PATH
:
warn_of_missing_prerequisite
(
warn_of_missing_prerequisite
(
'
LD_LIBRARY_PATH is completely missing from environment!
'
"
LD_LIBRARY_PATH is completely missing from environment!
"
)
)
return
None
return
None
ld_library_paths
:
Set
[
Path
]
=
tokenize_paths
(
LD_LIBRARY_PATH
)
ld_library_paths
:
Set
[
Path
]
=
tokenize_paths
(
LD_LIBRARY_PATH
)
non_existent_directories
:
Set
[
Path
]
=
{
non_existent_directories
:
Set
[
Path
]
=
{
path
for
path
in
ld_library_paths
path
for
path
in
ld_library_paths
if
not
path
.
exists
()
if
not
path
.
exists
()
}
}
if
non_existent_directories
:
if
non_existent_directories
:
...
@@ -111,7 +119,8 @@ def get_cuda_runtime_lib_path(
...
@@ -111,7 +119,8 @@ def get_cuda_runtime_lib_path(
)
)
cuda_runtime_libs
:
Set
[
Path
]
=
{
cuda_runtime_libs
:
Set
[
Path
]
=
{
path
/
CUDA_RUNTIME_LIB
for
path
in
ld_library_paths
path
/
CUDA_RUNTIME_LIB
for
path
in
ld_library_paths
if
(
path
/
CUDA_RUNTIME_LIB
).
is_file
()
if
(
path
/
CUDA_RUNTIME_LIB
).
is_file
()
}
-
non_existent_directories
}
-
non_existent_directories
...
@@ -126,26 +135,31 @@ def get_cuda_runtime_lib_path(
...
@@ -126,26 +135,31 @@ def get_cuda_runtime_lib_path(
single_cuda_runtime_lib_dir
=
next
(
iter
(
cuda_runtime_libs
))
single_cuda_runtime_lib_dir
=
next
(
iter
(
cuda_runtime_libs
))
return
single_cuda_runtime_lib_dir
return
single_cuda_runtime_lib_dir
def
evaluate_cuda_setup
():
def
evaluate_cuda_setup
():
cuda_path
=
get_cuda_runtime_lib_path
()
cuda_path
=
get_cuda_runtime_lib_path
()
cc
=
get_compute_capability
()
cc
=
get_compute_capability
()
binary_name
=
'
libbitsandbytes_cpu.so
'
binary_name
=
"
libbitsandbytes_cpu.so
"
if
not
(
has_gpu
:
=
bool
(
cc
)):
if
not
(
has_gpu
:
=
bool
(
cc
)):
print
(
'WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library...'
)
print
(
"WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library..."
)
return
binary_name
return
binary_name
has_cublaslt
=
cc
in
[
'
7.5
'
,
'
8.0
'
,
'
8.6
'
]
has_cublaslt
=
cc
in
[
"
7.5
"
,
"
8.0
"
,
"
8.6
"
]
# TODO:
# TODO:
# (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
# (2) Multiple CUDA versions installed
cuda_home
=
str
(
Path
(
cuda_path
).
parent
.
parent
)
cuda_home
=
str
(
Path
(
cuda_path
).
parent
.
parent
)
ls_output
,
err
=
execute_and_return
(
f
'
{
cuda_home
}
/bin/nvcc --version'
)
ls_output
,
err
=
execute_and_return
(
f
"
{
cuda_home
}
/bin/nvcc --version"
)
cuda_version
=
ls_output
.
split
(
'
\n
'
)[
3
].
split
(
','
)[
-
1
].
strip
().
lower
().
replace
(
'v'
,
''
)
cuda_version
=
(
major
,
minor
,
revision
=
cuda_version
.
split
(
'.'
)
ls_output
.
split
(
"
\n
"
)[
3
].
split
(
","
)[
-
1
].
strip
().
lower
().
replace
(
"v"
,
""
)
cuda_version_string
=
f
'
{
major
}{
minor
}
'
)
major
,
minor
,
revision
=
cuda_version
.
split
(
"."
)
cuda_version_string
=
f
"
{
major
}{
minor
}
"
binary_name
=
f
'libbitsandbytes_cuda
{
cuda_version_string
}
_
{
(
"cublaslt"
if
has_cublaslt
else
""
)
}
.so'
binary_name
=
f
'libbitsandbytes_cuda
{
cuda_version_string
}
_
{
(
"cublaslt"
if
has_cublaslt
else
""
)
}
.so'
...
...
bitsandbytes/debug_cli.py
View file @
bfa0e332
import
typer
import
typer
cli
=
typer
.
Typer
()
cli
=
typer
.
Typer
()
...
...
bitsandbytes/functional.py
View file @
bfa0e332
...
@@ -9,47 +9,68 @@ from typing import Tuple
...
@@ -9,47 +9,68 @@ from typing import Tuple
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
.cextension
import
lib
,
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
,
lib
name2qmap
=
{}
name2qmap
=
{}
if
COMPILED_WITH_CUDA
:
if
COMPILED_WITH_CUDA
:
'''
C FUNCTIONS FOR OPTIMIZERS
'''
"""
C FUNCTIONS FOR OPTIMIZERS
"""
str2optimizer32bit
=
{}
str2optimizer32bit
=
{}
str2optimizer32bit
[
'
adam
'
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
"
adam
"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
'
momentum
'
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
"
momentum
"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
'
rmsprop
'
]
=
(
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
)
str2optimizer32bit
[
"
rmsprop
"
]
=
(
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
)
str2optimizer32bit
[
'
adagrad
'
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
)
str2optimizer32bit
[
"
adagrad
"
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
)
str2optimizer32bit
[
'
lars
'
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
"
lars
"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
'
lamb
'
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
"
lamb
"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer8bit
=
{}
str2optimizer8bit
=
{}
str2optimizer8bit
[
'adam'
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
"adam"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
'momentum'
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
)
str2optimizer8bit
[
"momentum"
]
=
(
str2optimizer8bit
[
'rmsprop'
]
=
(
lib
.
crmsprop_static_8bit_g32
,
lib
.
crmsprop_static_8bit_g16
)
lib
.
cmomentum_static_8bit_g32
,
str2optimizer8bit
[
'lamb'
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
lib
.
cmomentum_static_8bit_g16
,
str2optimizer8bit
[
'lars'
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
)
)
str2optimizer8bit
[
"rmsprop"
]
=
(
lib
.
crmsprop_static_8bit_g32
,
lib
.
crmsprop_static_8bit_g16
,
)
str2optimizer8bit
[
"lamb"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
"lars"
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
,
)
str2optimizer8bit_blockwise
=
{}
str2optimizer8bit_blockwise
=
{}
str2optimizer8bit_blockwise
[
'adam'
]
=
(
lib
.
cadam_8bit_blockwise_fp32
,
lib
.
cadam_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
"adam"
]
=
(
str2optimizer8bit_blockwise
[
'momentum'
]
=
(
lib
.
cmomentum_8bit_blockwise_fp32
,
lib
.
cmomentum_8bit_blockwise_fp16
)
lib
.
cadam_8bit_blockwise_fp32
,
str2optimizer8bit_blockwise
[
'rmsprop'
]
=
(
lib
.
crmsprop_8bit_blockwise_fp32
,
lib
.
crmsprop_8bit_blockwise_fp16
)
lib
.
cadam_8bit_blockwise_fp16
,
str2optimizer8bit_blockwise
[
'adagrad'
]
=
(
lib
.
cadagrad_8bit_blockwise_fp32
,
lib
.
cadagrad_8bit_blockwise_fp16
)
)
str2optimizer8bit_blockwise
[
"momentum"
]
=
(
lib
.
cmomentum_8bit_blockwise_fp32
,
lib
.
cmomentum_8bit_blockwise_fp16
,
)
str2optimizer8bit_blockwise
[
"rmsprop"
]
=
(
lib
.
crmsprop_8bit_blockwise_fp32
,
lib
.
crmsprop_8bit_blockwise_fp16
,
)
str2optimizer8bit_blockwise
[
"adagrad"
]
=
(
lib
.
cadagrad_8bit_blockwise_fp32
,
lib
.
cadagrad_8bit_blockwise_fp16
,
)
class
CUBLAS_Context
(
object
):
class
CUBLAS_Context
(
object
):
_instance
=
None
_instance
=
None
def
__init__
(
self
):
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
def
initialize
(
self
):
self
.
context
=
{}
self
.
context
=
{}
#prev_device = torch.cuda.current_device()
#
prev_device = torch.cuda.current_device()
#for i in range(torch.cuda.device_count()):
#
for i in range(torch.cuda.device_count()):
# torch.cuda.set_device(torch.device('cuda', i))
# torch.cuda.set_device(torch.device('cuda', i))
# self.context.append(ct.c_void_p(lib.get_context()))
# self.context.append(ct.c_void_p(lib.get_context()))
#torch.cuda.set_device(prev_device)
#
torch.cuda.set_device(prev_device)
@
classmethod
@
classmethod
def
get_instance
(
cls
):
def
get_instance
(
cls
):
...
@@ -66,11 +87,12 @@ class CUBLAS_Context(object):
...
@@ -66,11 +87,12 @@ class CUBLAS_Context(object):
torch
.
cuda
.
set_device
(
prev_device
)
torch
.
cuda
.
set_device
(
prev_device
)
return
self
.
context
[
device
.
index
]
return
self
.
context
[
device
.
index
]
class
Cusparse_Context
(
object
):
class
Cusparse_Context
(
object
):
_instance
=
None
_instance
=
None
def
__init__
(
self
):
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
def
initialize
(
self
):
self
.
context
=
ct
.
c_void_p
(
lib
.
get_cusparse
())
self
.
context
=
ct
.
c_void_p
(
lib
.
get_cusparse
())
...
@@ -82,14 +104,16 @@ class Cusparse_Context(object):
...
@@ -82,14 +104,16 @@ class Cusparse_Context(object):
cls
.
_instance
.
initialize
()
cls
.
_instance
.
initialize
()
return
cls
.
_instance
return
cls
.
_instance
def
create_linear_map
(
signed
=
True
):
def
create_linear_map
(
signed
=
True
):
if
signed
:
if
signed
:
return
torch
.
linspace
(
-
1.0
,
1.0
,
256
)
return
torch
.
linspace
(
-
1.0
,
1.0
,
256
)
else
:
else
:
return
torch
.
linspace
(
0.0
,
1.0
,
256
)
return
torch
.
linspace
(
0.0
,
1.0
,
256
)
def
create_dynamic_map
(
signed
=
True
,
n
=
7
):
def
create_dynamic_map
(
signed
=
True
,
n
=
7
):
'''
"""
Creates the dynamic quantiztion map.
Creates the dynamic quantiztion map.
The dynamic data type is made up of a dynamic exponent and
The dynamic data type is made up of a dynamic exponent and
...
@@ -103,46 +127,54 @@ def create_dynamic_map(signed=True, n=7):
...
@@ -103,46 +127,54 @@ def create_dynamic_map(signed=True, n=7):
For more details see
For more details see
(8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
(8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
'''
"""
data
=
[]
data
=
[]
# these are additional items that come from the case
# these are additional items that come from the case
# where all the exponent bits are zero and no
# where all the exponent bits are zero and no
# indicator bit is present
# indicator bit is present
additional_items
=
2
**
(
7
-
n
)
-
1
additional_items
=
2
**
(
7
-
n
)
-
1
if
not
signed
:
additional_items
=
2
*
additional_items
if
not
signed
:
additional_items
=
2
*
additional_items
for
i
in
range
(
n
):
for
i
in
range
(
n
):
fraction_items
=
2
**
(
i
+
7
-
n
)
+
1
if
signed
else
2
**
(
i
+
7
-
n
+
1
)
+
1
fraction_items
=
2
**
(
i
+
7
-
n
)
+
1
if
signed
else
2
**
(
i
+
7
-
n
+
1
)
+
1
boundaries
=
torch
.
linspace
(
0.1
,
1
,
fraction_items
)
boundaries
=
torch
.
linspace
(
0.1
,
1
,
fraction_items
)
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
if
signed
:
if
signed
:
data
+=
(
-
(
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
data
+=
(
-
(
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
if
additional_items
>
0
:
if
additional_items
>
0
:
boundaries
=
torch
.
linspace
(
0.1
,
1
,
additional_items
+
1
)
boundaries
=
torch
.
linspace
(
0.1
,
1
,
additional_items
+
1
)
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
if
signed
:
if
signed
:
data
+=
(
-
(
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
data
+=
(
-
(
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
data
.
append
(
0
)
data
.
append
(
0
)
data
.
append
(
1.0
)
data
.
append
(
1.0
)
data
.
sort
()
data
.
sort
()
return
Tensor
(
data
)
return
Tensor
(
data
)
def
get_special_format_str
():
def
get_special_format_str
():
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
if
major
<
7
:
if
major
<
7
:
print
(
f
'Device with CUDA capability of
{
major
}
not supported for 8-bit matmul. Device has no tensor cores!'
)
print
(
f
"Device with CUDA capability of
{
major
}
not supported for 8-bit matmul. Device has no tensor cores!"
)
assert
major
>=
7
assert
major
>=
7
if
major
==
7
:
return
'col_turing'
if
major
==
7
:
elif
major
==
8
:
return
'col_ampere'
return
"col_turing"
else
:
return
'col_turing'
elif
major
==
8
:
return
"col_ampere"
else
:
return
"col_turing"
def
get_ptr
(
A
:
Tensor
)
->
ct
.
c_void_p
:
def
get_ptr
(
A
:
Tensor
)
->
ct
.
c_void_p
:
'''
"""
Get the ctypes pointer from a PyTorch Tensor.
Get the ctypes pointer from a PyTorch Tensor.
Parameters
Parameters
...
@@ -153,31 +185,39 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
...
@@ -153,31 +185,39 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
Returns
Returns
-------
-------
ctypes.c_void_p
ctypes.c_void_p
'''
"""
if
A
is
None
:
return
None
if
A
is
None
:
else
:
return
ct
.
c_void_p
(
A
.
data
.
storage
().
data_ptr
())
return
None
else
:
return
ct
.
c_void_p
(
A
.
data
.
storage
().
data_ptr
())
def
pre_call
(
device
):
def
pre_call
(
device
):
prev_device
=
torch
.
cuda
.
current_device
()
prev_device
=
torch
.
cuda
.
current_device
()
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
return
prev_device
return
prev_device
def
post_call
(
prev_device
):
def
post_call
(
prev_device
):
torch
.
cuda
.
set_device
(
prev_device
)
torch
.
cuda
.
set_device
(
prev_device
)
def
get_transform_func
(
dtype
,
orderA
,
orderOut
,
transpose
=
False
):
def
get_transform_func
(
dtype
,
orderA
,
orderOut
,
transpose
=
False
):
name
=
f
'ctransform_
{
(
8
if
dtype
==
torch
.
int8
else
32
)
}
_
{
orderA
}
_to_
{
orderOut
}
_
{
"t"
if
transpose
else
"n"
}
'
name
=
f
'ctransform_
{
(
8
if
dtype
==
torch
.
int8
else
32
)
}
_
{
orderA
}
_to_
{
orderOut
}
_
{
"t"
if
transpose
else
"n"
}
'
if
not
hasattr
(
lib
,
name
):
if
not
hasattr
(
lib
,
name
):
print
(
name
)
print
(
name
)
raise
ValueError
(
f
'Transform function not supported:
{
orderA
}
to
{
orderOut
}
for data type
{
dtype
}
and transpose=
{
transpose
}
'
)
raise
ValueError
(
f
"Transform function not supported:
{
orderA
}
to
{
orderOut
}
for data type
{
dtype
}
and transpose=
{
transpose
}
"
)
else
:
else
:
return
getattr
(
lib
,
name
)
return
getattr
(
lib
,
name
)
class
GlobalData
(
object
):
class
GlobalData
(
object
):
_instance
=
None
_instance
=
None
def
__init__
(
self
):
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
def
initialize
(
self
):
self
.
data
=
{}
self
.
data
=
{}
...
@@ -190,15 +230,17 @@ class GlobalData(object):
...
@@ -190,15 +230,17 @@ class GlobalData(object):
return
cls
.
_instance
return
cls
.
_instance
def
get_transform_buffer
(
shape
,
dtype
,
device
,
to_order
,
from_order
=
'row'
,
transpose
=
False
):
def
get_transform_buffer
(
#init_func = torch.empty
shape
,
dtype
,
device
,
to_order
,
from_order
=
"row"
,
transpose
=
False
):
# init_func = torch.empty
init_func
=
torch
.
zeros
init_func
=
torch
.
zeros
dims
=
len
(
shape
)
dims
=
len
(
shape
)
if
dims
==
2
:
if
dims
==
2
:
rows
=
shape
[
0
]
rows
=
shape
[
0
]
elif
dims
==
3
:
elif
dims
==
3
:
rows
=
shape
[
0
]
*
shape
[
1
]
rows
=
shape
[
0
]
*
shape
[
1
]
cols
=
shape
[
-
1
]
cols
=
shape
[
-
1
]
state
=
(
shape
,
to_order
)
state
=
(
shape
,
to_order
)
...
@@ -209,30 +251,39 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order='row', trans
...
@@ -209,30 +251,39 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order='row', trans
cols
=
tmp
cols
=
tmp
state
=
(
shape
[::
-
1
],
to_order
)
state
=
(
shape
[::
-
1
],
to_order
)
if
to_order
==
'
row
'
or
to_order
==
'
col
'
:
if
to_order
==
"
row
"
or
to_order
==
"
col
"
:
return
init_func
(
shape
,
dtype
=
dtype
,
device
=
device
),
state
return
init_func
(
shape
,
dtype
=
dtype
,
device
=
device
),
state
elif
to_order
==
'
col32
'
:
elif
to_order
==
"
col32
"
:
# blocks of 32 columns (padded)
# blocks of 32 columns (padded)
cols
=
32
*
((
cols
+
31
)
//
32
)
cols
=
32
*
((
cols
+
31
)
//
32
)
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
elif
to_order
==
'
col_turing
'
:
elif
to_order
==
"
col_turing
"
:
# blocks of 32 columns and 8 rows
# blocks of 32 columns and 8 rows
cols
=
32
*
((
cols
+
31
)
//
32
)
cols
=
32
*
((
cols
+
31
)
//
32
)
rows
=
8
*
((
rows
+
7
)
//
8
)
rows
=
8
*
((
rows
+
7
)
//
8
)
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
elif
to_order
==
'
col_ampere
'
:
elif
to_order
==
"
col_ampere
"
:
# blocks of 32 columns and 32 rows
# blocks of 32 columns and 32 rows
cols
=
32
*
((
cols
+
31
)
//
32
)
cols
=
32
*
((
cols
+
31
)
//
32
)
rows
=
32
*
((
rows
+
31
)
//
32
)
rows
=
32
*
((
rows
+
31
)
//
32
)
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
else
:
else
:
raise
NotImplementedError
(
f
'
To_order not supported:
{
to_order
}
'
)
raise
NotImplementedError
(
f
"
To_order not supported:
{
to_order
}
"
)
def
nvidia_transform
(
A
,
to_order
,
from_order
=
'row'
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
def
nvidia_transform
(
else
:
from_order
=
state
[
1
]
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
])
):
else
:
new_state
=
(
state
[
1
],
to_order
)
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
else
:
from_order
=
state
[
1
]
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
]
)
else
:
new_state
=
(
state
[
1
],
to_order
)
func
=
get_transform_func
(
A
.
dtype
,
from_order
,
to_order
,
transpose
)
func
=
get_transform_func
(
A
.
dtype
,
from_order
,
to_order
,
transpose
)
shape
=
state
[
0
]
shape
=
state
[
0
]
...
@@ -242,10 +293,10 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
...
@@ -242,10 +293,10 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
elif
ld
is
not
None
:
elif
ld
is
not
None
:
n
=
math
.
prod
(
shape
)
n
=
math
.
prod
(
shape
)
dim1
=
math
.
prod
([
shape
[
i
]
for
i
in
ld
])
dim1
=
math
.
prod
([
shape
[
i
]
for
i
in
ld
])
dim2
=
ct
.
c_int32
(
n
//
dim1
)
dim2
=
ct
.
c_int32
(
n
//
dim1
)
dim1
=
ct
.
c_int32
(
dim1
)
dim1
=
ct
.
c_int32
(
dim1
)
else
:
else
:
dim1
=
ct
.
c_int32
(
shape
[
0
]
*
shape
[
1
])
dim1
=
ct
.
c_int32
(
shape
[
0
]
*
shape
[
1
])
dim2
=
ct
.
c_int32
(
shape
[
2
])
dim2
=
ct
.
c_int32
(
shape
[
2
])
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
...
@@ -253,11 +304,13 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
...
@@ -253,11 +304,13 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
ptrOut
=
get_ptr
(
out
)
ptrOut
=
get_ptr
(
out
)
func
(
ptr
,
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
func
(
ptr
,
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
return
out
,
new_state
return
out
,
new_state
def
estimate_quantiles
(
A
:
Tensor
,
out
:
Tensor
=
None
,
offset
:
float
=
1
/
512
)
->
Tensor
:
'''
def
estimate_quantiles
(
A
:
Tensor
,
out
:
Tensor
=
None
,
offset
:
float
=
1
/
512
)
->
Tensor
:
"""
Estimates 256 equidistant quantiles on the input tensor eCDF.
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
...
@@ -282,18 +335,26 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
...
@@ -282,18 +335,26 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
-------
-------
torch.Tensor:
torch.Tensor:
The 256 quantiles in float32 datatype.
The 256 quantiles in float32 datatype.
'''
"""
if
out
is
None
:
out
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
if
out
is
None
:
out
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
if
A
.
dtype
==
torch
.
float32
:
if
A
.
dtype
==
torch
.
float32
:
lib
.
cestimate_quantiles_fp32
(
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_float
(
offset
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cestimate_quantiles_fp32
(
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_float
(
offset
),
ct
.
c_int
(
A
.
numel
())
)
elif
A
.
dtype
==
torch
.
float16
:
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cestimate_quantiles_fp16
(
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_float
(
offset
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cestimate_quantiles_fp16
(
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_float
(
offset
),
ct
.
c_int
(
A
.
numel
())
)
else
:
else
:
raise
NotImplementedError
(
f
'
Not supported data type
{
A
.
dtype
}
'
)
raise
NotImplementedError
(
f
"
Not supported data type
{
A
.
dtype
}
"
)
return
out
return
out
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
'''
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
"""
Quantize tensor A in blocks of size 4096 values.
Quantize tensor A in blocks of size 4096 values.
Quantizes tensor A by dividing it into blocks of 4096 values.
Quantizes tensor A by dividing it into blocks of 4096 values.
...
@@ -319,51 +380,96 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
...
@@ -319,51 +380,96 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
The 8-bit tensor.
The 8-bit tensor.
tuple(torch.Tensor, torch.Tensor):
tuple(torch.Tensor, torch.Tensor):
The quantization state to undo the quantization.
The quantization state to undo the quantization.
'''
"""
if
code
is
None
:
if
code
is
None
:
if
'dynamic'
not
in
name2qmap
:
name2qmap
[
'dynamic'
]
=
create_dynamic_map
().
to
(
A
.
device
)
if
"dynamic"
not
in
name2qmap
:
code
=
name2qmap
[
'dynamic'
]
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
"dynamic"
]
code
=
code
.
to
(
A
.
device
)
code
=
code
.
to
(
A
.
device
)
if
absmax
is
None
:
if
absmax
is
None
:
n
=
A
.
numel
()
n
=
A
.
numel
()
num_blocks
=
4096
num_blocks
=
4096
blocks
=
n
//
num_blocks
blocks
=
n
//
num_blocks
blocks
+=
1
if
n
%
num_blocks
>
0
else
0
blocks
+=
1
if
n
%
num_blocks
>
0
else
0
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
if
A
.
device
.
type
!=
'
cpu
'
:
if
A
.
device
.
type
!=
"
cpu
"
:
if
rand
is
not
None
:
if
rand
is
not
None
:
assert
rand
.
numel
()
>=
1024
assert
rand
.
numel
()
>=
1024
rand_offset
=
random
.
randint
(
0
,
1023
)
rand_offset
=
random
.
randint
(
0
,
1023
)
if
A
.
dtype
==
torch
.
float32
:
if
A
.
dtype
==
torch
.
float32
:
lib
.
cquantize_blockwise_stochastic_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_stochastic_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()),
)
elif
A
.
dtype
==
torch
.
float16
:
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_stochastic_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_stochastic_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
get_ptr
(
rand
),
ct
.
c_int32
(
rand_offset
),
ct
.
c_int
(
A
.
numel
()),
)
else
:
else
:
raise
ValueError
(
f
'Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
'
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
else
:
else
:
if
A
.
dtype
==
torch
.
float32
:
if
A
.
dtype
==
torch
.
float32
:
lib
.
cquantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
elif
A
.
dtype
==
torch
.
float16
:
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
else
:
else
:
raise
ValueError
(
f
'Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
'
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
else
:
else
:
# cpu
# cpu
assert
rand
is
None
assert
rand
is
None
lib
.
cquantize_blockwise_cpu_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_cpu_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
return
out
,
(
absmax
,
code
)
return
out
,
(
absmax
,
code
)
def
dequantize_blockwise
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
,
def
dequantize_blockwise
(
blocksize
:
int
=
4096
)
->
Tensor
:
A
:
Tensor
,
'''
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
4096
,
)
->
Tensor
:
"""
Dequantizes blockwise quantized values.
Dequantizes blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in
Dequantizes the tensor A with maximum absolute values absmax in
...
@@ -387,57 +493,94 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
...
@@ -387,57 +493,94 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
-------
-------
torch.Tensor:
torch.Tensor:
Dequantized tensor (default: float32)
Dequantized tensor (default: float32)
'''
"""
assert
quant_state
is
not
None
or
absmax
is
not
None
assert
quant_state
is
not
None
or
absmax
is
not
None
if
code
is
None
and
quant_state
is
None
:
if
code
is
None
and
quant_state
is
None
:
if
'dynamic'
not
in
name2qmap
:
name2qmap
[
'dynamic'
]
=
create_dynamic_map
().
to
(
A
.
device
)
if
"dynamic"
not
in
name2qmap
:
code
=
name2qmap
[
'dynamic'
]
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
"dynamic"
]
code
=
code
.
to
(
A
.
device
)
code
=
code
.
to
(
A
.
device
)
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
if
out
is
None
:
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
)
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
)
if
blocksize
not
in
[
2048
,
4096
]:
if
blocksize
not
in
[
2048
,
4096
]:
raise
ValueError
(
f
'The blockwise of
{
blocksize
}
is not supported. Supported values: [2048 4096]'
)
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048 4096]"
)
if
A
.
device
.
type
!=
'
cpu
'
:
if
A
.
device
.
type
!=
"
cpu
"
:
if
out
.
dtype
==
torch
.
float32
:
if
out
.
dtype
==
torch
.
float32
:
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()),
)
elif
out
.
dtype
==
torch
.
float16
:
elif
out
.
dtype
==
torch
.
float16
:
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()),
)
else
:
else
:
raise
ValueError
(
f
'Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
'
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
else
:
else
:
lib
.
cdequantize_blockwise_cpu_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_cpu_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
A
),
get_ptr
(
quant_state
[
0
]),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()),
)
return
out
return
out
def
quantize
(
A
:
Tensor
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
def
quantize
(
A
:
Tensor
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
if
code
is
None
:
if
code
is
None
:
if
'dynamic'
not
in
name2qmap
:
name2qmap
[
'dynamic'
]
=
create_dynamic_map
().
to
(
A
.
device
)
if
"dynamic"
not
in
name2qmap
:
code
=
name2qmap
[
'dynamic'
]
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
"dynamic"
]
code
=
code
.
to
(
A
.
device
)
code
=
code
.
to
(
A
.
device
)
absmax
=
torch
.
abs
(
A
).
max
()
absmax
=
torch
.
abs
(
A
).
max
()
inp
=
A
/
absmax
inp
=
A
/
absmax
out
=
quantize_no_absmax
(
inp
,
code
,
out
)
out
=
quantize_no_absmax
(
inp
,
code
,
out
)
return
out
,
(
absmax
,
code
)
return
out
,
(
absmax
,
code
)
def
dequantize
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
def
dequantize
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
,
)
->
Tensor
:
assert
quant_state
is
not
None
or
absmax
is
not
None
assert
quant_state
is
not
None
or
absmax
is
not
None
if
code
is
None
and
quant_state
is
None
:
if
code
is
None
and
quant_state
is
None
:
if
'dynamic'
not
in
name2qmap
:
name2qmap
[
'dynamic'
]
=
create_dynamic_map
().
to
(
A
.
device
)
if
"dynamic"
not
in
name2qmap
:
code
=
name2qmap
[
'dynamic'
]
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
"dynamic"
]
code
=
code
.
to
(
A
.
device
)
code
=
code
.
to
(
A
.
device
)
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
)
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
)
out
=
dequantize_no_absmax
(
A
,
quant_state
[
1
],
out
)
out
=
dequantize_no_absmax
(
A
,
quant_state
[
1
],
out
)
return
out
*
quant_state
[
0
]
return
out
*
quant_state
[
0
]
def
quantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
def
quantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
'''
"""
Quantizes input tensor to 8-bit.
Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
...
@@ -456,13 +599,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
...
@@ -456,13 +599,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
-------
-------
torch.Tensor:
torch.Tensor:
Quantized 8-bit tensor.
Quantized 8-bit tensor.
'''
"""
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
lib
.
cquantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
return
out
return
out
def
dequantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
'''
def
dequantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
"""
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
...
@@ -481,17 +626,31 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
...
@@ -481,17 +626,31 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
-------
-------
torch.Tensor:
torch.Tensor:
32-bit output tensor.
32-bit output tensor.
'''
"""
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
lib
.
cdequantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
return
out
return
out
def
optimizer_update_32bit
(
optimizer_name
:
str
,
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
beta1
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
def
optimizer_update_32bit
(
state2
:
Tensor
=
None
,
beta2
:
float
=
0.0
,
optimizer_name
:
str
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
g
:
Tensor
,
unorm_vec
:
Tensor
=
None
,
max_unorm
:
float
=
0.0
,
skip_zeros
=
False
)
->
None
:
p
:
Tensor
,
'''
state1
:
Tensor
,
beta1
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
state2
:
Tensor
=
None
,
beta2
:
float
=
0.0
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
unorm_vec
:
Tensor
=
None
,
max_unorm
:
float
=
0.0
,
skip_zeros
=
False
,
)
->
None
:
"""
Performs an inplace optimizer update with one or two optimizer states.
Performs an inplace optimizer update with one or two optimizer states.
Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.
Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.
...
@@ -528,33 +687,84 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
...
@@ -528,33 +687,84 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
The maximum update norm relative to the weight norm.
The maximum update norm relative to the weight norm.
skip_zeros : bool
skip_zeros : bool
Whether to skip zero-valued gradients or not (default: False).
Whether to skip zero-valued gradients or not (default: False).
'''
"""
param_norm
=
0.0
param_norm
=
0.0
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
if
optimizer_name
not
in
str2optimizer32bit
:
if
optimizer_name
not
in
str2optimizer32bit
:
raise
NotImplementedError
(
f
'Optimizer not implemented:
{
optimizer_name
}
. Choices:
{
","
.
join
(
str2optimizer32bit
.
keys
())
}
'
)
raise
NotImplementedError
(
f
'Optimizer not implemented:
{
optimizer_name
}
. Choices:
{
","
.
join
(
str2optimizer32bit
.
keys
())
}
'
)
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
float32
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
float32
:
str2optimizer32bit
[
optimizer_name
][
0
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
str2optimizer32bit
[
optimizer_name
][
0
](
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
get_ptr
(
g
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
float32
:
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
float32
:
str2optimizer32bit
[
optimizer_name
][
1
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
str2optimizer32bit
[
optimizer_name
][
1
](
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
get_ptr
(
g
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
else
:
else
:
raise
ValueError
(
f
'Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
'
)
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
def
optimizer_update_8bit
(
optimizer_name
:
str
,
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
state2
:
Tensor
,
)
beta1
:
float
,
beta2
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
Tensor
,
qmap2
:
Tensor
,
max1
:
Tensor
,
max2
:
Tensor
,
new_max1
:
Tensor
,
new_max2
:
Tensor
,
def
optimizer_update_8bit
(
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
optimizer_name
:
str
,
unorm_vec
:
Tensor
=
None
,
max_unorm
:
float
=
0.0
)
->
None
:
g
:
Tensor
,
'''
p
:
Tensor
,
state1
:
Tensor
,
state2
:
Tensor
,
beta1
:
float
,
beta2
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
Tensor
,
qmap2
:
Tensor
,
max1
:
Tensor
,
max2
:
Tensor
,
new_max1
:
Tensor
,
new_max2
:
Tensor
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
unorm_vec
:
Tensor
=
None
,
max_unorm
:
float
=
0.0
,
)
->
None
:
"""
Performs an inplace Adam update.
Performs an inplace Adam update.
Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
...
@@ -602,56 +812,135 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
...
@@ -602,56 +812,135 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
The tensor for the update norm.
The tensor for the update norm.
max_unorm : float
max_unorm : float
The maximum update norm relative to the weight norm.
The maximum update norm relative to the weight norm.
'''
"""
param_norm
=
0.0
param_norm
=
0.0
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
str2optimizer8bit
[
optimizer_name
][
0
](
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
get_ptr
(
p
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
get_ptr
(
g
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
state1
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
state2
),
get_ptr
(
max1
),
get_ptr
(
max2
),
get_ptr
(
new_max1
),
get_ptr
(
new_max2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()))
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
max1
),
get_ptr
(
max2
),
get_ptr
(
new_max1
),
get_ptr
(
new_max2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()),
)
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit
[
optimizer_name
][
1
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
str2optimizer8bit
[
optimizer_name
][
1
](
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
get_ptr
(
p
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
get_ptr
(
g
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
state1
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
state2
),
get_ptr
(
max1
),
get_ptr
(
max2
),
get_ptr
(
new_max1
),
get_ptr
(
new_max2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()))
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
max1
),
get_ptr
(
max2
),
get_ptr
(
new_max1
),
get_ptr
(
new_max2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_int32
(
g
.
numel
()),
)
else
:
else
:
raise
ValueError
(
f
'Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
'
)
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
def
optimizer_update_8bit_blockwise
(
optimizer_name
:
str
,
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
state2
:
Tensor
,
beta1
:
float
,
beta2
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
Tensor
,
qmap2
:
Tensor
,
def
optimizer_update_8bit_blockwise
(
absmax1
:
Tensor
,
absmax2
:
Tensor
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
optimizer_name
:
str
,
skip_zeros
=
False
)
->
None
:
g
:
Tensor
,
p
:
Tensor
,
state1
:
Tensor
,
state2
:
Tensor
,
beta1
:
float
,
beta2
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
Tensor
,
qmap2
:
Tensor
,
absmax1
:
Tensor
,
absmax2
:
Tensor
,
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
skip_zeros
=
False
,
)
->
None
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit_blockwise
[
optimizer_name
][
0
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
str2optimizer8bit_blockwise
[
optimizer_name
][
0
](
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
get_ptr
(
p
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
g
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
get_ptr
(
state1
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit_blockwise
[
optimizer_name
][
1
](
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
str2optimizer8bit_blockwise
[
optimizer_name
][
1
](
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
get_ptr
(
p
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
g
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
get_ptr
(
state1
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
else
:
else
:
raise
ValueError
(
f
'Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
'
)
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
def
percentile_clipping
(
grad
:
Tensor
,
gnorm_vec
:
Tensor
,
step
:
int
,
percentile
:
int
=
5
):
def
percentile_clipping
(
grad
:
Tensor
,
gnorm_vec
:
Tensor
,
step
:
int
,
percentile
:
int
=
5
):
"""Applies percentile clipping
"""Applies percentile clipping
grad: torch.Tensor
grad: torch.Tensor
...
@@ -663,11 +952,21 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
...
@@ -663,11 +952,21 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
"""
"""
if
grad
.
dtype
==
torch
.
float32
:
if
grad
.
dtype
==
torch
.
float32
:
lib
.
cpercentile_clipping_g32
(
get_ptr
(
grad
),
get_ptr
(
gnorm_vec
),
ct
.
c_int32
(
step
),
ct
.
c_int32
(
grad
.
numel
()))
lib
.
cpercentile_clipping_g32
(
get_ptr
(
grad
),
get_ptr
(
gnorm_vec
),
ct
.
c_int32
(
step
),
ct
.
c_int32
(
grad
.
numel
()),
)
elif
grad
.
dtype
==
torch
.
float16
:
elif
grad
.
dtype
==
torch
.
float16
:
lib
.
cpercentile_clipping_g16
(
get_ptr
(
grad
),
get_ptr
(
gnorm_vec
),
ct
.
c_int32
(
step
),
ct
.
c_int32
(
grad
.
numel
()))
lib
.
cpercentile_clipping_g16
(
get_ptr
(
grad
),
get_ptr
(
gnorm_vec
),
ct
.
c_int32
(
step
),
ct
.
c_int32
(
grad
.
numel
()),
)
else
:
else
:
raise
ValueError
(
f
'
Gradient type
{
grad
.
dtype
}
not supported!
'
)
raise
ValueError
(
f
"
Gradient type
{
grad
.
dtype
}
not supported!
"
)
current_gnorm
=
torch
.
sqrt
(
gnorm_vec
[
step
%
100
])
current_gnorm
=
torch
.
sqrt
(
gnorm_vec
[
step
%
100
])
vals
,
idx
=
torch
.
sort
(
gnorm_vec
)
vals
,
idx
=
torch
.
sort
(
gnorm_vec
)
...
@@ -675,31 +974,44 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
...
@@ -675,31 +974,44 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
gnorm_scale
=
1.0
gnorm_scale
=
1.0
if
current_gnorm
>
clip_value
:
if
current_gnorm
>
clip_value
:
gnorm_scale
=
clip_value
/
current_gnorm
gnorm_scale
=
clip_value
/
current_gnorm
return
current_gnorm
,
clip_value
,
gnorm_scale
return
current_gnorm
,
clip_value
,
gnorm_scale
def
histogram_scatter_add_2d
(
histogram
:
Tensor
,
index1
:
Tensor
,
index2
:
Tensor
,
source
:
Tensor
):
def
histogram_scatter_add_2d
(
histogram
:
Tensor
,
index1
:
Tensor
,
index2
:
Tensor
,
source
:
Tensor
):
assert
len
(
histogram
.
shape
)
==
2
assert
len
(
histogram
.
shape
)
==
2
assert
histogram
.
dtype
==
torch
.
float32
assert
histogram
.
dtype
==
torch
.
float32
assert
source
.
dtype
==
torch
.
float32
assert
source
.
dtype
==
torch
.
float32
assert
index1
.
dtype
==
torch
.
int32
assert
index1
.
dtype
==
torch
.
int32
assert
index2
.
dtype
==
torch
.
int32
assert
index2
.
dtype
==
torch
.
int32
assert
histogram
.
device
.
type
==
'
cuda
'
assert
histogram
.
device
.
type
==
"
cuda
"
assert
index1
.
device
.
type
==
'
cuda
'
assert
index1
.
device
.
type
==
"
cuda
"
assert
index2
.
device
.
type
==
'
cuda
'
assert
index2
.
device
.
type
==
"
cuda
"
assert
source
.
device
.
type
==
'
cuda
'
assert
source
.
device
.
type
==
"
cuda
"
maxdim1
=
ct
.
c_int32
(
histogram
.
shape
[
0
])
maxdim1
=
ct
.
c_int32
(
histogram
.
shape
[
0
])
n
=
ct
.
c_int32
(
index1
.
numel
())
n
=
ct
.
c_int32
(
index1
.
numel
())
lib
.
chistogram_scatter_add_2d
(
get_ptr
(
histogram
),
get_ptr
(
index1
),
get_ptr
(
index2
),
get_ptr
(
source
),
maxdim1
,
n
)
lib
.
chistogram_scatter_add_2d
(
get_ptr
(
histogram
),
get_ptr
(
index1
),
get_ptr
(
index2
),
get_ptr
(
source
),
maxdim1
,
n
,
)
def
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
int8
):
def
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
int8
):
if
not
torch
.
cuda
.
is_initialized
():
torch
.
cuda
.
init
()
if
not
torch
.
cuda
.
is_initialized
():
torch
.
cuda
.
init
()
if
A
.
dtype
!=
expected_type
or
B
.
dtype
!=
expected_type
:
if
A
.
dtype
!=
expected_type
or
B
.
dtype
!=
expected_type
:
raise
TypeError
(
f
'Expected torch.int8 input tensors A and B, but got
{
A
.
dtype
}
and
{
B
.
dtype
}
'
)
raise
TypeError
(
f
"Expected torch.int8 input tensors A and B, but got
{
A
.
dtype
}
and
{
B
.
dtype
}
"
)
sA
=
A
.
shape
sA
=
A
.
shape
sB
=
B
.
shape
sB
=
B
.
shape
...
@@ -709,64 +1021,101 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
...
@@ -709,64 +1021,101 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
correct
=
True
correct
=
True
if
len
(
sA
)
==
2
and
len
(
sB
)
==
2
:
if
len
(
sA
)
==
2
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
0
]:
correct
=
False
if
not
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
0
]:
elif
tA
and
not
tB
and
A
.
shape
[
0
]
!=
B
.
shape
[
0
]:
correct
=
False
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
0
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
0
]
!=
B
.
shape
[
0
]:
elif
not
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
0
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
0
]:
correct
=
False
if
not
tA
and
not
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
0
]:
elif
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
0
]:
correct
=
False
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
0
]:
elif
not
tA
and
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
1
]:
correct
=
False
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
if
not
tA
and
not
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
1
]:
correct
=
False
if
not
tA
and
not
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
1
]:
elif
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
2
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
elif
not
tA
and
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
2
]:
correct
=
False
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
2
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
2
]:
correct
=
False
if
out
is
not
None
:
if
out
is
not
None
:
sout
=
out
.
shape
sout
=
out
.
shape
# special case common in backprop
# special case common in backprop
if
not
correct
and
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
if
not
correct
and
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
if
(
sout
[
0
]
==
sA
[
2
]
and
sout
[
1
]
==
sB
[
2
]
and
if
(
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
sout
[
0
]
==
sA
[
2
]
and
sout
[
1
]
==
sB
[
2
]
and
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]
):
correct
=
True
correct
=
True
else
:
else
:
if
len
(
sA
)
==
2
and
len
(
sB
)
==
2
:
if
len
(
sA
)
==
2
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sB
[
1
])
if
not
tA
and
not
tB
:
elif
tA
and
tB
:
sout
=
(
sA
[
1
],
sB
[
0
])
sout
=
(
sA
[
0
],
sB
[
1
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
1
],
sB
[
1
])
elif
tA
and
tB
:
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sB
[
0
])
sout
=
(
sA
[
1
],
sB
[
0
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
1
],
sB
[
1
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sB
[
0
])
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
1
])
if
not
tA
and
not
tB
:
elif
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
0
])
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
1
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
1
])
elif
tA
and
tB
:
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
0
])
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
0
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
1
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
0
])
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
2
])
if
not
tA
and
not
tB
:
elif
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
1
])
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
2
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
2
])
elif
tA
and
tB
:
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
1
])
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
1
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
2
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
1
])
if
not
correct
:
if
not
correct
:
raise
ValueError
(
f
'Tensor dimensions incorrect for matrix mulitiplication: A x B:
{
sA
}
x
{
sB
}
with transpose for A x B:
{
tA
}
x
{
tB
}
.'
)
raise
ValueError
(
f
"Tensor dimensions incorrect for matrix mulitiplication: A x B:
{
sA
}
x
{
sB
}
with transpose for A x B:
{
tA
}
x
{
tB
}
."
)
return
sout
return
sout
def
igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
):
def
igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
):
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
int32
,
device
=
A
.
device
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
int32
,
device
=
A
.
device
)
if
len
(
A
.
shape
)
==
3
and
len
(
B
.
shape
)
==
3
:
if
len
(
A
.
shape
)
==
3
and
len
(
B
.
shape
)
==
3
:
if
A
.
shape
[
0
]
==
B
.
shape
[
0
]
and
A
.
shape
[
2
]
==
B
.
shape
[
1
]:
if
A
.
shape
[
0
]
==
B
.
shape
[
0
]
and
A
.
shape
[
2
]
==
B
.
shape
[
1
]:
return
batched_igemm
(
A
,
B
,
out
)
return
batched_igemm
(
A
,
B
,
out
)
sA
=
A
.
shape
sA
=
A
.
shape
sB
=
B
.
shape
sB
=
B
.
shape
if
transposed_A
and
len
(
sA
)
==
2
:
sA
=
(
sA
[
1
],
sA
[
0
])
if
transposed_A
and
len
(
sA
)
==
2
:
elif
transposed_A
and
len
(
sA
)
==
3
:
sA
=
(
sA
[
0
],
sA
[
2
],
sA
[
0
])
sA
=
(
sA
[
1
],
sA
[
0
])
if
transposed_B
and
len
(
sB
)
==
2
:
sB
=
(
sB
[
1
],
sB
[
0
])
elif
transposed_A
and
len
(
sA
)
==
3
:
elif
transposed_B
and
len
(
sB
)
==
3
:
sB
=
(
sB
[
0
],
sB
[
2
],
sB
[
0
])
sA
=
(
sA
[
0
],
sA
[
2
],
sA
[
0
])
if
transposed_B
and
len
(
sB
)
==
2
:
sB
=
(
sB
[
1
],
sB
[
0
])
elif
transposed_B
and
len
(
sB
)
==
3
:
sB
=
(
sB
[
0
],
sB
[
2
],
sB
[
0
])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# (transpose of row major is column major)
...
@@ -777,23 +1126,28 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
...
@@ -777,23 +1126,28 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if
len
(
sB
)
==
2
:
if
len
(
sB
)
==
2
:
if
B
.
stride
()[
0
]
==
B
.
shape
[
1
]:
transposed_B
=
False
if
B
.
stride
()[
0
]
==
B
.
shape
[
1
]:
elif
B
.
stride
()[
1
]
==
B
.
shape
[
0
]:
transposed_B
=
True
transposed_B
=
False
elif
B
.
stride
()[
1
]
==
B
.
shape
[
0
]:
transposed_B
=
True
if
len
(
A
.
shape
)
==
2
:
if
len
(
A
.
shape
)
==
2
:
if
A
.
stride
()[
0
]
==
A
.
shape
[
1
]:
transposed_A
=
False
if
A
.
stride
()[
0
]
==
A
.
shape
[
1
]:
elif
A
.
stride
()[
1
]
==
A
.
shape
[
0
]:
transposed_A
=
True
transposed_A
=
False
elif
A
.
stride
()[
1
]
==
A
.
shape
[
0
]:
transposed_A
=
True
else
:
else
:
if
A
.
stride
()[
1
]
==
A
.
shape
[
2
]:
transposed_A
=
False
if
A
.
stride
()[
1
]
==
A
.
shape
[
2
]:
elif
A
.
stride
()[
2
]
==
A
.
shape
[
1
]:
transposed_A
=
True
transposed_A
=
False
elif
A
.
stride
()[
2
]
==
A
.
shape
[
1
]:
transposed_A
=
True
if
len
(
sA
)
==
2
:
if
len
(
sA
)
==
2
:
n
=
sA
[
0
]
n
=
sA
[
0
]
ldb
=
A
.
stride
()[
1
if
transposed_A
else
0
]
ldb
=
A
.
stride
()[
1
if
transposed_A
else
0
]
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
n
=
sA
[
0
]
*
sA
[
1
]
n
=
sA
[
0
]
*
sA
[
1
]
ldb
=
sA
[
2
]
ldb
=
sA
[
2
]
m
=
sB
[
1
]
m
=
sB
[
1
]
k
=
sB
[
0
]
k
=
sB
[
0
]
lda
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
lda
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
...
@@ -802,34 +1156,52 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
...
@@ -802,34 +1156,52 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# special case
# special case
assert
len
(
sA
)
==
3
assert
len
(
sA
)
==
3
if
not
(
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
if
not
(
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
raise
ValueError
(
f
'Only bsi,bso->io supported for tensor contractions, but dims for A x B were:
{
sA
}
x
{
sB
}
'
)
raise
ValueError
(
f
"Only bsi,bso->io supported for tensor contractions, but dims for A x B were:
{
sA
}
x
{
sB
}
"
)
transposed_A
=
True
transposed_A
=
True
transposed_B
=
False
transposed_B
=
False
m
=
sB
[
2
]
m
=
sB
[
2
]
n
=
sA
[
2
]
n
=
sA
[
2
]
k
=
sB
[
0
]
*
sB
[
1
]
k
=
sB
[
0
]
*
sB
[
1
]
lda
=
m
lda
=
m
ldb
=
sA
[
2
]
ldb
=
sA
[
2
]
ldc
=
m
ldc
=
m
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
# B^T @ A^T = C^T
# B^T @ A^T = C^T
# [km, nk -> mn]
# [km, nk -> mn]
lib
.
cigemm
(
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
lib
.
cigemm
(
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
))
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
),
)
return
out
return
out
def
batched_igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
):
def
batched_igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
):
if
not
len
(
A
.
shape
)
==
3
or
not
len
(
B
.
shape
)
==
3
:
if
not
len
(
A
.
shape
)
==
3
or
not
len
(
B
.
shape
)
==
3
:
raise
ValueError
(
f
'Expected 3-dimensional tensors for bmm, but got shapes A and B:
{
A
.
shape
}
and
{
B
.
shape
}
'
)
raise
ValueError
(
f
"Expected 3-dimensional tensors for bmm, but got shapes A and B:
{
A
.
shape
}
and
{
B
.
shape
}
"
)
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
int32
,
device
=
A
.
device
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
int32
,
device
=
A
.
device
)
if
B
.
is_contiguous
():
if
B
.
is_contiguous
():
lda
=
B
.
stride
()[
1
]
lda
=
B
.
stride
()[
1
]
...
@@ -886,17 +1258,33 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
...
@@ -886,17 +1258,33 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ldc
=
m
ldc
=
m
strideA
=
B
.
shape
[
1
]
*
B
.
shape
[
2
]
strideA
=
B
.
shape
[
1
]
*
B
.
shape
[
2
]
strideB
=
A
.
shape
[
1
]
*
A
.
shape
[
2
]
strideB
=
A
.
shape
[
1
]
*
A
.
shape
[
2
]
strideC
=
A
.
shape
[
1
]
*
B
.
shape
[
2
]
strideC
=
A
.
shape
[
1
]
*
B
.
shape
[
2
]
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
lib
.
cbatched_igemm
(
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
lib
.
cbatched_igemm
(
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
),
ptr
,
ct
.
c_long
(
strideA
),
ct
.
c_long
(
strideB
),
ct
.
c_long
(
strideC
),
ct
.
c_uint32
(
num_batch
))
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
),
ct
.
c_long
(
strideA
),
ct
.
c_long
(
strideB
),
ct
.
c_long
(
strideC
),
ct
.
c_uint32
(
num_batch
),
)
return
out
return
out
def
igemmlt
(
A
,
B
,
SA
,
SB
,
out
=
None
,
Sout
=
None
,
dtype
=
torch
.
int32
):
def
igemmlt
(
A
,
B
,
SA
,
SB
,
out
=
None
,
Sout
=
None
,
dtype
=
torch
.
int32
):
shapeA
=
SA
[
0
]
shapeA
=
SA
[
0
]
shapeB
=
SB
[
0
]
shapeB
=
SB
[
0
]
...
@@ -905,28 +1293,34 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
...
@@ -905,28 +1293,34 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
if
dimsA
==
2
:
if
dimsA
==
2
:
m
=
shapeA
[
0
]
m
=
shapeA
[
0
]
elif
dimsA
==
3
:
elif
dimsA
==
3
:
m
=
shapeA
[
0
]
*
shapeA
[
1
]
m
=
shapeA
[
0
]
*
shapeA
[
1
]
if
dimsB
==
2
:
if
dimsB
==
2
:
rows
=
n
=
shapeB
[
0
]
rows
=
n
=
shapeB
[
0
]
elif
dimsB
==
3
:
elif
dimsB
==
3
:
rows
=
n
=
shapeB
[
0
]
*
shapeB
[
1
]
rows
=
n
=
shapeB
[
0
]
*
shapeB
[
1
]
if
dimsA
==
2
and
out
is
None
:
if
dimsA
==
2
and
out
is
None
:
out
,
Sout
=
get_transform_buffer
((
shapeA
[
0
],
shapeB
[
0
]),
dtype
,
A
.
device
,
'col32'
,
'row'
)
out
,
Sout
=
get_transform_buffer
(
(
shapeA
[
0
],
shapeB
[
0
]),
dtype
,
A
.
device
,
"col32"
,
"row"
)
elif
dimsA
==
3
and
out
is
None
:
elif
dimsA
==
3
and
out
is
None
:
out
,
Sout
=
get_transform_buffer
((
shapeA
[
0
],
shapeA
[
1
],
shapeB
[
0
]),
dtype
,
A
.
device
,
'col32'
,
'row'
)
out
,
Sout
=
get_transform_buffer
(
(
shapeA
[
0
],
shapeA
[
1
],
shapeB
[
0
]),
dtype
,
A
.
device
,
"col32"
,
"row"
)
assert
dimsB
!=
3
,
'
len(B.shape)==3 not supported
'
assert
dimsB
!=
3
,
"
len(B.shape)==3 not supported
"
assert
A
.
device
.
type
==
'
cuda
'
assert
A
.
device
.
type
==
"
cuda
"
assert
B
.
device
.
type
==
'
cuda
'
assert
B
.
device
.
type
==
"
cuda
"
assert
A
.
dtype
==
torch
.
int8
assert
A
.
dtype
==
torch
.
int8
assert
B
.
dtype
==
torch
.
int8
assert
B
.
dtype
==
torch
.
int8
assert
out
.
dtype
==
dtype
assert
out
.
dtype
==
dtype
assert
SA
[
1
]
==
'col32'
assert
SA
[
1
]
==
"col32"
assert
SB
[
1
]
in
[
'col_turing'
,
'col_ampere'
]
assert
SB
[
1
]
in
[
"col_turing"
,
"col_ampere"
]
assert
Sout
[
1
]
==
'col32'
assert
Sout
[
1
]
==
"col32"
assert
shapeA
[
-
1
]
==
shapeB
[
-
1
],
f
'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B =
{
shapeA
}
@
{
shapeB
}
'
assert
(
shapeA
[
-
1
]
==
shapeB
[
-
1
]
),
f
"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B =
{
shapeA
}
@
{
shapeB
}
"
formatB
=
SB
[
1
]
formatB
=
SB
[
1
]
prev_device
=
A
.
device
prev_device
=
A
.
device
torch
.
cuda
.
set_device
(
A
.
device
)
torch
.
cuda
.
set_device
(
A
.
device
)
...
@@ -937,53 +1331,76 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
...
@@ -937,53 +1331,76 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
ptrC
=
get_ptr
(
out
)
ptrC
=
get_ptr
(
out
)
k
=
shapeA
[
-
1
]
k
=
shapeA
[
-
1
]
lda
=
ct
.
c_int32
(
m
*
32
)
lda
=
ct
.
c_int32
(
m
*
32
)
if
formatB
==
'
col_turing
'
:
if
formatB
==
"
col_turing
"
:
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# n = rows
# n = rows
ldb
=
ct
.
c_int32
(((
rows
+
7
)
//
8
)
*
8
*
32
)
ldb
=
ct
.
c_int32
(((
rows
+
7
)
//
8
)
*
8
*
32
)
else
:
else
:
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# n = rows
# n = rows
ldb
=
ct
.
c_int32
(((
rows
+
31
)
//
32
)
*
32
*
32
)
ldb
=
ct
.
c_int32
(((
rows
+
31
)
//
32
)
*
32
*
32
)
ldc
=
ct
.
c_int32
(
m
*
32
)
ldc
=
ct
.
c_int32
(
m
*
32
)
m
=
ct
.
c_int32
(
m
)
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
n
=
ct
.
c_int32
(
n
)
k
=
ct
.
c_int32
(
k
)
k
=
ct
.
c_int32
(
k
)
has_error
=
0
has_error
=
0
ptrRowScale
=
get_ptr
(
None
)
ptrRowScale
=
get_ptr
(
None
)
if
formatB
==
'
col_turing
'
:
if
formatB
==
"
col_turing
"
:
if
dtype
==
torch
.
int32
:
if
dtype
==
torch
.
int32
:
has_error
=
lib
.
cigemmlt_turing_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
has_error
=
lib
.
cigemmlt_turing_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
else
:
else
:
has_error
=
lib
.
cigemmlt_turing_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
has_error
=
lib
.
cigemmlt_turing_8
(
elif
formatB
==
'col_ampere'
:
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
elif
formatB
==
"col_ampere"
:
if
dtype
==
torch
.
int32
:
if
dtype
==
torch
.
int32
:
has_error
=
lib
.
cigemmlt_ampere_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
has_error
=
lib
.
cigemmlt_ampere_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
else
:
else
:
has_error
=
lib
.
cigemmlt_ampere_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
has_error
=
lib
.
cigemmlt_ampere_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
if
has_error
==
1
:
if
has_error
==
1
:
raise
Exception
(
'
cublasLt ran into an error!
'
)
raise
Exception
(
"
cublasLt ran into an error!
"
)
torch
.
cuda
.
set_device
(
prev_device
)
torch
.
cuda
.
set_device
(
prev_device
)
return
out
,
Sout
return
out
,
Sout
def
mm_dequant
(
A
,
quant_state
,
row_stats
,
col_stats
,
out
=
None
,
new_row_stats
=
None
,
new_col_stats
=
None
):
def
mm_dequant
(
A
,
quant_state
,
row_stats
,
col_stats
,
out
=
None
,
new_row_stats
=
None
,
new_col_stats
=
None
,
):
assert
A
.
dtype
==
torch
.
int32
assert
A
.
dtype
==
torch
.
int32
out_shape
=
quant_state
[
0
]
out_shape
=
quant_state
[
0
]
if
len
(
out_shape
)
==
3
:
out_shape
=
(
out_shape
[
0
]
*
out_shape
[
1
],
out_shape
[
2
])
if
len
(
out_shape
)
==
3
:
out_shape
=
(
out_shape
[
0
]
*
out_shape
[
1
],
out_shape
[
2
])
if
out
is
None
:
out
=
torch
.
empty
(
out_shape
,
dtype
=
torch
.
float16
,
device
=
A
.
device
)
if
new_row_stats
is
None
:
new_row_stats
=
torch
.
empty
(
out_shape
[
0
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
if
out
is
None
:
if
new_col_stats
is
None
:
new_col_stats
=
torch
.
empty
(
out_shape
[
1
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
out
=
torch
.
empty
(
out_shape
,
dtype
=
torch
.
float16
,
device
=
A
.
device
)
assert
new_row_stats
.
shape
[
0
]
==
row_stats
.
shape
[
0
],
f
"
{
new_row_stats
.
shape
}
vs
{
row_stats
.
shape
}
"
if
new_row_stats
is
None
:
assert
new_col_stats
.
shape
[
0
]
==
col_stats
.
shape
[
0
],
f
"
{
new_col_stats
.
shape
}
vs
{
col_stats
.
shape
}
"
new_row_stats
=
torch
.
empty
(
out_shape
[
0
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
if
new_col_stats
is
None
:
new_col_stats
=
torch
.
empty
(
out_shape
[
1
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
assert
(
new_row_stats
.
shape
[
0
]
==
row_stats
.
shape
[
0
]
),
f
"
{
new_row_stats
.
shape
}
vs
{
row_stats
.
shape
}
"
assert
(
new_col_stats
.
shape
[
0
]
==
col_stats
.
shape
[
0
]
),
f
"
{
new_col_stats
.
shape
}
vs
{
col_stats
.
shape
}
"
ptrA
=
get_ptr
(
A
)
ptrA
=
get_ptr
(
A
)
ptrOut
=
get_ptr
(
out
)
ptrOut
=
get_ptr
(
out
)
...
@@ -994,27 +1411,47 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
...
@@ -994,27 +1411,47 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
numRows
=
ct
.
c_int32
(
out_shape
[
0
])
numRows
=
ct
.
c_int32
(
out_shape
[
0
])
numCols
=
ct
.
c_int32
(
out_shape
[
1
])
numCols
=
ct
.
c_int32
(
out_shape
[
1
])
lib
.
cdequant_mm_int32_fp16
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOut
,
ptrNewRowStats
,
ptrNewColStats
,
numRows
,
numCols
)
lib
.
cdequant_mm_int32_fp16
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOut
,
ptrNewRowStats
,
ptrNewColStats
,
numRows
,
numCols
,
)
return
out
return
out
def
get_colrow_absmax
(
A
,
row_stats
=
None
,
col_stats
=
None
,
nnz_block_ptr
=
None
,
threshold
=
0.0
):
def
get_colrow_absmax
(
A
,
row_stats
=
None
,
col_stats
=
None
,
nnz_block_ptr
=
None
,
threshold
=
0.0
):
assert
A
.
dtype
==
torch
.
float16
assert
A
.
dtype
==
torch
.
float16
device
=
A
.
device
device
=
A
.
device
cols
=
A
.
shape
[
-
1
]
cols
=
A
.
shape
[
-
1
]
if
len
(
A
.
shape
)
==
3
:
if
len
(
A
.
shape
)
==
3
:
rows
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
rows
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
else
:
else
:
rows
=
A
.
shape
[
0
]
rows
=
A
.
shape
[
0
]
col_tiles
=
(
cols
+
255
)
//
256
col_tiles
=
(
cols
+
255
)
//
256
tiled_rows
=
((
rows
+
15
)
//
16
)
*
16
tiled_rows
=
((
rows
+
15
)
//
16
)
*
16
if
row_stats
is
None
:
row_stats
=
torch
.
empty
((
rows
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
if
row_stats
is
None
:
if
col_stats
is
None
:
col_stats
=
torch
.
empty
((
cols
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
row_stats
=
torch
.
empty
((
rows
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
if
nnz_block_ptr
is
None
and
threshold
>
0.0
:
nnz_block_ptr
=
torch
.
zeros
(((
tiled_rows
*
col_tiles
)
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
)
if
col_stats
is
None
:
col_stats
=
torch
.
empty
((
cols
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
if
nnz_block_ptr
is
None
and
threshold
>
0.0
:
nnz_block_ptr
=
torch
.
zeros
(
((
tiled_rows
*
col_tiles
)
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
ptrA
=
get_ptr
(
A
)
ptrA
=
get_ptr
(
A
)
ptrRowStats
=
get_ptr
(
row_stats
)
ptrRowStats
=
get_ptr
(
row_stats
)
...
@@ -1024,16 +1461,17 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr
...
@@ -1024,16 +1461,17 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr
cols
=
ct
.
c_int32
(
cols
)
cols
=
ct
.
c_int32
(
cols
)
prev_device
=
pre_call
(
A
.
device
)
prev_device
=
pre_call
(
A
.
device
)
lib
.
cget_col_row_stats
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
)
lib
.
cget_col_row_stats
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
)
post_call
(
prev_device
)
post_call
(
prev_device
)
if
threshold
>
0.0
:
if
threshold
>
0.0
:
nnz_block_ptr
.
cumsum_
(
0
)
nnz_block_ptr
.
cumsum_
(
0
)
return
row_stats
,
col_stats
,
nnz_block_ptr
return
row_stats
,
col_stats
,
nnz_block_ptr
class
COOSparseTensor
(
object
):
class
COOSparseTensor
(
object
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
rowidx
,
colidx
,
values
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
rowidx
,
colidx
,
values
):
assert
rowidx
.
dtype
==
torch
.
int32
assert
rowidx
.
dtype
==
torch
.
int32
...
@@ -1050,6 +1488,7 @@ class COOSparseTensor(object):
...
@@ -1050,6 +1488,7 @@ class COOSparseTensor(object):
self
.
colidx
=
colidx
self
.
colidx
=
colidx
self
.
values
=
values
self
.
values
=
values
class
CSRSparseTensor
(
object
):
class
CSRSparseTensor
(
object
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
rowptr
,
colidx
,
values
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
rowptr
,
colidx
,
values
):
assert
rowptr
.
dtype
==
torch
.
int32
assert
rowptr
.
dtype
==
torch
.
int32
...
@@ -1057,7 +1496,7 @@ class CSRSparseTensor(object):
...
@@ -1057,7 +1496,7 @@ class CSRSparseTensor(object):
assert
values
.
dtype
==
torch
.
float16
assert
values
.
dtype
==
torch
.
float16
assert
values
.
numel
()
==
nnz
assert
values
.
numel
()
==
nnz
assert
colidx
.
numel
()
==
nnz
assert
colidx
.
numel
()
==
nnz
assert
rowptr
.
numel
()
==
rows
+
1
assert
rowptr
.
numel
()
==
rows
+
1
self
.
rows
=
rows
self
.
rows
=
rows
self
.
cols
=
cols
self
.
cols
=
cols
...
@@ -1066,6 +1505,7 @@ class CSRSparseTensor(object):
...
@@ -1066,6 +1505,7 @@ class CSRSparseTensor(object):
self
.
colidx
=
colidx
self
.
colidx
=
colidx
self
.
values
=
values
self
.
values
=
values
class
CSCSparseTensor
(
object
):
class
CSCSparseTensor
(
object
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
colptr
,
rowidx
,
values
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
colptr
,
rowidx
,
values
):
assert
colptr
.
dtype
==
torch
.
int32
assert
colptr
.
dtype
==
torch
.
int32
...
@@ -1073,7 +1513,7 @@ class CSCSparseTensor(object):
...
@@ -1073,7 +1513,7 @@ class CSCSparseTensor(object):
assert
values
.
dtype
==
torch
.
float16
assert
values
.
dtype
==
torch
.
float16
assert
values
.
numel
()
==
nnz
assert
values
.
numel
()
==
nnz
assert
rowidx
.
numel
()
==
nnz
assert
rowidx
.
numel
()
==
nnz
assert
colptr
.
numel
()
==
cols
+
1
assert
colptr
.
numel
()
==
cols
+
1
self
.
rows
=
rows
self
.
rows
=
rows
self
.
cols
=
cols
self
.
cols
=
cols
...
@@ -1082,13 +1522,17 @@ class CSCSparseTensor(object):
...
@@ -1082,13 +1522,17 @@ class CSCSparseTensor(object):
self
.
rowidx
=
rowidx
self
.
rowidx
=
rowidx
self
.
values
=
values
self
.
values
=
values
def
coo2csr
(
cooA
):
def
coo2csr
(
cooA
):
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
values
.
add_
(
1
)
values
.
add_
(
1
)
rowptr
=
torch
.
zeros
((
cooA
.
rows
+
1
,
),
dtype
=
torch
.
int32
,
device
=
cooA
.
rowidx
.
device
)
rowptr
=
torch
.
zeros
((
cooA
.
rows
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
rowidx
.
device
)
rowptr
.
scatter_
(
index
=
values
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
rowptr
.
scatter_
(
index
=
values
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
rowptr
.
cumsum_
(
0
)
rowptr
.
cumsum_
(
0
)
return
CSRSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
rowptr
,
cooA
.
colidx
,
cooA
.
values
)
return
CSRSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
rowptr
,
cooA
.
colidx
,
cooA
.
values
)
def
coo2csc
(
cooA
):
def
coo2csc
(
cooA
):
val
,
col2rowidx
=
torch
.
sort
(
cooA
.
colidx
)
val
,
col2rowidx
=
torch
.
sort
(
cooA
.
colidx
)
...
@@ -1096,11 +1540,12 @@ def coo2csc(cooA):
...
@@ -1096,11 +1540,12 @@ def coo2csc(cooA):
values
=
cooA
.
values
[
col2rowidx
]
values
=
cooA
.
values
[
col2rowidx
]
colvalues
,
counts
=
torch
.
unique
(
val
,
return_counts
=
True
)
colvalues
,
counts
=
torch
.
unique
(
val
,
return_counts
=
True
)
colvalues
.
add_
(
1
)
colvalues
.
add_
(
1
)
colptr
=
torch
.
zeros
((
cooA
.
cols
+
1
,
),
dtype
=
torch
.
int32
,
device
=
cooA
.
colidx
.
device
)
colptr
=
torch
.
zeros
((
cooA
.
cols
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
colidx
.
device
)
colptr
.
scatter_
(
index
=
colvalues
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
colptr
.
scatter_
(
index
=
colvalues
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
colptr
.
cumsum_
(
0
)
colptr
.
cumsum_
(
0
)
return
CSCSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
colptr
,
rowidx
,
values
)
return
CSCSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
colptr
,
rowidx
,
values
)
def
coo_zeros
(
rows
,
cols
,
nnz
,
device
,
dtype
=
torch
.
half
):
def
coo_zeros
(
rows
,
cols
,
nnz
,
device
,
dtype
=
torch
.
half
):
rowidx
=
torch
.
zeros
((
nnz
,),
dtype
=
torch
.
int32
,
device
=
device
)
rowidx
=
torch
.
zeros
((
nnz
,),
dtype
=
torch
.
int32
,
device
=
device
)
colidx
=
torch
.
zeros
((
nnz
,),
dtype
=
torch
.
int32
,
device
=
device
)
colidx
=
torch
.
zeros
((
nnz
,),
dtype
=
torch
.
int32
,
device
=
device
)
...
@@ -1108,23 +1553,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
...
@@ -1108,23 +1553,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return
COOSparseTensor
(
rows
,
cols
,
nnz
,
rowidx
,
colidx
,
values
)
return
COOSparseTensor
(
rows
,
cols
,
nnz
,
rowidx
,
colidx
,
values
)
def
double_quant
(
A
,
col_stats
=
None
,
row_stats
=
None
,
out_col
=
None
,
out_row
=
None
,
threshold
=
0.0
):
def
double_quant
(
A
,
col_stats
=
None
,
row_stats
=
None
,
out_col
=
None
,
out_row
=
None
,
threshold
=
0.0
):
device
=
A
.
device
device
=
A
.
device
assert
A
.
dtype
==
torch
.
half
assert
A
.
dtype
==
torch
.
half
assert
device
.
type
==
'
cuda
'
assert
device
.
type
==
"
cuda
"
prev_device
=
pre_call
(
A
.
device
)
prev_device
=
pre_call
(
A
.
device
)
cols
=
A
.
shape
[
-
1
]
cols
=
A
.
shape
[
-
1
]
if
len
(
A
.
shape
)
==
3
:
if
len
(
A
.
shape
)
==
3
:
rows
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
rows
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
else
:
else
:
rows
=
A
.
shape
[
0
]
rows
=
A
.
shape
[
0
]
if
row_stats
is
None
or
col_stats
is
None
:
if
row_stats
is
None
or
col_stats
is
None
:
row_stats
,
col_stats
,
nnz_row_ptr
=
get_colrow_absmax
(
A
,
threshold
=
threshold
)
row_stats
,
col_stats
,
nnz_row_ptr
=
get_colrow_absmax
(
A
,
threshold
=
threshold
)
if
out_col
is
None
:
out_col
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
if
out_col
is
None
:
if
out_row
is
None
:
out_row
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
out_col
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
if
out_row
is
None
:
out_row
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
coo_tensor
=
None
coo_tensor
=
None
ptrA
=
get_ptr
(
A
)
ptrA
=
get_ptr
(
A
)
...
@@ -1136,21 +1585,62 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
...
@@ -1136,21 +1585,62 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
if
threshold
>
0.0
:
if
threshold
>
0.0
:
nnz
=
nnz_row_ptr
[
-
1
].
item
()
nnz
=
nnz_row_ptr
[
-
1
].
item
()
if
nnz
>
0
:
if
nnz
>
0
:
coo_tensor
=
coo_zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz_row_ptr
[
-
1
].
item
(),
device
)
coo_tensor
=
coo_zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz_row_ptr
[
-
1
].
item
(),
device
)
ptrRowIdx
=
get_ptr
(
coo_tensor
.
rowidx
)
ptrRowIdx
=
get_ptr
(
coo_tensor
.
rowidx
)
ptrColIdx
=
get_ptr
(
coo_tensor
.
colidx
)
ptrColIdx
=
get_ptr
(
coo_tensor
.
colidx
)
ptrVal
=
get_ptr
(
coo_tensor
.
values
)
ptrVal
=
get_ptr
(
coo_tensor
.
values
)
ptrRowPtr
=
get_ptr
(
nnz_row_ptr
)
ptrRowPtr
=
get_ptr
(
nnz_row_ptr
)
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
ptrRowIdx
,
ptrColIdx
,
ptrVal
,
ptrRowPtr
,
ct
.
c_float
(
threshold
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
))
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
ptrRowIdx
,
ptrColIdx
,
ptrVal
,
ptrRowPtr
,
ct
.
c_float
(
threshold
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
),
)
val
,
idx
=
torch
.
sort
(
coo_tensor
.
rowidx
)
val
,
idx
=
torch
.
sort
(
coo_tensor
.
rowidx
)
coo_tensor
.
rowidx
=
val
coo_tensor
.
rowidx
=
val
coo_tensor
.
colidx
=
coo_tensor
.
colidx
[
idx
]
coo_tensor
.
colidx
=
coo_tensor
.
colidx
[
idx
]
coo_tensor
.
values
=
coo_tensor
.
values
[
idx
]
coo_tensor
.
values
=
coo_tensor
.
values
[
idx
]
else
:
else
:
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
None
,
None
,
None
,
None
,
ct
.
c_float
(
0.0
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
))
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
None
,
None
,
None
,
None
,
ct
.
c_float
(
0.0
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
),
)
else
:
else
:
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
None
,
None
,
None
,
None
,
ct
.
c_float
(
threshold
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
))
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
None
,
None
,
None
,
None
,
ct
.
c_float
(
threshold
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
),
)
post_call
(
prev_device
)
post_call
(
prev_device
)
return
out_row
,
out_col
,
row_stats
,
col_stats
,
coo_tensor
return
out_row
,
out_col
,
row_stats
,
col_stats
,
coo_tensor
...
@@ -1159,69 +1649,81 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
...
@@ -1159,69 +1649,81 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
def
get_special_format_str
():
def
get_special_format_str
():
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
if
major
<
7
:
if
major
<
7
:
print
(
f
'Device with CUDA capability of
{
major
}
not supported for 8-bit matmul. Device has no tensor cores!'
)
print
(
f
"Device with CUDA capability of
{
major
}
not supported for 8-bit matmul. Device has no tensor cores!"
)
assert
major
>=
7
assert
major
>=
7
if
major
==
7
:
return
'col_turing'
if
major
==
7
:
elif
major
==
8
:
return
'col_ampere'
return
"col_turing"
else
:
return
'col_turing'
elif
major
==
8
:
return
"col_ampere"
else
:
return
"col_turing"
def
transform
(
A
,
to_order
,
from_order
=
'row'
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
def
transform
(
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
else
:
from_order
=
state
[
1
]
):
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
],
transpose
)
if
state
is
None
:
else
:
new_state
=
(
state
[
0
],
to_order
)
# (shape, order)
state
=
(
A
.
shape
,
from_order
)
else
:
from_order
=
state
[
1
]
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
],
transpose
)
else
:
new_state
=
(
state
[
0
],
to_order
)
# (shape, order)
shape
=
state
[
0
]
shape
=
state
[
0
]
if
len
(
shape
)
==
2
:
if
len
(
shape
)
==
2
:
dim1
=
ct
.
c_int32
(
shape
[
0
])
dim1
=
ct
.
c_int32
(
shape
[
0
])
dim2
=
ct
.
c_int32
(
shape
[
1
])
dim2
=
ct
.
c_int32
(
shape
[
1
])
else
:
else
:
dim1
=
ct
.
c_int32
(
shape
[
0
]
*
shape
[
1
])
dim1
=
ct
.
c_int32
(
shape
[
0
]
*
shape
[
1
])
dim2
=
ct
.
c_int32
(
shape
[
2
])
dim2
=
ct
.
c_int32
(
shape
[
2
])
ptrA
=
get_ptr
(
A
)
ptrA
=
get_ptr
(
A
)
ptrOut
=
get_ptr
(
out
)
ptrOut
=
get_ptr
(
out
)
if
to_order
==
'
col32
'
:
if
to_order
==
"
col32
"
:
if
transpose
:
if
transpose
:
lib
.
ctransform_row2col32T
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_row2col32T
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
else
:
lib
.
ctransform_row2col32
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_row2col32
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
to_order
==
'
col_turing
'
:
elif
to_order
==
"
col_turing
"
:
if
transpose
:
if
transpose
:
lib
.
ctransform_row2turingT
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_row2turingT
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
else
:
lib
.
ctransform_row2turing
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_row2turing
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
to_order
==
'
col_ampere
'
:
elif
to_order
==
"
col_ampere
"
:
if
transpose
:
if
transpose
:
lib
.
ctransform_row2ampereT
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_row2ampereT
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
else
:
lib
.
ctransform_row2ampere
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_row2ampere
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
to_order
==
'
row
'
:
elif
to_order
==
"
row
"
:
if
from_order
==
'
col_turing
'
:
if
from_order
==
"
col_turing
"
:
lib
.
ctransform_turing2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_turing2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
from_order
==
'
col_ampere
'
:
elif
from_order
==
"
col_ampere
"
:
lib
.
ctransform_ampere2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_ampere2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
else
:
raise
NotImplementedError
(
f
'Transform function not implemented: From
{
from_order
}
to
{
to_order
}
'
)
raise
NotImplementedError
(
f
"Transform function not implemented: From
{
from_order
}
to
{
to_order
}
"
)
return
out
,
new_state
return
out
,
new_state
def
spmm_coo
(
cooA
,
B
,
out
=
None
):
def
spmm_coo
(
cooA
,
B
,
out
=
None
):
if
out
is
None
:
out
=
torch
.
empty
((
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
B
.
dtype
)
if
out
is
None
:
out
=
torch
.
empty
((
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
B
.
dtype
)
nnz
=
cooA
.
nnz
nnz
=
cooA
.
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
values
.
numel
()
==
nnz
assert
cooA
.
values
.
numel
()
==
nnz
assert
cooA
.
cols
==
B
.
shape
[
0
]
assert
cooA
.
cols
==
B
.
shape
[
0
]
transposed_B
=
(
False
if
B
.
is_contiguous
()
else
True
)
transposed_B
=
False
if
B
.
is_contiguous
()
else
True
ldb
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
ldb
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
ldc
=
B
.
shape
[
1
]
ldc
=
B
.
shape
[
1
]
...
@@ -1240,19 +1742,37 @@ def spmm_coo(cooA, B, out=None):
...
@@ -1240,19 +1742,37 @@ def spmm_coo(cooA, B, out=None):
cldb
=
ct
.
c_int32
(
ldb
)
cldb
=
ct
.
c_int32
(
ldb
)
cldc
=
ct
.
c_int32
(
ldc
)
cldc
=
ct
.
c_int32
(
ldc
)
lib
.
cspmm_coo
(
ptr
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
cnnz
,
crowsA
,
ccolsA
,
ccolsB
,
cldb
,
ptrB
,
cldc
,
ptrC
,
ct
.
c_bool
(
transposed_B
))
lib
.
cspmm_coo
(
ptr
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
cnnz
,
crowsA
,
ccolsA
,
ccolsB
,
cldb
,
ptrB
,
cldc
,
ptrC
,
ct
.
c_bool
(
transposed_B
),
)
return
out
return
out
def
spmm_coo_very_sparse
(
cooA
,
B
,
dequant_stats
=
None
,
out
=
None
):
def
spmm_coo_very_sparse
(
cooA
,
B
,
dequant_stats
=
None
,
out
=
None
):
if
out
is
None
:
out
=
torch
.
zeros
((
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
cooA
.
values
.
dtype
)
if
out
is
None
:
out
=
torch
.
zeros
(
(
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
cooA
.
values
.
dtype
)
nnz
=
cooA
.
nnz
nnz
=
cooA
.
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
values
.
numel
()
==
nnz
assert
cooA
.
values
.
numel
()
==
nnz
assert
cooA
.
cols
==
B
.
shape
[
0
],
f
'
{
cooA
.
cols
}
vs
{
B
.
shape
}
'
assert
cooA
.
cols
==
B
.
shape
[
0
],
f
"
{
cooA
.
cols
}
vs
{
B
.
shape
}
"
transposed_B
=
(
False
if
B
.
is_contiguous
()
else
True
)
transposed_B
=
False
if
B
.
is_contiguous
()
else
True
ldb
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
ldb
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
ldc
=
B
.
shape
[
1
]
ldc
=
B
.
shape
[
1
]
...
@@ -1262,7 +1782,9 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
...
@@ -1262,7 +1782,9 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
max_count
,
max_idx
=
torch
.
sort
(
counts
,
descending
=
True
)
max_count
,
max_idx
=
torch
.
sort
(
counts
,
descending
=
True
)
max_idx
=
max_idx
.
int
()
max_idx
=
max_idx
.
int
()
max_count
=
max_count
.
int
()
max_count
=
max_count
.
int
()
assert
max_count
[
0
]
<=
32
,
f
'Current max count per row is 8 but found
{
max_count
[
0
]
}
.'
assert
(
max_count
[
0
]
<=
32
),
f
"Current max count per row is 8 but found
{
max_count
[
0
]
}
."
assert
B
.
dtype
in
[
torch
.
float16
,
torch
.
int8
]
assert
B
.
dtype
in
[
torch
.
float16
,
torch
.
int8
]
ptrOffset
=
get_ptr
(
offset
)
ptrOffset
=
get_ptr
(
offset
)
ptrMaxCount
=
get_ptr
(
max_count
)
ptrMaxCount
=
get_ptr
(
max_count
)
...
@@ -1282,134 +1804,183 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
...
@@ -1282,134 +1804,183 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB
=
ct
.
c_int32
(
B
.
shape
[
1
])
ccolsB
=
ct
.
c_int32
(
B
.
shape
[
1
])
cldb
=
ct
.
c_int32
(
ldb
)
cldb
=
ct
.
c_int32
(
ldb
)
cldc
=
ct
.
c_int32
(
ldc
)
cldc
=
ct
.
c_int32
(
ldc
)
#print(cooA.rowidx[:64])
#
print(cooA.rowidx[:64])
#print(cooA.colidx[:64].sort()[0])
#
print(cooA.colidx[:64].sort()[0])
if
B
.
dtype
==
torch
.
float16
:
if
B
.
dtype
==
torch
.
float16
:
lib
.
cspmm_coo_very_sparse_naive_fp16
(
ptrMaxCount
,
ptrMaxIdx
,
ptrOffset
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
ptrB
,
ptrC
,
ptrDequantStats
,
cnnz_rows
,
cnnz
,
crowsA
,
crowsB
,
ccolsB
)
lib
.
cspmm_coo_very_sparse_naive_fp16
(
ptrMaxCount
,
ptrMaxIdx
,
ptrOffset
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
ptrB
,
ptrC
,
ptrDequantStats
,
cnnz_rows
,
cnnz
,
crowsA
,
crowsB
,
ccolsB
,
)
elif
B
.
dtype
==
torch
.
int8
:
elif
B
.
dtype
==
torch
.
int8
:
lib
.
cspmm_coo_very_sparse_naive_int8
(
ptrMaxCount
,
ptrMaxIdx
,
ptrOffset
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
ptrB
,
ptrC
,
ptrDequantStats
,
cnnz_rows
,
cnnz
,
crowsA
,
crowsB
,
ccolsB
)
lib
.
cspmm_coo_very_sparse_naive_int8
(
#else: assertion error
ptrMaxCount
,
ptrMaxIdx
,
ptrOffset
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
ptrB
,
ptrC
,
ptrDequantStats
,
cnnz_rows
,
cnnz
,
crowsA
,
crowsB
,
ccolsB
,
)
# else: assertion error
return
out
return
out
C
=
127.0
C
=
127.0
def
vectorwise_quant
(
x
,
dim
=
1
,
quant_type
=
'vector'
):
if
quant_type
==
'linear'
:
def
vectorwise_quant
(
x
,
dim
=
1
,
quant_type
=
"vector"
):
if
quant_type
==
"linear"
:
max1
=
torch
.
abs
(
x
).
max
().
float
()
max1
=
torch
.
abs
(
x
).
max
().
float
()
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
return
xq
,
max1
elif
quant_type
in
[
'
vector
'
,
'
row
'
]:
elif
quant_type
in
[
"
vector
"
,
"
row
"
]:
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
xq
=
torch
.
round
(
x
*
(
C
/
max1
)).
to
(
torch
.
int8
)
xq
=
torch
.
round
(
x
*
(
C
/
max1
)).
to
(
torch
.
int8
)
return
xq
,
max1
return
xq
,
max1
elif
quant_type
==
'
zeropoint
'
:
elif
quant_type
==
"
zeropoint
"
:
dtype
=
x
.
dtype
dtype
=
x
.
dtype
x
=
x
.
float
()
x
=
x
.
float
()
dyna
=
x
.
max
()
-
x
.
min
()
dyna
=
x
.
max
()
-
x
.
min
()
if
dyna
==
0
:
dyna
=
1
if
dyna
==
0
:
qx
=
255.
/
dyna
dyna
=
1
qx
=
255.0
/
dyna
minx
=
x
.
min
()
minx
=
x
.
min
()
zpx
=
torch
.
round
(
minx
*
qx
)
zpx
=
torch
.
round
(
minx
*
qx
)
x
=
torch
.
round
(
qx
*
x
-
zpx
)
+
zpx
x
=
torch
.
round
(
qx
*
x
-
zpx
)
+
zpx
return
x
,
qx
return
x
,
qx
elif
quant_type
in
[
'
vector-zeropoint
'
,
'
row-zeropoint
'
]:
elif
quant_type
in
[
"
vector-zeropoint
"
,
"
row-zeropoint
"
]:
dtype
=
x
.
dtype
dtype
=
x
.
dtype
x
=
x
.
float
()
x
=
x
.
float
()
dyna
=
(
torch
.
amax
(
x
,
dim
=
dim
,
keepdim
=
True
)
-
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
))
dyna
=
torch
.
amax
(
x
,
dim
=
dim
,
keepdim
=
True
)
-
torch
.
amin
(
dyna
[
dyna
==
0
]
=
1
x
,
dim
=
dim
,
keepdim
=
True
qx
=
255.
/
dyna
)
dyna
[
dyna
==
0
]
=
1
qx
=
255.0
/
dyna
minx
=
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
)
minx
=
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
)
zpx
=
torch
.
round
(
minx
*
qx
)
zpx
=
torch
.
round
(
minx
*
qx
)
x
=
torch
.
round
(
qx
*
x
-
zpx
)
+
zpx
x
=
torch
.
round
(
qx
*
x
-
zpx
)
+
zpx
return
x
,
qx
return
x
,
qx
elif
quant_type
==
'
truncated-vector
'
:
elif
quant_type
==
"
truncated-vector
"
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
absx
=
torch
.
abs
(
x
)
absx
=
torch
.
abs
(
x
)
max1
=
torch
.
amax
(
absx
,
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
amax
(
absx
,
dim
=
dim
,
keepdim
=
True
)
max1
=
max1
*
0.7
max1
=
max1
*
0.7
idx
=
(
absx
>
max1
.
expand_as
(
absx
)
)
idx
=
absx
>
max1
.
expand_as
(
absx
)
sign
=
torch
.
sign
(
x
[
idx
])
sign
=
torch
.
sign
(
x
[
idx
])
x
[
idx
]
=
max1
.
expand_as
(
absx
)[
idx
]
*
sign
x
[
idx
]
=
max1
.
expand_as
(
absx
)[
idx
]
*
sign
xq
=
torch
.
round
(
x
/
max1
*
C
).
to
(
torch
.
int8
)
xq
=
torch
.
round
(
x
/
max1
*
C
).
to
(
torch
.
int8
)
return
xq
,
max1
return
xq
,
max1
else
:
return
None
else
:
return
None
def
vectorwise_dequant
(
xq
,
max1
,
quant_type
=
'
vector
'
):
def
vectorwise_dequant
(
xq
,
max1
,
quant_type
=
"
vector
"
):
if
quant_type
==
'
vector
'
:
if
quant_type
==
"
vector
"
:
x
=
(
xq
/
C
*
max1
).
to
(
torch
.
float32
)
x
=
(
xq
/
C
*
max1
).
to
(
torch
.
float32
)
return
x
return
x
else
:
return
None
else
:
return
None
def
vectorwise_mm_dequant
(
xq
,
S1
,
S2
,
dtype
=
torch
.
half
,
quant_type
=
'
vector
'
):
def
vectorwise_mm_dequant
(
xq
,
S1
,
S2
,
dtype
=
torch
.
half
,
quant_type
=
"
vector
"
):
if
quant_type
==
'
linear
'
:
if
quant_type
==
"
linear
"
:
norm
=
S1
*
S2
/
(
C
*
C
)
norm
=
S1
*
S2
/
(
C
*
C
)
# double cast needed to prevent overflows
# double cast needed to prevent overflows
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
'
zeropoint
'
:
elif
quant_type
==
"
zeropoint
"
:
norm
=
1.0
/
(
S1
*
S2
)
norm
=
1.0
/
(
S1
*
S2
)
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
'
row-zeropoint
'
:
elif
quant_type
==
"
row-zeropoint
"
:
norm
=
1.0
/
(
S1
*
S2
)
norm
=
1.0
/
(
S1
*
S2
)
x
=
xq
.
float
()
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
if
len
(
S1
.
shape
)
==
2
:
x
*=
norm
x
*=
norm
else
:
else
:
x
*=
norm
x
*=
norm
return
x
.
to
(
dtype
)
return
x
.
to
(
dtype
)
elif
quant_type
==
'
vector-zeropoint
'
:
elif
quant_type
==
"
vector-zeropoint
"
:
x
=
xq
.
float
()
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
if
len
(
S1
.
shape
)
==
2
:
x
*=
1.0
/
S1
x
*=
1.0
/
S1
else
:
else
:
x
*=
1.0
/
S1
x
*=
1.0
/
S1
x
*=
1.0
/
S2
.
t
()
x
*=
1.0
/
S2
.
t
()
return
x
.
to
(
dtype
)
return
x
.
to
(
dtype
)
elif
quant_type
==
'
row
'
:
elif
quant_type
==
"
row
"
:
x
=
xq
.
float
()
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
if
len
(
S1
.
shape
)
==
2
:
x
*=
S1
*
S2
/
(
C
*
C
)
x
*=
S1
*
S2
/
(
C
*
C
)
else
:
else
:
x
*=
S1
*
S2
/
(
C
*
C
)
x
*=
S1
*
S2
/
(
C
*
C
)
return
x
.
to
(
dtype
)
return
x
.
to
(
dtype
)
elif
quant_type
in
[
'
truncated-vector
'
,
'
vector
'
]:
elif
quant_type
in
[
"
truncated-vector
"
,
"
vector
"
]:
x
=
xq
.
float
()
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
if
len
(
S1
.
shape
)
==
2
:
x
*=
S1
/
C
x
*=
S1
/
C
else
:
else
:
x
*=
S1
/
C
x
*=
S1
/
C
x
*=
S2
/
C
x
*=
S2
/
C
return
x
.
to
(
dtype
)
return
x
.
to
(
dtype
)
else
:
return
None
else
:
return
None
def
dequant_min_max
(
xq
,
A
,
B
,
SA
,
SB
,
dtype
=
torch
.
half
):
def
dequant_min_max
(
xq
,
A
,
B
,
SA
,
SB
,
dtype
=
torch
.
half
):
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
x
=
xq
.
float
()
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
SB
=
SB
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
SB
=
SB
.
squeeze
(
0
)
if
len
(
SB
.
shape
)
==
2
:
if
len
(
SB
.
shape
)
==
2
:
x
*=
SB
.
t
()
/
127
x
*=
SB
.
t
()
/
127
else
:
else
:
x
*=
SB
/
127
x
*=
SB
/
127
x
*=
SA
[
1
]
/
127
x
*=
SA
[
1
]
/
127
x
+=
offset
x
+=
offset
return
x
.
to
(
dtype
)
return
x
.
to
(
dtype
)
def
extract_outliers
(
A
,
SA
,
idx
):
def
extract_outliers
(
A
,
SA
,
idx
):
shapeA
=
SA
[
0
]
shapeA
=
SA
[
0
]
formatA
=
SA
[
1
]
formatA
=
SA
[
1
]
assert
formatA
in
[
'
col_turing
'
,
'
col_ampere
'
]
assert
formatA
in
[
"
col_turing
"
,
"
col_ampere
"
]
assert
A
.
device
.
type
==
'
cuda
'
assert
A
.
device
.
type
==
"
cuda
"
out
=
torch
.
zeros
((
shapeA
[
0
],
idx
.
numel
()),
dtype
=
torch
.
int8
,
device
=
A
.
device
)
out
=
torch
.
zeros
((
shapeA
[
0
],
idx
.
numel
()),
dtype
=
torch
.
int8
,
device
=
A
.
device
)
...
@@ -1420,13 +1991,9 @@ def extract_outliers(A, SA, idx):
...
@@ -1420,13 +1991,9 @@ def extract_outliers(A, SA, idx):
ptrIdx
=
get_ptr
(
idx
)
ptrIdx
=
get_ptr
(
idx
)
ptrOut
=
get_ptr
(
out
)
ptrOut
=
get_ptr
(
out
)
if
formatA
==
'
col_turing
'
:
if
formatA
==
"
col_turing
"
:
lib
.
cextractOutliers_turing
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
lib
.
cextractOutliers_turing
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
elif
formatA
==
'
col_ampere
'
:
elif
formatA
==
"
col_ampere
"
:
lib
.
cextractOutliers_ampere
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
lib
.
cextractOutliers_ampere
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
return
out
return
out
bitsandbytes/nn/__init__.py
View file @
bfa0e332
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.modules
import
StableEmbedding
,
Linear8bit
,
Linear8bitLt
,
Int8Params
from
.modules
import
Int8Params
,
Linear8bit
,
Linear8bitLt
,
StableEmbedding
bitsandbytes/nn/modules.py
View file @
bfa0e332
...
@@ -2,38 +2,58 @@
...
@@ -2,38 +2,58 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
torch
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Set
,
import
bitsandbytes
as
bnb
Tuple
,
TypeVar
,
Union
,
overload
)
from
typing
import
Union
,
Tuple
,
Any
,
Callable
,
Iterator
,
Set
,
Optional
,
overload
,
TypeVar
,
Mapping
,
Dict
import
torch
from
torch
import
Tensor
,
device
,
dtype
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
device
,
dtype
,
nn
from
torch.nn.parameter
import
Parameter
import
bitsandbytes
as
bnb
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.optim
import
GlobalOptimManager
T
=
TypeVar
(
'T'
,
bound
=
'torch.nn.Module'
)
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
class
StableEmbedding
(
torch
.
nn
.
Embedding
):
class
StableEmbedding
(
torch
.
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
def
__init__
(
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
self
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
)
->
None
:
num_embeddings
:
int
,
super
(
StableEmbedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
)
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.0
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
,
)
->
None
:
super
(
StableEmbedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
)
self
.
norm
=
torch
.
nn
.
LayerNorm
(
embedding_dim
)
self
.
norm
=
torch
.
nn
.
LayerNorm
(
embedding_dim
)
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
'weight'
,
{
'optim_bits'
:
32
})
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
"weight"
,
{
"optim_bits"
:
32
}
)
def
reset_parameters
(
self
)
->
None
:
def
reset_parameters
(
self
)
->
None
:
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
self
.
_fill_padding_idx_with_zero
()
self
.
_fill_padding_idx_with_zero
()
'''
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
"""
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
PyTorch releases.
'''
"""
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
:
if
self
.
padding_idx
is
not
None
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding):
...
@@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding):
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
emb
=
F
.
embedding
(
emb
=
F
.
embedding
(
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
input
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
return
self
.
norm
(
emb
)
return
self
.
norm
(
emb
)
class
Embedding
(
torch
.
nn
.
Embedding
):
class
Embedding
(
torch
.
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
def
__init__
(
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
self
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
)
->
None
:
num_embeddings
:
int
,
super
(
Embedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
)
embedding_dim
:
int
,
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
'weight'
,
{
'optim_bits'
:
32
})
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.0
,
scale_grad_by_freq
:
bool
=
False
,
sparse
:
bool
=
False
,
_weight
:
Optional
[
Tensor
]
=
None
,
)
->
None
:
super
(
Embedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
,
max_norm
,
norm_type
,
scale_grad_by_freq
,
sparse
,
_weight
,
)
GlobalOptimManager
.
get_instance
().
register_module_override
(
self
,
"weight"
,
{
"optim_bits"
:
32
}
)
def
reset_parameters
(
self
)
->
None
:
def
reset_parameters
(
self
)
->
None
:
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
self
.
_fill_padding_idx_with_zero
()
self
.
_fill_padding_idx_with_zero
()
'''
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
"""
!!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
PyTorch releases.
'''
"""
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
:
if
self
.
padding_idx
is
not
None
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding):
...
@@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding):
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
emb
=
F
.
embedding
(
emb
=
F
.
embedding
(
input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
input
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
return
emb
return
emb
class
Int8Params
(
torch
.
nn
.
Parameter
):
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
CB
=
None
cls
.
SCB
=
None
cls
.
SCB
=
None
...
@@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter):
...
@@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter):
del
CBt
del
CBt
del
SCBt
del
SCBt
self
.
data
=
CB
self
.
data
=
CB
setattr
(
self
,
'
CB
'
,
CB
)
setattr
(
self
,
"
CB
"
,
CB
)
setattr
(
self
,
'
SCB
'
,
SCB
)
setattr
(
self
,
"
SCB
"
,
SCB
)
return
self
return
self
@
overload
@
overload
def
to
(
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
def
to
(
non_blocking
:
bool
=
...)
->
T
:
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
non_blocking
:
bool
=
...,
)
->
T
:
...
...
@
overload
@
overload
...
@@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter):
...
@@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter):
...
...
def
to
(
self
,
*
args
,
**
kwargs
):
def
to
(
self
,
*
args
,
**
kwargs
):
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
if
device
is
not
None
and
device
.
type
==
'cuda'
and
self
.
data
.
device
.
type
==
'cpu'
:
return
self
.
cuda
(
device
)
)
if
(
device
is
not
None
and
device
.
type
==
"cuda"
and
self
.
data
.
device
.
type
==
"cpu"
):
return
self
.
cuda
(
device
)
else
:
else
:
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
)
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
,
)
new_param
.
CB
=
self
.
CB
new_param
.
CB
=
self
.
CB
new_param
.
SCB
=
self
.
SCB
new_param
.
SCB
=
self
.
SCB
return
new_param
return
new_param
class
Linear8bitLt
(
nn
.
Linear
):
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
,
):
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
state
=
bnb
.
MatmulLtState
()
self
.
index
=
index
self
.
index
=
index
self
.
state
.
threshold
=
threshold
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
self
.
state
.
has_fp16_weights
=
has_fp16_weights
...
@@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear):
...
@@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
self
.
state
.
is_training
=
self
.
training
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
if
self
.
weight
.
CB
is
not
None
:
#assert not self.state.has_fp16_weights
self
.
init_8bit_state
()
#if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
# assert not self.state.has_fp16_weights
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
state
=
self
.
state
)
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
state
=
self
.
state
)
...
@@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear):
...
@@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear):
return
out
return
out
class
Linear8bit
(
nn
.
Linear
):
class
Linear8bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
quant_type
=
'vector'
,
index
=
None
,
args
=
None
,
sparse_decomp
=
False
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
quant_type
=
"vector"
,
index
=
None
,
args
=
None
,
sparse_decomp
=
False
,
):
super
(
Linear8bit
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
super
(
Linear8bit
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
quant_type
=
quant_type
self
.
quant_type
=
quant_type
self
.
index
=
index
self
.
index
=
index
...
@@ -178,15 +266,24 @@ class Linear8bit(nn.Linear):
...
@@ -178,15 +266,24 @@ class Linear8bit(nn.Linear):
self
.
iter
+=
1
self
.
iter
+=
1
if
self
.
iter
%
self
.
args
.
clip_freq
==
0
:
if
self
.
iter
%
self
.
args
.
clip_freq
==
0
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
maxval
,
maxidx
=
torch
.
topk
(
torch
.
abs
(
self
.
weight
.
flatten
()),
k
=
self
.
args
.
clip_idx
)
maxval
,
maxidx
=
torch
.
topk
(
torch
.
abs
(
self
.
weight
.
flatten
()),
k
=
self
.
args
.
clip_idx
)
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
print
(
'
clip
'
,
maxval
[
-
1
].
item
())
print
(
"
clip
"
,
maxval
[
-
1
].
item
())
self
.
weight
.
clip_
(
-
maxval
[
-
1
],
maxval
[
-
1
])
self
.
weight
.
clip_
(
-
maxval
[
-
1
],
maxval
[
-
1
])
if
self
.
args
is
not
None
:
if
self
.
args
is
not
None
:
out
=
bnb
.
nn
.
functional
.
sparse_decomposed_linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
qval
=
self
.
args
.
sparse_decomp_val
,
quant_type
=
self
.
args
.
quant_type
)
out
=
bnb
.
nn
.
functional
.
sparse_decomposed_linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
qval
=
self
.
args
.
sparse_decomp_val
,
quant_type
=
self
.
args
.
quant_type
,
)
else
:
else
:
out
=
bnb
.
nn
.
functional
.
linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
quant_type
=
self
.
args
.
quant_type
)
out
=
bnb
.
nn
.
functional
.
linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
quant_type
=
self
.
args
.
quant_type
)
return
out
return
out
bitsandbytes/optim/__init__.py
View file @
bfa0e332
bitsandbytes/optim/adagrad.py
View file @
bfa0e332
...
@@ -4,9 +4,22 @@
...
@@ -4,9 +4,22 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
Adagrad
(
Optimizer1State
):
class
Adagrad
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
def
__init__
(
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
...
@@ -14,15 +27,39 @@ class Adagrad(Optimizer1State):
...
@@ -14,15 +27,39 @@ class Adagrad(Optimizer1State):
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
raise
ValueError
(
"Lr Decay != 0.0 not supported!"
)
super
(
Adagrad
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
super
(
Adagrad
,
self
).
__init__
(
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adagrad8bit
(
Optimizer1State
):
class
Adagrad8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
def
__init__
(
optim_bits
=
8
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
8
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
...
@@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State):
...
@@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State):
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
if
lr_decay
!=
0.0
:
raise
ValueError
(
'
Lr Decay != 0.0 not supported!
'
)
raise
ValueError
(
"
Lr Decay != 0.0 not supported!
"
)
assert
block_wise
assert
block_wise
super
(
Adagrad8bit
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
super
(
Adagrad8bit
,
self
).
__init__
(
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adagrad32bit
(
Optimizer1State
):
class
Adagrad32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
def
__init__
(
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
...
@@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State):
...
@@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State):
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'
Initial accumulator value != 0.0 not supported!
'
)
raise
ValueError
(
"
Initial accumulator value != 0.0 not supported!
"
)
if
lr_decay
!=
0.0
:
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
raise
ValueError
(
"Lr Decay != 0.0 not supported!"
)
super
(
Adagrad32bit
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
super
(
Adagrad32bit
,
self
).
__init__
(
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"adagrad"
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/adam.py
View file @
bfa0e332
...
@@ -8,29 +8,97 @@ import os
...
@@ -8,29 +8,97 @@ import os
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
bitsandbytes.optim.optimizer
import
Optimizer2State
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
Adam
(
Optimizer2State
):
class
Adam
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
super
(
Adam
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adam8bit
(
Optimizer2State
):
class
Adam8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
super
(
Adam8bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam8bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Adam32bit
(
Optimizer2State
):
class
Adam32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
super
(
Adam32bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
Adam32bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AnalysisAdam
(
torch
.
optim
.
Optimizer
):
class
AnalysisAdam
(
torch
.
optim
.
Optimizer
):
...
@@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer):
eps
=
1e-8
,
eps
=
1e-8
,
weight_decay
=
0
,
weight_decay
=
0
,
amsgrad
=
False
,
amsgrad
=
False
,
bnb_analysis
=
'
dynamic-blockwise
'
,
bnb_analysis
=
"
dynamic-blockwise
"
,
savedir
=
None
savedir
=
None
,
):
):
defaults
=
dict
(
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
...
@@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer):
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p_data_fp32
)
# Exponential moving average of squared gradient values
# Exponential moving average of squared gradient values
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
'abserrors'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"abserrors"
]
=
torch
.
zeros
(
state
[
'relerrors'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
(
256
,
256
),
device
=
p_data_fp32
.
device
state
[
'counts'
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
)
state
[
"relerrors"
]
=
torch
.
zeros
(
(
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"counts"
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
if
amsgrad
:
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. values
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
...
@@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer):
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
step_size
=
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
step_size
=
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
e
=
state
[
'
abserrors
'
]
e
=
state
[
"
abserrors
"
]
rele
=
state
[
'
relerrors
'
]
rele
=
state
[
"
relerrors
"
]
counts
=
state
[
'
counts
'
]
counts
=
state
[
"
counts
"
]
if
group
[
"weight_decay"
]
!=
0
:
if
group
[
"weight_decay"
]
!=
0
:
p_data_fp32
.
add_
(
p_data_fp32
.
add_
(
...
@@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer):
if
amsgrad
:
if
amsgrad
:
max_exp_avg_sq
=
state
[
"max_exp_avg_sq"
]
max_exp_avg_sq
=
state
[
"max_exp_avg_sq"
]
# Decay the first and second moment running average coefficient
# Decay the first and second moment running average coefficient
exp_avg
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
update_fp32
=
exp_avg
/
denom
update_fp32
=
exp_avg
/
denom
if
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
:
if
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
:
# embedding layer or too small
# embedding layer or too small
p_data_fp32
+=
-
step_size
*
update_fp32
p_data_fp32
+=
-
step_size
*
update_fp32
else
:
else
:
if
self
.
analysis
==
'
dynamic-blockwise
'
:
if
self
.
analysis
==
"
dynamic-blockwise
"
:
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize_blockwise
(
exp_avg
,
code
=
code1
)
C1
,
S1
=
F
.
quantize_blockwise
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize_blockwise
(
C1
,
S1
)
state1
=
F
.
dequantize_blockwise
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize_blockwise
(
exp_avg_sq
,
code
=
code2
)
C2
,
S2
=
F
.
quantize_blockwise
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize_blockwise
(
C2
,
S2
)
state2
=
F
.
dequantize_blockwise
(
C2
,
S2
)
elif
self
.
analysis
==
'
dynamic
'
:
elif
self
.
analysis
==
"
dynamic
"
:
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code1
=
F
.
create_dynamic_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
code2
=
F
.
create_dynamic_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
elif
self
.
analysis
==
'
linear
'
:
elif
self
.
analysis
==
"
linear
"
:
code1
=
F
.
create_linear_map
(
signed
=
True
).
to
(
p
.
device
)
code1
=
F
.
create_linear_map
(
signed
=
True
).
to
(
p
.
device
)
code2
=
F
.
create_linear_map
(
signed
=
False
).
to
(
p
.
device
)
code2
=
F
.
create_linear_map
(
signed
=
False
).
to
(
p
.
device
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
C1
,
S1
=
F
.
quantize
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
state1
=
F
.
dequantize
(
C1
,
S1
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
C2
,
S2
=
F
.
quantize
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
state2
=
F
.
dequantize
(
C2
,
S2
)
elif
self
.
analysis
==
'
quantile
'
:
elif
self
.
analysis
==
"
quantile
"
:
code1
=
F
.
estimate_quantiles
(
exp_avg
)
code1
=
F
.
estimate_quantiles
(
exp_avg
)
code2
=
F
.
estimate_quantiles
(
exp_avg_sq
)
code2
=
F
.
estimate_quantiles
(
exp_avg_sq
)
C1
=
F
.
quantize_no_absmax
(
exp_avg
,
code
=
code1
)
C1
=
F
.
quantize_no_absmax
(
exp_avg
,
code
=
code1
)
state1
=
F
.
dequantize_no_absmax
(
C1
,
code1
)
state1
=
F
.
dequantize_no_absmax
(
C1
,
code1
)
C2
=
F
.
quantize_no_absmax
(
exp_avg_sq
,
code
=
code2
)
C2
=
F
.
quantize_no_absmax
(
exp_avg_sq
,
code
=
code2
)
state2
=
F
.
dequantize_no_absmax
(
C2
,
code2
)
state2
=
F
.
dequantize_no_absmax
(
C2
,
code2
)
elif
self
.
analysis
==
'
my-quantization-routine
'
:
elif
self
.
analysis
==
"
my-quantization-routine
"
:
pass
pass
# 1. get code
# 1. get code
# 2. quantize
# 2. quantize
# 3. dequantize
# 3. dequantize
# Error will be calculated automatically!
# Error will be calculated automatically!
else
:
else
:
raise
ValueError
(
f
'
Invalid analysis value:
{
self
.
analysis
}
!
'
)
raise
ValueError
(
f
"
Invalid analysis value:
{
self
.
analysis
}
!
"
)
denom
=
state2
.
sqrt
().
add_
(
group
[
"eps"
])
denom
=
state2
.
sqrt
().
add_
(
group
[
"eps"
])
update_8bit
=
state1
/
denom
update_8bit
=
state1
/
denom
abserr
=
torch
.
abs
(
update_8bit
-
update_fp32
)
abserr
=
torch
.
abs
(
update_8bit
-
update_fp32
)
relerr
=
abserr
/
torch
.
abs
(
update_fp32
+
1e-6
)
relerr
=
abserr
/
torch
.
abs
(
update_fp32
+
1e-6
)
C1
,
C2
=
C1
.
int
(),
C2
.
int
()
C1
,
C2
=
C1
.
int
(),
C2
.
int
()
F
.
histogram_scatter_add_2d
(
e
,
C1
.
int
(),
C2
.
int
(),
abserr
)
F
.
histogram_scatter_add_2d
(
e
,
C1
.
int
(),
C2
.
int
(),
abserr
)
F
.
histogram_scatter_add_2d
(
rele
,
C1
.
int
(),
C2
.
int
(),
relerr
)
F
.
histogram_scatter_add_2d
(
rele
,
C1
.
int
(),
C2
.
int
(),
relerr
)
F
.
histogram_scatter_add_2d
(
counts
,
C1
.
int
(),
C2
.
int
(),
torch
.
ones_like
(
abserr
))
F
.
histogram_scatter_add_2d
(
counts
,
C1
.
int
(),
C2
.
int
(),
torch
.
ones_like
(
abserr
)
p_data_fp32
+=
-
step_size
*
update_fp32
)
p_data_fp32
+=
-
step_size
*
update_fp32
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
if
self
.
savedir
!=
''
and
state
[
'step'
]
%
100
==
0
:
if
self
.
savedir
!=
""
and
state
[
"step"
]
%
100
==
0
:
if
not
os
.
path
.
exists
(
self
.
savedir
):
os
.
makedirs
(
self
.
savedir
)
if
not
os
.
path
.
exists
(
self
.
savedir
):
shapestr
=
'_'
.
join
([
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
])
os
.
makedirs
(
self
.
savedir
)
pathe
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_abserr.pkl'
)
shapestr
=
"_"
.
join
([
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
])
pathrele
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_relerr.pkl'
)
pathe
=
os
.
path
.
join
(
pathcounts
=
os
.
path
.
join
(
self
.
savedir
,
f
'
{
p_id
}
_
{
shapestr
}
_counts.pkl'
)
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_abserr.pkl"
)
pathrele
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_relerr.pkl"
)
pathcounts
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_counts.pkl"
)
torch
.
save
(
e
,
pathe
)
torch
.
save
(
e
,
pathe
)
torch
.
save
(
rele
,
pathrele
)
torch
.
save
(
rele
,
pathrele
)
torch
.
save
(
counts
,
pathcounts
)
torch
.
save
(
counts
,
pathcounts
)
...
@@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer):
if
p
.
data
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
if
p
.
data
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}:
p
.
data
.
copy_
(
p_data_fp32
)
p
.
data
.
copy_
(
p_data_fp32
)
return
loss
return
loss
bitsandbytes/optim/adamw.py
View file @
bfa0e332
...
@@ -4,24 +4,90 @@
...
@@ -4,24 +4,90 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer2State
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
AdamW
(
Optimizer2State
):
class
AdamW
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
1e-2
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
super
(
AdamW
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AdamW8bit
(
Optimizer2State
):
class
AdamW8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
super
(
AdamW8bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW8bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
AdamW32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW32bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
class
AdamW32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
(
AdamW32bit
,
self
).
__init__
(
"adam"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/lamb.py
View file @
bfa0e332
...
@@ -4,25 +4,102 @@
...
@@ -4,25 +4,102 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer2State
from
bitsandbytes.optim.optimizer
import
Optimizer2State
class
LAMB
(
Optimizer2State
):
class
LAMB
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
params
,
super
(
LAMB
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
class
LAMB8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
super
(
LAMB8bit
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
class
LAMB32bit
(
Optimizer2State
):
class
LAMB8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
):
params
,
super
(
LAMB32bit
,
self
).
__init__
(
'lamb'
,
params
,
lr
,
betas
,
eps
,
lr
=
1e-3
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
)
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB8bit
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
class
LAMB32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
,
adam_w_mode
=
True
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
False
,
max_unorm
=
1.0
,
):
super
(
LAMB32bit
,
self
).
__init__
(
"lamb"
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
max_unorm
=
1.0
,
)
bitsandbytes/optim/lars.py
View file @
bfa0e332
...
@@ -3,41 +3,119 @@
...
@@ -3,41 +3,119 @@
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
torch
import
torch
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
bitsandbytes.optim.optimizer
import
Optimizer1State
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
LARS
(
Optimizer1State
):
class
LARS
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
LARS
,
self
).
__init__
(
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
LARS8bit
(
Optimizer1State
):
class
LARS8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS8bit
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
LARS8bit
,
self
).
__init__
(
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
LARS32bit
(
Optimizer1State
):
class
LARS32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
max_unorm
=
0.02
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'LARS without momentum is not supported!'
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS32bit
,
self
).
__init__
(
'lars'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
LARS32bit
,
self
).
__init__
(
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
)
"lars"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
max_unorm
=
max_unorm
,
block_wise
=
False
,
)
class
PytorchLARS
(
Optimizer
):
class
PytorchLARS
(
Optimizer
):
def
__init__
(
self
,
params
,
lr
=
0.01
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
max_unorm
=
0.02
):
self
,
params
,
lr
=
0.01
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
max_unorm
=
0.02
,
):
if
lr
<
0.0
:
if
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
momentum
<
0.0
:
if
momentum
<
0.0
:
...
@@ -45,8 +123,14 @@ class PytorchLARS(Optimizer):
...
@@ -45,8 +123,14 @@ class PytorchLARS(Optimizer):
if
weight_decay
<
0.0
:
if
weight_decay
<
0.0
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
dampening
=
dampening
,
defaults
=
dict
(
weight_decay
=
weight_decay
,
nesterov
=
nesterov
,
max_unorm
=
max_unorm
)
lr
=
lr
,
momentum
=
momentum
,
dampening
=
dampening
,
weight_decay
=
weight_decay
,
nesterov
=
nesterov
,
max_unorm
=
max_unorm
,
)
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
super
(
PytorchLARS
,
self
).
__init__
(
params
,
defaults
)
super
(
PytorchLARS
,
self
).
__init__
(
params
,
defaults
)
...
@@ -54,7 +138,7 @@ class PytorchLARS(Optimizer):
...
@@ -54,7 +138,7 @@ class PytorchLARS(Optimizer):
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
super
(
PytorchLARS
,
self
).
__setstate__
(
state
)
super
(
PytorchLARS
,
self
).
__setstate__
(
state
)
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
group
.
setdefault
(
'
nesterov
'
,
False
)
group
.
setdefault
(
"
nesterov
"
,
False
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
step
(
self
,
closure
=
None
):
def
step
(
self
,
closure
=
None
):
...
@@ -73,15 +157,16 @@ class PytorchLARS(Optimizer):
...
@@ -73,15 +157,16 @@ class PytorchLARS(Optimizer):
params_with_grad
=
[]
params_with_grad
=
[]
d_p_list
=
[]
d_p_list
=
[]
momentum_buffer_list
=
[]
momentum_buffer_list
=
[]
weight_decay
=
group
[
'
weight_decay
'
]
weight_decay
=
group
[
"
weight_decay
"
]
momentum
=
group
[
'
momentum
'
]
momentum
=
group
[
"
momentum
"
]
dampening
=
group
[
'
dampening
'
]
dampening
=
group
[
"
dampening
"
]
nesterov
=
group
[
'
nesterov
'
]
nesterov
=
group
[
"
nesterov
"
]
max_unorm
=
group
[
'
max_unorm
'
]
max_unorm
=
group
[
"
max_unorm
"
]
lr
=
group
[
'
lr
'
]
lr
=
group
[
"
lr
"
]
for
p
in
group
[
'params'
]:
for
p
in
group
[
"params"
]:
if
p
.
grad
is
None
:
continue
if
p
.
grad
is
None
:
continue
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
d_p
=
p
.
grad
d_p
=
p
.
grad
...
@@ -89,16 +174,16 @@ class PytorchLARS(Optimizer):
...
@@ -89,16 +174,16 @@ class PytorchLARS(Optimizer):
d_p
=
d_p
.
add
(
param
,
alpha
=
weight_decay
)
d_p
=
d_p
.
add
(
param
,
alpha
=
weight_decay
)
if
momentum
!=
0
:
if
momentum
!=
0
:
buf
=
state
.
get
(
'
momentum_buffer
'
,
None
)
buf
=
state
.
get
(
"
momentum_buffer
"
,
None
)
if
buf
is
None
:
if
buf
is
None
:
buf
=
torch
.
clone
(
d_p
).
detach
()
buf
=
torch
.
clone
(
d_p
).
detach
()
state
[
'
momentum_buffer
'
]
=
buf
state
[
"
momentum_buffer
"
]
=
buf
else
:
else
:
buf
.
mul_
(
momentum
).
add_
(
d_p
,
alpha
=
1
-
dampening
)
buf
.
mul_
(
momentum
).
add_
(
d_p
,
alpha
=
1
-
dampening
)
if
nesterov
:
if
nesterov
:
update
=
d_p
+
buf
*
momentum
update
=
d_p
+
buf
*
momentum
else
:
else
:
update
=
buf
update
=
buf
...
@@ -107,9 +192,9 @@ class PytorchLARS(Optimizer):
...
@@ -107,9 +192,9 @@ class PytorchLARS(Optimizer):
assert
p
.
dtype
==
torch
.
float32
assert
p
.
dtype
==
torch
.
float32
pnorm
=
torch
.
norm
(
p
.
detach
())
pnorm
=
torch
.
norm
(
p
.
detach
())
unorm
=
torch
.
norm
(
update
)
unorm
=
torch
.
norm
(
update
)
if
unorm
>
max_unorm
*
pnorm
:
if
unorm
>
max_unorm
*
pnorm
:
update_scale
=
max_unorm
*
pnorm
/
unorm
update_scale
=
max_unorm
*
pnorm
/
unorm
p
.
add_
(
update
,
alpha
=-
lr
*
update_scale
)
p
.
add_
(
update
,
alpha
=-
lr
*
update_scale
)
return
loss
return
loss
bitsandbytes/optim/optimizer.py
View file @
bfa0e332
...
@@ -2,12 +2,15 @@
...
@@ -2,12 +2,15 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
collections
import
abc
as
container_abcs
from
collections
import
defaultdict
from
copy
import
deepcopy
from
itertools
import
chain
import
torch
import
torch
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
copy
import
deepcopy
from
itertools
import
chain
from
collections
import
defaultdict
,
abc
as
container_abcs
class
MockArgs
(
object
):
class
MockArgs
(
object
):
def
__init__
(
self
,
initial_data
):
def
__init__
(
self
,
initial_data
):
...
@@ -19,7 +22,7 @@ class GlobalOptimManager(object):
...
@@ -19,7 +22,7 @@ class GlobalOptimManager(object):
_instance
=
None
_instance
=
None
def
__init__
(
self
):
def
__init__
(
self
):
raise
RuntimeError
(
'
Call get_instance() instead
'
)
raise
RuntimeError
(
"
Call get_instance() instead
"
)
def
initialize
(
self
):
def
initialize
(
self
):
self
.
pid2config
=
{}
self
.
pid2config
=
{}
...
@@ -38,15 +41,15 @@ class GlobalOptimManager(object):
...
@@ -38,15 +41,15 @@ class GlobalOptimManager(object):
def
register_parameters
(
self
,
params
):
def
register_parameters
(
self
,
params
):
param_groups
=
list
(
params
)
param_groups
=
list
(
params
)
if
not
isinstance
(
param_groups
[
0
],
dict
):
if
not
isinstance
(
param_groups
[
0
],
dict
):
param_groups
=
[{
'
params
'
:
param_groups
}]
param_groups
=
[{
"
params
"
:
param_groups
}]
for
group_index
,
group
in
enumerate
(
param_groups
):
for
group_index
,
group
in
enumerate
(
param_groups
):
for
p_index
,
p
in
enumerate
(
group
[
'
params
'
]):
for
p_index
,
p
in
enumerate
(
group
[
"
params
"
]):
if
id
(
p
)
in
self
.
pid2config
:
if
id
(
p
)
in
self
.
pid2config
:
self
.
index2config
[(
group_index
,
p_index
)]
=
self
.
pid2config
[
id
(
p
)]
self
.
index2config
[(
group_index
,
p_index
)]
=
self
.
pid2config
[
id
(
p
)]
def
override_config
(
self
,
parameters
,
key
=
None
,
value
=
None
,
key_value_dict
=
None
):
def
override_config
(
self
,
parameters
,
key
=
None
,
value
=
None
,
key_value_dict
=
None
):
'''
"""
Overrides initial optimizer config for specific parameters.
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
The key-values of the optimizer config for the input parameters are overidden
...
@@ -63,7 +66,7 @@ class GlobalOptimManager(object):
...
@@ -63,7 +66,7 @@ class GlobalOptimManager(object):
The value for the hyperparamters.
The value for the hyperparamters.
key_value_dict : dict
key_value_dict : dict
A dictionary with multiple key-values to override.
A dictionary with multiple key-values to override.
'''
"""
self
.
uses_config_override
=
True
self
.
uses_config_override
=
True
if
isinstance
(
parameters
,
torch
.
nn
.
Parameter
):
if
isinstance
(
parameters
,
torch
.
nn
.
Parameter
):
parameters
=
[
parameters
]
parameters
=
[
parameters
]
...
@@ -75,16 +78,16 @@ class GlobalOptimManager(object):
...
@@ -75,16 +78,16 @@ class GlobalOptimManager(object):
if
key_value_dict
is
not
None
:
if
key_value_dict
is
not
None
:
for
p
in
parameters
:
for
p
in
parameters
:
if
id
(
p
)
in
self
.
pid2config
:
self
.
pid2config
[
id
(
p
)].
update
(
key_value_dict
)
if
id
(
p
)
in
self
.
pid2config
:
else
:
self
.
pid2config
[
id
(
p
)]
=
key_value_dict
self
.
pid2config
[
id
(
p
)].
update
(
key_value_dict
)
else
:
self
.
pid2config
[
id
(
p
)]
=
key_value_dict
def
register_module_override
(
self
,
module
,
param_name
,
config
):
def
register_module_override
(
self
,
module
,
param_name
,
config
):
self
.
module_weight_config_triple
.
append
((
module
,
param_name
,
config
))
self
.
module_weight_config_triple
.
append
((
module
,
param_name
,
config
))
class
Optimizer8bit
(
torch
.
optim
.
Optimizer
):
class
Optimizer8bit
(
torch
.
optim
.
Optimizer
):
def
__init__
(
self
,
params
,
defaults
,
optim_bits
=
32
):
def
__init__
(
self
,
params
,
defaults
,
optim_bits
=
32
):
super
(
Optimizer8bit
,
self
).
__init__
(
params
,
defaults
)
super
(
Optimizer8bit
,
self
).
__init__
(
params
,
defaults
)
self
.
initialized
=
False
self
.
initialized
=
False
...
@@ -92,23 +95,32 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -92,23 +95,32 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
mng
=
GlobalOptimManager
.
get_instance
()
self
.
mng
=
GlobalOptimManager
.
get_instance
()
self
.
non_castable_tensor_keys
=
set
(
self
.
non_castable_tensor_keys
=
set
(
[
'qmap1'
,
'qmap2'
,
[
'max1'
,
'max2'
,
"qmap1"
,
'new_max1'
,
'new_max2'
,
"qmap2"
,
'state1'
,
'state2'
,
"max1"
,
'gnorm_vec'
,
'absmax1'
,
'absmax2'
,
"max2"
,
'unorm_vec'
])
"new_max1"
,
"new_max2"
,
if
optim_bits
==
8
:
self
.
fill_qmap
()
"state1"
,
"state2"
,
"gnorm_vec"
,
"absmax1"
,
"absmax2"
,
"unorm_vec"
,
]
)
if
optim_bits
==
8
:
self
.
fill_qmap
()
def
fill_qmap
(
self
):
def
fill_qmap
(
self
):
self
.
name2qmap
[
'
dynamic
'
]
=
F
.
create_dynamic_map
(
signed
=
True
)
self
.
name2qmap
[
"
dynamic
"
]
=
F
.
create_dynamic_map
(
signed
=
True
)
self
.
name2qmap
[
'
udynamic
'
]
=
F
.
create_dynamic_map
(
signed
=
False
)
self
.
name2qmap
[
"
udynamic
"
]
=
F
.
create_dynamic_map
(
signed
=
False
)
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
super
(
Optimizer8bit
,
self
).
__setstate__
(
state
)
super
(
Optimizer8bit
,
self
).
__setstate__
(
state
)
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
):
r
"""Loads the optimizer state.
r
"""Loads the optimizer state.
...
@@ -120,21 +132,28 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -120,21 +132,28 @@ class Optimizer8bit(torch.optim.Optimizer):
state_dict
=
deepcopy
(
state_dict
)
state_dict
=
deepcopy
(
state_dict
)
# Validate the state_dict
# Validate the state_dict
groups
=
self
.
param_groups
groups
=
self
.
param_groups
saved_groups
=
state_dict
[
'
param_groups
'
]
saved_groups
=
state_dict
[
"
param_groups
"
]
if
len
(
groups
)
!=
len
(
saved_groups
):
if
len
(
groups
)
!=
len
(
saved_groups
):
raise
ValueError
(
"loaded state dict has a different number of "
raise
ValueError
(
"parameter groups"
)
"loaded state dict has a different number of "
"parameter groups"
param_lens
=
(
len
(
g
[
'params'
])
for
g
in
groups
)
)
saved_lens
=
(
len
(
g
[
'params'
])
for
g
in
saved_groups
)
param_lens
=
(
len
(
g
[
"params"
])
for
g
in
groups
)
saved_lens
=
(
len
(
g
[
"params"
])
for
g
in
saved_groups
)
if
any
(
p_len
!=
s_len
for
p_len
,
s_len
in
zip
(
param_lens
,
saved_lens
)):
if
any
(
p_len
!=
s_len
for
p_len
,
s_len
in
zip
(
param_lens
,
saved_lens
)):
raise
ValueError
(
"loaded state dict contains a parameter group "
raise
ValueError
(
"that doesn't match the size of optimizer's group"
)
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
# Update the state
# Update the state
id_map
=
{
old_id
:
p
for
old_id
,
p
in
id_map
=
{
zip
(
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
saved_groups
)),
old_id
:
p
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
groups
)))}
for
old_id
,
p
in
zip
(
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
saved_groups
)),
chain
.
from_iterable
((
g
[
"params"
]
for
g
in
groups
)),
)
}
def
cast
(
param
,
value
):
def
cast
(
param
,
value
):
r
"""Make a deep copy of value, casting all tensors to device of param."""
r
"""Make a deep copy of value, casting all tensors to device of param."""
...
@@ -161,7 +180,7 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -161,7 +180,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# State that is not assigned to params is copied as is (needed for
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
# backward compatibility).
state
=
defaultdict
(
dict
)
state
=
defaultdict
(
dict
)
for
k
,
v
in
state_dict
[
'
state
'
].
items
():
for
k
,
v
in
state_dict
[
"
state
"
].
items
():
if
k
in
id_map
:
if
k
in
id_map
:
param
=
id_map
[
k
]
param
=
id_map
[
k
]
state
[
param
]
=
cast
(
param
,
v
)
state
[
param
]
=
cast
(
param
,
v
)
...
@@ -170,15 +189,15 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -170,15 +189,15 @@ class Optimizer8bit(torch.optim.Optimizer):
# Update parameter groups, setting their 'params' value
# Update parameter groups, setting their 'params' value
def
update_group
(
group
,
new_group
):
def
update_group
(
group
,
new_group
):
new_group
[
'
params
'
]
=
group
[
'
params
'
]
new_group
[
"
params
"
]
=
group
[
"
params
"
]
return
new_group
return
new_group
param_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
groups
,
saved_groups
)]
param_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
groups
,
saved_groups
)]
self
.
__setstate__
({
'
state
'
:
state
,
'
param_groups
'
:
param_groups
})
self
.
__setstate__
({
"
state
"
:
state
,
"
param_groups
"
:
param_groups
})
def
to_gpu
(
self
):
def
to_gpu
(
self
):
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
pindex
,
p
in
enumerate
(
group
[
'
params
'
]):
for
pindex
,
p
in
enumerate
(
group
[
"
params
"
]):
if
p
in
self
.
state
:
if
p
in
self
.
state
:
values
=
self
.
state
[
p
]
values
=
self
.
state
[
p
]
for
k
,
v
in
values
.
items
():
for
k
,
v
in
values
.
items
():
...
@@ -189,17 +208,23 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -189,17 +208,23 @@ class Optimizer8bit(torch.optim.Optimizer):
for
module
,
attr
,
config
in
self
.
mng
.
module_weight_config_triple
:
for
module
,
attr
,
config
in
self
.
mng
.
module_weight_config_triple
:
pmodule
=
getattr
(
module
,
attr
)
pmodule
=
getattr
(
module
,
attr
)
assert
pmodule
is
not
None
assert
pmodule
is
not
None
assert
isinstance
(
pmodule
,
torch
.
Tensor
)
or
isinstance
(
pmodule
,
torch
.
Parameter
)
assert
isinstance
(
pmodule
,
torch
.
Tensor
)
or
isinstance
(
pmodule
,
torch
.
Parameter
)
found
=
False
found
=
False
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
if
found
:
break
if
found
:
for
pindex
,
p
in
enumerate
(
group
[
'params'
]):
break
if
found
:
break
for
pindex
,
p
in
enumerate
(
group
[
"params"
]):
if
found
:
break
if
id
(
p
)
==
id
(
pmodule
):
if
id
(
p
)
==
id
(
pmodule
):
# found the matching parameter
# found the matching parameter
# init override
# init override
self
.
mng
.
pid2config
[
id
(
p
)]
=
config
self
.
mng
.
pid2config
[
id
(
p
)]
=
config
self
.
mng
.
index2config
[(
gindex
,
pindex
)]
=
self
.
mng
.
pid2config
[
id
(
p
)]
self
.
mng
.
index2config
[(
gindex
,
pindex
)]
=
self
.
mng
.
pid2config
[
id
(
p
)
]
found
=
True
found
=
True
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -223,7 +248,7 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -223,7 +248,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
initialized
=
True
self
.
initialized
=
True
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
pindex
,
p
in
enumerate
(
group
[
'
params
'
]):
for
pindex
,
p
in
enumerate
(
group
[
"
params
"
]):
if
p
.
grad
is
None
:
if
p
.
grad
is
None
:
continue
continue
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
...
@@ -236,58 +261,70 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -236,58 +261,70 @@ class Optimizer8bit(torch.optim.Optimizer):
def
get_config
(
self
,
gindex
,
pindex
,
group
):
def
get_config
(
self
,
gindex
,
pindex
,
group
):
config
=
{}
config
=
{}
config
[
'
betas
'
]
=
group
[
'
betas
'
]
config
[
"
betas
"
]
=
group
[
"
betas
"
]
config
[
'
eps
'
]
=
group
[
'
eps
'
]
config
[
"
eps
"
]
=
group
[
"
eps
"
]
config
[
'
weight_decay
'
]
=
group
[
'
weight_decay
'
]
config
[
"
weight_decay
"
]
=
group
[
"
weight_decay
"
]
config
[
'
lr
'
]
=
group
[
'
lr
'
]
config
[
"
lr
"
]
=
group
[
"
lr
"
]
config
[
'
optim_bits
'
]
=
self
.
args
.
optim_bits
config
[
"
optim_bits
"
]
=
self
.
args
.
optim_bits
config
[
'
min_8bit_size
'
]
=
self
.
args
.
min_8bit_size
config
[
"
min_8bit_size
"
]
=
self
.
args
.
min_8bit_size
config
[
'
percentile_clipping
'
]
=
self
.
args
.
percentile_clipping
config
[
"
percentile_clipping
"
]
=
self
.
args
.
percentile_clipping
config
[
'
block_wise
'
]
=
self
.
args
.
block_wise
config
[
"
block_wise
"
]
=
self
.
args
.
block_wise
config
[
'
max_unorm
'
]
=
self
.
args
.
max_unorm
config
[
"
max_unorm
"
]
=
self
.
args
.
max_unorm
config
[
'
skip_zeros
'
]
=
self
.
args
.
skip_zeros
config
[
"
skip_zeros
"
]
=
self
.
args
.
skip_zeros
if
(
gindex
,
pindex
)
in
self
.
mng
.
index2config
:
if
(
gindex
,
pindex
)
in
self
.
mng
.
index2config
:
config
.
update
(
self
.
mng
.
index2config
[(
gindex
,
pindex
)])
config
.
update
(
self
.
mng
.
index2config
[(
gindex
,
pindex
)])
return
config
return
config
def
init_state
(
self
,
group
,
p
,
gindex
,
pindex
):
def
init_state
(
self
,
group
,
p
,
gindex
,
pindex
):
raise
NotImplementedError
(
f
'
init_state method needs to be overidden
'
)
raise
NotImplementedError
(
f
"
init_state method needs to be overidden
"
)
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
raise
NotImplementedError
(
f
'The update_step method needs to be overidden'
)
raise
NotImplementedError
(
f
"The update_step method needs to be overidden"
)
class
Optimizer2State
(
Optimizer8bit
):
class
Optimizer2State
(
Optimizer8bit
):
def
__init__
(
self
,
optimizer_name
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0.0
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
,
optimizer_name
,
skip_zeros
=
False
):
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0.0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
,
skip_zeros
=
False
,
):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
isinstance
(
betas
,
str
):
if
isinstance
(
betas
,
str
):
# format: '(beta1, beta2)'
# format: '(beta1, beta2)'
betas
=
betas
.
replace
(
'('
,
''
).
replace
(
')'
,
''
).
strip
().
split
(
','
)
betas
=
betas
.
replace
(
"("
,
""
).
replace
(
")"
,
""
).
strip
().
split
(
","
)
betas
=
[
float
(
b
)
for
b
in
betas
]
betas
=
[
float
(
b
)
for
b
in
betas
]
for
i
in
range
(
len
(
betas
)):
for
i
in
range
(
len
(
betas
)):
if
not
0.0
<=
betas
[
i
]
<
1.0
:
if
not
0.0
<=
betas
[
i
]
<
1.0
:
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
weight_decay
=
weight_decay
)
super
(
Optimizer2State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
super
(
Optimizer2State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
if
args
is
None
:
if
args
is
None
:
args
=
{}
args
=
{}
args
[
'
optim_bits
'
]
=
optim_bits
args
[
"
optim_bits
"
]
=
optim_bits
args
[
'
percentile_clipping
'
]
=
100
args
[
"
percentile_clipping
"
]
=
100
args
[
'
min_8bit_size
'
]
=
min_8bit_size
args
[
"
min_8bit_size
"
]
=
min_8bit_size
args
[
'
percentile_clipping
'
]
=
percentile_clipping
args
[
"
percentile_clipping
"
]
=
percentile_clipping
args
[
'
block_wise
'
]
=
block_wise
args
[
"
block_wise
"
]
=
block_wise
args
[
'
max_unorm
'
]
=
max_unorm
args
[
"
max_unorm
"
]
=
max_unorm
args
[
'
skip_zeros
'
]
=
skip_zeros
args
[
"
skip_zeros
"
]
=
skip_zeros
self
.
args
=
MockArgs
(
args
)
self
.
args
=
MockArgs
(
args
)
else
:
else
:
...
@@ -299,50 +336,83 @@ class Optimizer2State(Optimizer8bit):
...
@@ -299,50 +336,83 @@ class Optimizer2State(Optimizer8bit):
def
init_state
(
self
,
group
,
p
,
gindex
,
pindex
):
def
init_state
(
self
,
group
,
p
,
gindex
,
pindex
):
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
if
config
[
'
optim_bits
'
]
==
32
:
if
config
[
"
optim_bits
"
]
==
32
:
dtype
=
torch
.
float32
dtype
=
torch
.
float32
elif
config
[
'
optim_bits
'
]
==
8
:
elif
config
[
"
optim_bits
"
]
==
8
:
dtype
=
torch
.
uint8
dtype
=
torch
.
uint8
else
:
raise
NotImplementedError
(
f
'Amount of optimizer bits not supported:
{
config
[
"optim_bits"
]
}
'
)
else
:
raise
NotImplementedError
(
f
'Amount of optimizer bits not supported:
{
config
[
"optim_bits"
]
}
'
)
if
p
.
numel
()
<
config
[
'min_8bit_size'
]:
dtype
=
torch
.
float32
if
p
.
numel
()
<
config
[
"min_8bit_size"
]:
dtype
=
torch
.
float32
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
state
[
'
step
'
]
=
0
state
[
"
step
"
]
=
0
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
state
[
'state1'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"state1"
]
=
torch
.
zeros_like
(
state
[
'state2'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
)
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
,
)
state
[
"state2"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
,
)
elif
dtype
==
torch
.
uint8
:
elif
dtype
==
torch
.
uint8
:
if
state
[
'step'
]
==
0
:
if
state
[
"step"
]
==
0
:
if
'dynamic'
not
in
self
.
name2qmap
:
self
.
fill_qmap
()
if
"dynamic"
not
in
self
.
name2qmap
:
self
.
name2qmap
[
'dynamic'
]
=
self
.
name2qmap
[
'dynamic'
].
to
(
p
.
device
)
self
.
fill_qmap
()
self
.
name2qmap
[
'udynamic'
]
=
self
.
name2qmap
[
'udynamic'
].
to
(
p
.
device
)
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
self
.
name2qmap
[
"udynamic"
]
=
self
.
name2qmap
[
"udynamic"
].
to
(
p
.
device
)
state
[
'state1'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
)
state
[
'qmap1'
]
=
self
.
name2qmap
[
'dynamic'
]
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
state
[
'state2'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
)
memory_format
=
torch
.
preserve_format
,
state
[
'qmap2'
]
=
self
.
name2qmap
[
'udynamic'
]
dtype
=
torch
.
uint8
,
device
=
p
.
device
,
if
config
[
'block_wise'
]:
)
state
[
"qmap1"
]
=
self
.
name2qmap
[
"dynamic"
]
state
[
"state2"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
,
)
state
[
"qmap2"
]
=
self
.
name2qmap
[
"udynamic"
]
if
config
[
"block_wise"
]:
n
=
p
.
numel
()
n
=
p
.
numel
()
blocks
=
n
//
2048
blocks
=
n
//
2048
blocks
+=
1
if
n
%
2048
>
0
else
0
blocks
+=
1
if
n
%
2048
>
0
else
0
state
[
'absmax1'
]
=
torch
.
zeros
((
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"absmax1"
]
=
torch
.
zeros
(
state
[
'absmax2'
]
=
torch
.
zeros
((
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"absmax2"
]
=
torch
.
zeros
(
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
else
:
else
:
state
[
'max1'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max1"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
'new_max1'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max1"
]
=
torch
.
zeros
(
state
[
'max2'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
state
[
'new_max2'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
)
state
[
"max2"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max2"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
if
config
[
'
percentile_clipping
'
]
<
100
:
if
config
[
"
percentile_clipping
"
]
<
100
:
state
[
'
gnorm_vec
'
]
=
torch
.
zeros
((
100
,),
device
=
p
.
device
)
state
[
"
gnorm_vec
"
]
=
torch
.
zeros
((
100
,),
device
=
p
.
device
)
if
config
[
'
max_unorm
'
]
>
0.0
:
if
config
[
"
max_unorm
"
]
>
0.0
:
state
[
'
unorm_vec
'
]
=
torch
.
zeros
((
1
,),
device
=
p
.
device
)
state
[
"
unorm_vec
"
]
=
torch
.
zeros
((
1
,),
device
=
p
.
device
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
...
@@ -351,41 +421,101 @@ class Optimizer2State(Optimizer8bit):
...
@@ -351,41 +421,101 @@ class Optimizer2State(Optimizer8bit):
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
state
[
'
step
'
]
+=
1
state
[
"
step
"
]
+=
1
step
=
state
[
'
step
'
]
step
=
state
[
"
step
"
]
if
config
[
'percentile_clipping'
]
<
100
:
if
config
[
"percentile_clipping"
]
<
100
:
current_gnorm
,
clip_value
,
gnorm_scale
=
F
.
percentile_clipping
(
grad
,
state
[
'gnorm_vec'
],
step
,
config
[
'percentile_clipping'
])
current_gnorm
,
clip_value
,
gnorm_scale
=
F
.
percentile_clipping
(
grad
,
state
[
"gnorm_vec"
],
step
,
config
[
"percentile_clipping"
]
)
else
:
else
:
gnorm_scale
=
1.0
gnorm_scale
=
1.0
if
state
[
'state1'
].
dtype
==
torch
.
float
:
if
state
[
"state1"
].
dtype
==
torch
.
float
:
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
config
[
'betas'
][
0
],
config
[
'eps'
],
step
,
config
[
'lr'
],
F
.
optimizer_update_32bit
(
state
[
'state2'
],
config
[
'betas'
][
1
],
config
[
'weight_decay'
],
gnorm_scale
,
self
.
optimizer_name
,
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
],
skip_zeros
=
config
[
'skip_zeros'
])
grad
,
p
,
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
not
config
[
'block_wise'
]:
state
[
"state1"
],
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
state
[
'state2'
],
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
config
[
"betas"
][
0
],
config
[
'eps'
],
step
,
config
[
'lr'
],
config
[
"eps"
],
state
[
'qmap1'
],
state
[
'qmap2'
],
state
[
'max1'
],
state
[
'max2'
],
state
[
'new_max1'
],
state
[
'new_max2'
],
step
,
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
,
config
[
"lr"
],
unorm_vec
=
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
])
state
[
"state2"
],
config
[
"betas"
][
1
],
config
[
"weight_decay"
],
gnorm_scale
,
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
max_unorm
=
config
[
"max_unorm"
],
skip_zeros
=
config
[
"skip_zeros"
],
)
elif
state
[
"state1"
].
dtype
==
torch
.
uint8
and
not
config
[
"block_wise"
]:
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
"state1"
],
state
[
"state2"
],
config
[
"betas"
][
0
],
config
[
"betas"
][
1
],
config
[
"eps"
],
step
,
config
[
"lr"
],
state
[
"qmap1"
],
state
[
"qmap2"
],
state
[
"max1"
],
state
[
"max2"
],
state
[
"new_max1"
],
state
[
"new_max2"
],
config
[
"weight_decay"
],
gnorm_scale
=
gnorm_scale
,
unorm_vec
=
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
max_unorm
=
config
[
"max_unorm"
],
)
# swap maxes
# swap maxes
state
[
'max1'
],
state
[
'new_max1'
]
=
state
[
'new_max1'
],
state
[
'max1'
]
state
[
"max1"
],
state
[
"new_max1"
]
=
state
[
"new_max1"
],
state
[
"max1"
]
state
[
'max2'
],
state
[
'new_max2'
]
=
state
[
'new_max2'
],
state
[
'max2'
]
state
[
"max2"
],
state
[
"new_max2"
]
=
state
[
"new_max2"
],
state
[
"max2"
]
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
config
[
'block_wise'
]:
elif
state
[
"state1"
].
dtype
==
torch
.
uint8
and
config
[
"block_wise"
]:
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
state
[
'state2'
],
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
F
.
optimizer_update_8bit_blockwise
(
config
[
'eps'
],
step
,
config
[
'lr'
],
self
.
optimizer_name
,
state
[
'qmap1'
],
state
[
'qmap2'
],
state
[
'absmax1'
],
state
[
'absmax2'
],
grad
,
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
config
[
'skip_zeros'
])
p
,
state
[
"state1"
],
state
[
"state2"
],
config
[
"betas"
][
0
],
config
[
"betas"
][
1
],
config
[
"eps"
],
step
,
config
[
"lr"
],
state
[
"qmap1"
],
state
[
"qmap2"
],
state
[
"absmax1"
],
state
[
"absmax2"
],
config
[
"weight_decay"
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
config
[
"skip_zeros"
],
)
class
Optimizer1State
(
Optimizer8bit
):
class
Optimizer1State
(
Optimizer8bit
):
def
__init__
(
self
,
optimizer_name
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.0
),
eps
=
1e-8
,
def
__init__
(
weight_decay
=
0.0
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
,
optimizer_name
,
skip_zeros
=
False
):
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.0
),
eps
=
1e-8
,
weight_decay
=
0.0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
max_unorm
=
0.0
,
skip_zeros
=
False
,
):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
...
@@ -395,19 +525,18 @@ class Optimizer1State(Optimizer8bit):
...
@@ -395,19 +525,18 @@ class Optimizer1State(Optimizer8bit):
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
weight_decay
=
weight_decay
)
super
(
Optimizer1State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
super
(
Optimizer1State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
if
args
is
None
:
if
args
is
None
:
args
=
{}
args
=
{}
args
[
'
optim_bits
'
]
=
optim_bits
args
[
"
optim_bits
"
]
=
optim_bits
args
[
'
percentile_clipping
'
]
=
100
args
[
"
percentile_clipping
"
]
=
100
args
[
'
min_8bit_size
'
]
=
min_8bit_size
args
[
"
min_8bit_size
"
]
=
min_8bit_size
args
[
'
percentile_clipping
'
]
=
percentile_clipping
args
[
"
percentile_clipping
"
]
=
percentile_clipping
args
[
'
block_wise
'
]
=
block_wise
args
[
"
block_wise
"
]
=
block_wise
args
[
'
max_unorm
'
]
=
max_unorm
args
[
"
max_unorm
"
]
=
max_unorm
args
[
'
skip_zeros
'
]
=
skip_zeros
args
[
"
skip_zeros
"
]
=
skip_zeros
self
.
args
=
MockArgs
(
args
)
self
.
args
=
MockArgs
(
args
)
else
:
else
:
...
@@ -419,43 +548,61 @@ class Optimizer1State(Optimizer8bit):
...
@@ -419,43 +548,61 @@ class Optimizer1State(Optimizer8bit):
def
init_state
(
self
,
group
,
p
,
gindex
,
pindex
):
def
init_state
(
self
,
group
,
p
,
gindex
,
pindex
):
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
if
config
[
'
optim_bits
'
]
==
32
:
if
config
[
"
optim_bits
"
]
==
32
:
dtype
=
torch
.
float32
dtype
=
torch
.
float32
elif
config
[
'
optim_bits
'
]
==
8
:
elif
config
[
"
optim_bits
"
]
==
8
:
dtype
=
torch
.
uint8
dtype
=
torch
.
uint8
else
:
raise
NotImplementedError
(
f
'Amount of optimizer bits not supported:
{
config
[
"optim_bits"
]
}
'
)
else
:
raise
NotImplementedError
(
f
'Amount of optimizer bits not supported:
{
config
[
"optim_bits"
]
}
'
)
if
p
.
numel
()
<
config
[
'min_8bit_size'
]:
dtype
=
torch
.
float32
if
p
.
numel
()
<
config
[
"min_8bit_size"
]:
dtype
=
torch
.
float32
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
state
[
'
step
'
]
=
0
state
[
"
step
"
]
=
0
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
state
[
'state1'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float32
,
device
=
p
.
device
,
)
elif
dtype
==
torch
.
uint8
:
elif
dtype
==
torch
.
uint8
:
if
state
[
'step'
]
==
0
:
if
state
[
"step"
]
==
0
:
if
'dynamic'
not
in
self
.
name2qmap
:
self
.
fill_qmap
()
if
"dynamic"
not
in
self
.
name2qmap
:
self
.
name2qmap
[
'dynamic'
]
=
self
.
name2qmap
[
'dynamic'
].
to
(
p
.
device
)
self
.
fill_qmap
()
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
state
[
'state1'
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
)
state
[
'qmap1'
]
=
self
.
name2qmap
[
'dynamic'
]
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
if
config
[
'block_wise'
]:
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
uint8
,
device
=
p
.
device
,
)
state
[
"qmap1"
]
=
self
.
name2qmap
[
"dynamic"
]
if
config
[
"block_wise"
]:
n
=
p
.
numel
()
n
=
p
.
numel
()
blocks
=
n
//
2048
blocks
=
n
//
2048
blocks
+=
1
if
n
%
2048
>
0
else
0
blocks
+=
1
if
n
%
2048
>
0
else
0
state
[
'absmax1'
]
=
torch
.
zeros
((
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"absmax1"
]
=
torch
.
zeros
(
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
else
:
else
:
state
[
'max1'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max1"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
'new_max1'
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
if
config
[
'percentile_clipping'
]
<
100
:
)
state
[
'gnorm_vec'
]
=
torch
.
zeros
((
100
,),
device
=
p
.
device
)
if
config
[
'max_unorm'
]
>
0.
0
:
if
config
[
"percentile_clipping"
]
<
10
0
:
state
[
'u
norm_vec
'
]
=
torch
.
zeros
((
1
,),
device
=
p
.
device
)
state
[
"g
norm_vec
"
]
=
torch
.
zeros
((
1
00
,),
device
=
p
.
device
)
if
config
[
"max_unorm"
]
>
0.0
:
state
[
"unorm_vec"
]
=
torch
.
zeros
((
1
,),
device
=
p
.
device
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
...
@@ -464,29 +611,77 @@ class Optimizer1State(Optimizer8bit):
...
@@ -464,29 +611,77 @@ class Optimizer1State(Optimizer8bit):
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
config
=
self
.
get_config
(
gindex
,
pindex
,
group
)
state
[
'
step
'
]
+=
1
state
[
"
step
"
]
+=
1
step
=
state
[
'
step
'
]
step
=
state
[
"
step
"
]
if
config
[
'percentile_clipping'
]
<
100
:
if
config
[
"percentile_clipping"
]
<
100
:
current_gnorm
,
clip_value
,
gnorm_scale
=
F
.
percentile_clipping
(
grad
,
state
[
'gnorm_vec'
],
step
,
config
[
'percentile_clipping'
])
current_gnorm
,
clip_value
,
gnorm_scale
=
F
.
percentile_clipping
(
grad
,
state
[
"gnorm_vec"
],
step
,
config
[
"percentile_clipping"
]
)
else
:
else
:
gnorm_scale
=
1.0
gnorm_scale
=
1.0
if
state
[
'state1'
].
dtype
==
torch
.
float
:
if
state
[
"state1"
].
dtype
==
torch
.
float
:
F
.
optimizer_update_32bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
config
[
'betas'
][
0
],
config
[
'eps'
],
step
,
config
[
'lr'
],
F
.
optimizer_update_32bit
(
None
,
0.0
,
config
[
'weight_decay'
],
gnorm_scale
,
self
.
optimizer_name
,
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
],
grad
,
skip_zeros
=
config
[
'skip_zeros'
])
p
,
state
[
"state1"
],
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
not
config
[
'block_wise'
]:
config
[
"betas"
][
0
],
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
None
,
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
config
[
"eps"
],
config
[
'eps'
],
step
,
config
[
'lr'
],
state
[
'qmap1'
],
None
,
state
[
'max1'
],
None
,
state
[
'new_max1'
],
None
,
step
,
config
[
'weight_decay'
],
gnorm_scale
,
config
[
"lr"
],
state
[
'unorm_vec'
]
if
config
[
'max_unorm'
]
>
0.0
else
None
,
max_unorm
=
config
[
'max_unorm'
])
None
,
0.0
,
state
[
'max1'
],
state
[
'new_max1'
]
=
state
[
'new_max1'
],
state
[
'max1'
]
config
[
"weight_decay"
],
elif
state
[
'state1'
].
dtype
==
torch
.
uint8
and
config
[
'block_wise'
]:
gnorm_scale
,
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
'state1'
],
None
,
config
[
'betas'
][
0
],
config
[
'betas'
][
1
],
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
config
[
'eps'
],
step
,
config
[
'lr'
],
max_unorm
=
config
[
"max_unorm"
],
state
[
'qmap1'
],
None
,
state
[
'absmax1'
],
None
,
skip_zeros
=
config
[
"skip_zeros"
],
config
[
'weight_decay'
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
config
[
'skip_zeros'
])
)
elif
state
[
"state1"
].
dtype
==
torch
.
uint8
and
not
config
[
"block_wise"
]:
F
.
optimizer_update_8bit
(
self
.
optimizer_name
,
grad
,
p
,
state
[
"state1"
],
None
,
config
[
"betas"
][
0
],
config
[
"betas"
][
1
],
config
[
"eps"
],
step
,
config
[
"lr"
],
state
[
"qmap1"
],
None
,
state
[
"max1"
],
None
,
state
[
"new_max1"
],
None
,
config
[
"weight_decay"
],
gnorm_scale
,
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
max_unorm
=
config
[
"max_unorm"
],
)
state
[
"max1"
],
state
[
"new_max1"
]
=
state
[
"new_max1"
],
state
[
"max1"
]
elif
state
[
"state1"
].
dtype
==
torch
.
uint8
and
config
[
"block_wise"
]:
F
.
optimizer_update_8bit_blockwise
(
self
.
optimizer_name
,
grad
,
p
,
state
[
"state1"
],
None
,
config
[
"betas"
][
0
],
config
[
"betas"
][
1
],
config
[
"eps"
],
step
,
config
[
"lr"
],
state
[
"qmap1"
],
None
,
state
[
"absmax1"
],
None
,
config
[
"weight_decay"
],
gnorm_scale
=
gnorm_scale
,
skip_zeros
=
config
[
"skip_zeros"
],
)
bitsandbytes/optim/rmsprop.py
View file @
bfa0e332
...
@@ -4,33 +4,106 @@
...
@@ -4,33 +4,106 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
RMSprop
(
Optimizer1State
):
class
RMSprop
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
optim_bits
=
32
,
args
=
None
,
def
__init__
(
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
super
(
RMSprop
,
self
).
__init__
(
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
RMSprop8bit
(
Optimizer1State
):
class
RMSprop8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
def
__init__
(
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop8bit
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
super
(
RMSprop8bit
,
self
).
__init__
(
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
RMSprop32bit
(
Optimizer1State
):
class
RMSprop32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
def
__init__
(
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
self
,
params
,
lr
=
1e-2
,
alpha
=
0.99
,
eps
=
1e-8
,
weight_decay
=
0
,
momentum
=
0
,
centered
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
alpha
==
0
:
if
alpha
==
0
:
raise
NotImplementedError
(
f
'
RMSprop with alpha==0.0 is not supported!
'
)
raise
NotImplementedError
(
f
"
RMSprop with alpha==0.0 is not supported!
"
)
if
centered
:
if
centered
:
raise
NotImplementedError
(
f
'Centered RMSprop is not supported!'
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop32bit
,
self
).
__init__
(
'rmsprop'
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
super
(
RMSprop32bit
,
self
).
__init__
(
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"rmsprop"
,
params
,
lr
,
(
alpha
,
momentum
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/optim/sgd.py
View file @
bfa0e332
...
@@ -4,29 +4,96 @@
...
@@ -4,29 +4,96 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
SGD
(
Optimizer1State
):
class
SGD
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
SGD
,
self
).
__init__
(
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
SGD8bit
(
Optimizer1State
):
class
SGD8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD8bit
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
SGD8bit
,
self
).
__init__
(
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
SGD32bit
(
Optimizer1State
):
class
SGD32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
self
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
params
,
lr
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
'SGD without momentum is not supported!'
)
raise
NotImplementedError
(
f
"SGD without momentum is not supported!"
)
super
(
SGD32bit
,
self
).
__init__
(
'momentum'
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
super
(
SGD32bit
,
self
).
__init__
(
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
"momentum"
,
params
,
lr
,
(
momentum
,
dampening
),
0.0
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
bitsandbytes/utils.py
View file @
bfa0e332
import
sys
import
sys
def
print_err
(
s
:
str
)
->
None
:
def
print_err
(
s
:
str
)
->
None
:
print
(
s
,
file
=
sys
.
stderr
)
print
(
s
,
file
=
sys
.
stderr
)
def
warn_of_missing_prerequisite
(
s
:
str
)
->
None
:
def
warn_of_missing_prerequisite
(
s
:
str
)
->
None
:
print_err
(
'
WARNING, missing pre-requisite:
'
+
s
)
print_err
(
"
WARNING, missing pre-requisite:
"
+
s
)
quicktest.py
View file @
bfa0e332
from
itertools
import
product
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
itertools
import
product
def
test_igemmlt
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
def
test_igemmlt
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
k
=
25
k
=
25
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
'
col32
'
)
A2
,
SA
=
F
.
transform
(
A
,
"
col32
"
)
B2
,
SB
=
F
.
transform
(
B
,
'
colx
'
)
B2
,
SB
=
F
.
transform
(
B
,
"
colx
"
)
if
dims
==
2
:
if
dims
==
2
:
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
),
'col32'
)
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"col32"
,
)
else
:
else
:
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
),
'col32'
)
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"col32"
,
)
F
.
igemmlt
(
A2
,
B2
,
C2
,
SA
,
SB
,
SC
)
F
.
igemmlt
(
A2
,
B2
,
C2
,
SA
,
SB
,
SC
)
C3
,
S
=
F
.
transform
(
C2
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
transform
(
C2
,
"
row
"
,
state
=
SC
)
#torch.testing.assert_allclose(C1, C3.float())
#
torch.testing.assert_allclose(C1, C3.float())
#print(C1)
#
print(C1)
#print(C2)
#
print(C2)
#print(C3)
#
print(C3)
allclose
=
torch
.
allclose
(
C1
,
C3
.
float
())
allclose
=
torch
.
allclose
(
C1
,
C3
.
float
())
if
allclose
:
if
allclose
:
print
(
C1
)
print
(
C1
)
...
@@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
...
@@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
print
(
C3
)
print
(
C3
)
## transposed
## transposed
#A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
#
A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
#if dims == 2:
#
if dims == 2:
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float().t())
# C1 = torch.matmul(A.float(), B.float().t())
#elif dims == 3:
#
elif dims == 3:
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.float(), A.t().float())
# C1 = torch.matmul(B.float(), A.t().float())
# C1 = C1.permute([2, 0, 1])
# C1 = C1.permute([2, 0, 1])
#A2, SA = F.transform(A, 'col32')
#
A2, SA = F.transform(A, 'col32')
#B2, SB = F.transform(B, 'colx')
#
B2, SB = F.transform(B, 'colx')
#if dims == 2:
#
if dims == 2:
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
#else:
#
else:
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
# state = (C2.shape, 'row', A.shape[0])
# state = (C2.shape, 'row', A.shape[0])
# C2, SC = F.transform(C2, 'col32', state=state)
# C2, SC = F.transform(C2, 'col32', state=state)
#F.igemmlt(A2, B2, C2, SA, SB, SC)
#
F.igemmlt(A2, B2, C2, SA, SB, SC)
#C3, S = F.transform(C2, 'row', state=SC, ld=[0])
#
C3, S = F.transform(C2, 'row', state=SC, ld=[0])
#torch.testing.assert_allclose(C1, C3.float())
#
torch.testing.assert_allclose(C1, C3.float())
## weight update
## weight update
#if dims == 3:
#
if dims == 3:
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
...
@@ -73,18 +87,18 @@ dims = (2, 3)
...
@@ -73,18 +87,18 @@ dims = (2, 3)
ldb
=
[
0
]
ldb
=
[
0
]
n
=
2
n
=
2
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
for
ldb
in
range
(
32
,
4096
,
32
):
for
ldb
in
range
(
32
,
4096
,
32
):
#
for ldb in [None]:
#
for ldb in [None]:
val
=
test_igemmlt
(
2
,
2
,
2
,
2
,
2
,
ldb
)
val
=
test_igemmlt
(
2
,
2
,
2
,
2
,
2
,
ldb
)
if
val
:
if
val
:
print
(
val
,
ldb
)
print
(
val
,
ldb
)
else
:
else
:
print
(
'
nope
'
,
ldb
)
print
(
"
nope
"
,
ldb
)
#for val in values:
#
for val in values:
#test_igemmlt(*val)
#
test_igemmlt(*val)
setup.py
View file @
bfa0e332
...
@@ -2,18 +2,20 @@
...
@@ -2,18 +2,20 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
os
import
glob
import
glob
from
setuptools
import
setup
,
find_package
s
import
o
s
from
setuptools
import
find_packages
,
setup
libs
=
list
(
glob
.
glob
(
'
./bitsandbytes/libbitsandbytes*.so
'
))
libs
=
list
(
glob
.
glob
(
"
./bitsandbytes/libbitsandbytes*.so
"
))
libs
=
[
os
.
path
.
basename
(
p
)
for
p
in
libs
]
libs
=
[
os
.
path
.
basename
(
p
)
for
p
in
libs
]
print
(
'libs:'
,
libs
)
print
(
"libs:"
,
libs
)
def
read
(
fname
):
def
read
(
fname
):
return
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
fname
)).
read
()
return
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
fname
)).
read
()
setup
(
setup
(
name
=
f
"bitsandbytes"
,
name
=
f
"bitsandbytes"
,
version
=
f
"0.31.0"
,
version
=
f
"0.31.0"
,
...
@@ -27,11 +29,11 @@ setup(
...
@@ -27,11 +29,11 @@ setup(
entry_points
=
{
entry_points
=
{
"console_scripts"
:
[
"debug_cuda = bitsandbytes.debug_cli:cli"
],
"console_scripts"
:
[
"debug_cuda = bitsandbytes.debug_cli:cli"
],
},
},
package_data
=
{
''
:
libs
},
package_data
=
{
""
:
libs
},
long_description
=
read
(
'
README.md
'
),
long_description
=
read
(
"
README.md
"
),
long_description_content_type
=
'
text/markdown
'
,
long_description_content_type
=
"
text/markdown
"
,
classifiers
=
[
classifiers
=
[
"Development Status :: 4 - Beta"
,
"Development Status :: 4 - Beta"
,
'
Topic :: Scientific/Engineering :: Artificial Intelligence
'
"
Topic :: Scientific/Engineering :: Artificial Intelligence
"
,
],
],
)
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment