Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c771b3a7
Commit
c771b3a7
authored
Jul 22, 2022
by
Tim Dettmers
Browse files
Most tests passing.
parent
4cd7ea62
Changes
16
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
5269 additions
and
159 deletions
+5269
-159
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+2
-1
bitsandbytes/autograd/__init__.py
bitsandbytes/autograd/__init__.py
+0
-0
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+307
-0
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+2
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+867
-2
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+122
-2
csrc/kernels.cu
csrc/kernels.cu
+874
-0
csrc/kernels.cuh
csrc/kernels.cuh
+12
-0
csrc/ops.cu
csrc/ops.cu
+406
-0
csrc/ops.cuh
csrc/ops.cuh
+104
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+126
-1
tests/test_autograd.py
tests/test_autograd.py
+270
-0
tests/test_functional.py
tests/test_functional.py
+1704
-59
tests/test_modules.py
tests/test_modules.py
+453
-25
tests/test_optim.py
tests/test_optim.py
+19
-68
No files found.
bitsandbytes/__init__.py
View file @
c771b3a7
...
...
@@ -4,12 +4,13 @@
# LICENSE file in the root directory of this source tree.
from
.nn
import
modules
from
.autograd._functions
import
mm_cublas
,
bmm_cublas
,
matmul_cublas
,
matmul
,
MatmulLtState
from
.cextension
import
COMPILED_WITH_CUDA
if
COMPILED_WITH_CUDA
:
from
.optim
import
adam
__pdoc__
=
{
'lib
B
its
NB
ytes'
:
False
,
__pdoc__
=
{
'lib
b
its
andb
ytes'
:
False
,
'optim.optimizer.Optimizer8bit'
:
False
,
'optim.optimizer.MockArgs'
:
False
}
bitsandbytes/autograd/__init__.py
0 → 100644
View file @
c771b3a7
bitsandbytes/autograd/_functions.py
0 → 100644
View file @
c771b3a7
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
from
dataclasses
import
dataclass
tensor
=
torch
.
Tensor
'''
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
'''
class
GlobalOutlierPooler
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'Call get_instance() instead'
)
def
initialize
(
self
):
self
.
outliers
=
set
()
self
.
model_dim
=
None
@
classmethod
def
get_instance
(
cls
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
cls
.
__new__
(
cls
)
cls
.
_instance
.
initialize
()
return
cls
.
_instance
def
add_outliers
(
self
,
outlier_idx
,
feature_dim
):
if
self
.
model_dim
is
None
:
self
.
model_dim
=
feature_dim
if
feature_dim
!=
self
.
model_dim
:
return
# we do not encode outliers for the 2nd FFN layer
self
.
outliers
.
update
(
outlier_idx
.
tolist
())
def
get_current_outlier_idx
(
self
):
return
torch
.
Tensor
(
list
(
self
.
outliers
)).
to
(
torch
.
int64
)
class
MatMul8bit
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
quant_type
=
'vector'
,
precision
=
[
8
,
8
,
8
]):
if
precision
[
0
]
!=
8
:
with
torch
.
no_grad
():
output
=
torch
.
matmul
(
A
,
B
)
else
:
if
len
(
B
.
shape
)
==
2
:
dim
=
0
else
:
dim
=
1
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=-
1
,
quant_type
=
quant_type
)
qB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
dim
,
quant_type
=
quant_type
)
iout
=
F
.
igemm
(
qA
,
qB
)
output
=
F
.
vectorwise_mm_dequant
(
iout
,
SA
,
SB
,
A
.
dtype
,
quant_type
)
if
A
.
requires_grad
or
B
.
requires_grad
:
ctx
.
save_for_backward
(
A
,
B
)
ctx
.
quant_type
=
quant_type
ctx
.
precision
=
precision
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
A
,
B
=
ctx
.
saved_tensors
quant_type
=
ctx
.
quant_type
precision
=
ctx
.
precision
grad_A
=
grad_B
=
None
if
B
.
requires_grad
:
if
len
(
A
.
shape
)
==
3
:
dims
=
[
0
,
1
]
# bsi -> ibs
permute_dim
=
[
0
,
2
,
1
]
else
:
dims
=
[
0
]
# bs -> sb
permute_dim
=
[
1
,
0
]
if
precision
[
1
]
!=
8
:
with
torch
.
no_grad
():
grad_B
=
torch
.
matmul
(
A
.
permute
(
permute_dim
),
grad_output
)
else
:
if
len
(
B
.
shape
)
==
2
and
len
(
A
.
shape
)
==
3
:
grad_output
=
grad_output
.
contiguous
()
if
not
grad_output
.
is_contiguous
():
grad_output
.
contiguous
()
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
if
not
A
.
is_contiguous
():
A
=
A
.
contiguous
()
qA
,
S2
=
F
.
vectorwise_quant
(
A
.
view
(
-
1
,
A
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
t
(),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
t
(),
S1
,
grad_output
.
dtype
,
quant_type
)
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
permute
(
permute_dim
),
S1
,
grad_output
.
dtype
,
quant_type
)
if
A
.
requires_grad
:
if
len
(
grad_output
.
shape
)
==
3
:
dims
=
[
2
]
else
:
dims
=
[
1
]
if
len
(
B
.
shape
)
==
3
:
# bio -> boi
permute_dim
=
[
0
,
2
,
1
]
dim_B
=
dims
else
:
# io -> oi
permute_dim
=
[
1
,
0
]
dim_B
=
[
1
]
if
precision
[
2
]
!=
8
:
with
torch
.
no_grad
():
grad_A
=
torch
.
matmul
(
grad_output
,
B
.
permute
(
permute_dim
))
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
)
return
grad_A
,
grad_B
,
None
,
None
,
None
mm_cublas
=
MatMul8bit
.
apply
bmm_cublas
=
MatMul8bit
.
apply
matmul_cublas
=
MatMul8bit
.
apply
@
dataclass
class
MatmulLtState
:
CB
=
None
CxB
=
None
SB
=
None
SCB
=
None
CxBt
=
None
SBt
=
None
CBt
=
None
subB
=
None
outlier_pool
=
None
has_accumulated_gradients
=
False
threshold
=
0.0
idx
=
None
is_training
=
True
has_fp16_weights
=
True
use_pool
=
False
formatB
=
F
.
get_special_format_str
()
def
reset_grads
(
self
):
self
.
CB
=
None
self
.
CxB
=
None
self
.
SB
=
None
self
.
SCB
=
None
self
.
CxBt
=
None
self
.
SBt
=
None
self
.
CBt
=
None
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
requires_gradA
=
A
.
requires_grad
requires_gradB
=
B
.
requires_grad
formatB
=
state
.
formatB
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
assert
A
.
dtype
==
torch
.
float16
,
f
'The input data type needs to be fp16 but
{
A
.
dtype
}
was found!'
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
has_fp16_weights
:
idx
=
torch
.
unique
(
coo_tensorA
.
colidx
).
long
()
CA
[:,
idx
]
=
0
CAt
[:,
idx
]
=
0
subA
=
A
[:,
idx
]
state
.
subB
=
B
[:,
idx
].
t
().
contiguous
()
state
.
idx
=
idx
else
:
if
state
.
CxB
is
None
:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
and
state
.
idx
is
None
and
state
.
CB
is
not
None
:
# generate outlier index and subB
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
).
long
()
state
.
outlier_pool
.
add_outliers
(
outlier_idx
,
A
.
shape
[
-
1
])
if
state
.
use_pool
and
state
.
outlier_pool
.
model_dim
==
A
.
shape
[
-
1
]:
# do not use pool for 2nd FFN layer
state
.
idx
=
state
.
outlier_pool
.
get_current_outlier_idx
().
to
(
A
.
device
)
else
:
state
.
idx
=
outlier_idx
state
.
subB
=
(
state
.
CB
[:,
state
.
idx
].
float
().
t
().
contiguous
()
*
(
state
.
SCB
/
127
)).
half
()
if
state
.
idx
is
not
None
:
# extract outliers
CA
[:,
state
.
idx
]
=
0
CAt
[:,
state
.
idx
]
=
0
subA
=
A
[:,
state
.
idx
]
else
:
subA
=
None
else
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
subA
=
None
C32A
,
SA
=
F
.
transform
(
CA
,
'col32'
)
# 2. Quantize B
if
state
.
has_fp16_weights
:
has_grad
=
(
True
if
(
getattr
(
B
,
'grad'
,
None
)
is
not
None
)
else
False
)
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
if
is_transposed
:
B
=
B
.
contiguous
()
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
state
.
reset_grads
()
CB
,
state
.
CBt
,
state
.
SCB
,
state
.
SCBt
,
coo_tensorB
=
F
.
double_quant
(
B
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
else
:
has_grad
=
False
shapeB
=
state
.
SB
[
0
]
if
len
(
input_shape
)
==
3
:
output_shape
=
(
input_shape
[
0
],
input_shape
[
1
],
shapeB
[
0
])
else
:
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
# 3. Matmul
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
# 4. Mixed-precision decomposition matmul
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
and
subA
is
not
None
:
output
+=
torch
.
matmul
(
subA
,
state
.
subB
)
# 5. Save state
ctx
.
state
=
state
ctx
.
formatB
=
formatB
ctx
.
grad_shape
=
input_shape
ctx
.
req_grads
=
[
requires_gradA
,
requires_gradB
]
if
requires_gradA
or
requires_gradB
:
ctx
.
tensors
=
(
CAt
,
subA
)
ctx
.
tensor_states
=
(
SCAt
,
state
.
idx
)
else
:
ctx
.
tensors
=
[
None
,
None
]
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func
=
torch
.
clone
return
clone_func
(
output
.
view
(
output_shape
))
@
staticmethod
def
backward
(
ctx
,
grad_output
):
req_gradA
,
req_gradB
=
ctx
.
req_grads
CAt
,
subA
=
ctx
.
tensors
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
state
=
ctx
.
state
assert
state
.
has_fp16_weights
,
'Backprop only supported for fp16 weights.'
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
grad_A
=
grad_B
=
None
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
'col32'
,
transpose
=
True
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
'col32'
)
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
matmul
=
MatMul8bitLt
.
apply
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
):
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
state
.
threshold
=
threshold
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
state
)
bitsandbytes/cextension.py
View file @
c771b3a7
...
...
@@ -6,6 +6,8 @@ lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
try
:
lib
.
cadam32bit_g32
lib
.
get_context
.
restype
=
ct
.
c_void_p
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
COMPILED_WITH_CUDA
=
True
except
AttributeError
:
warn
(
"The installed version of bitsandbytes was compiled without GPU support. "
...
...
bitsandbytes/functional.py
View file @
c771b3a7
This diff is collapsed.
Click to expand it.
bitsandbytes/nn/__init__.py
View file @
c771b3a7
...
...
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.modules
import
StableEmbedding
,
Embedding
from
.modules
import
StableEmbedding
,
Linear8bit
,
Linear8bitLt
,
Int8Params
bitsandbytes/nn/modules.py
View file @
c771b3a7
...
...
@@ -3,14 +3,19 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
bitsandbytes
as
bnb
from
typing
import
Optional
from
typing
import
Union
,
Tuple
,
Any
,
Callable
,
Iterator
,
Set
,
Optional
,
overload
,
TypeVar
,
Mapping
,
Dict
from
torch
import
Tensor
from
torch
import
Tensor
,
device
,
dtype
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
import
torch.nn.functional
as
F
from
bitsandbytes.optim
import
GlobalOptimManager
T
=
TypeVar
(
'T'
,
bound
=
'torch.nn.Module'
)
class
StableEmbedding
(
torch
.
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
...
...
@@ -70,3 +75,118 @@ class Embedding(torch.nn.Embedding):
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
return
emb
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
SCB
=
None
if
data
is
None
:
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
def
cuda
(
self
,
device
):
if
self
.
has_fp16_weights
:
return
super
().
cuda
(
device
)
else
:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
bnb
.
functional
.
double_quant
(
B
)
del
CBt
del
SCBt
self
.
data
=
CB
setattr
(
self
,
'CB'
,
CB
)
setattr
(
self
,
'SCB'
,
SCB
)
return
self
@
overload
def
to
(
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
non_blocking
:
bool
=
...)
->
T
:
...
@
overload
def
to
(
self
:
T
,
dtype
:
Union
[
dtype
,
str
],
non_blocking
:
bool
=
...)
->
T
:
...
@
overload
def
to
(
self
:
T
,
tensor
:
Tensor
,
non_blocking
:
bool
=
...)
->
T
:
...
def
to
(
self
,
*
args
,
**
kwargs
):
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
device
is
not
None
and
device
.
type
==
'cuda'
and
self
.
data
.
device
.
type
==
'cpu'
:
return
self
.
cuda
(
device
)
else
:
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
)
new_param
.
CB
=
self
.
CB
new_param
.
SCB
=
self
.
SCB
return
new_param
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
):
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
index
=
index
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
)
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
weight
.
CB
=
None
self
.
weight
.
SCB
=
None
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
#assert not self.state.has_fp16_weights
#if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
state
=
self
.
state
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
.
unsqueeze
(
0
).
expand_as
(
out
)
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
not
None
:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del
self
.
state
.
CB
self
.
weight
.
data
=
self
.
state
.
CxB
return
out
class
Linear8bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
quant_type
=
'vector'
,
index
=
None
,
args
=
None
,
sparse_decomp
=
False
):
super
(
Linear8bit
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
quant_type
=
quant_type
self
.
index
=
index
self
.
args
=
args
self
.
iter
=
0
def
forward
(
self
,
x
):
self
.
iter
+=
1
if
self
.
iter
%
self
.
args
.
clip_freq
==
0
:
with
torch
.
no_grad
():
maxval
,
maxidx
=
torch
.
topk
(
torch
.
abs
(
self
.
weight
.
flatten
()),
k
=
self
.
args
.
clip_idx
)
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
print
(
'clip'
,
maxval
[
-
1
].
item
())
self
.
weight
.
clip_
(
-
maxval
[
-
1
],
maxval
[
-
1
])
if
self
.
args
is
not
None
:
out
=
bnb
.
nn
.
functional
.
sparse_decomposed_linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
qval
=
self
.
args
.
sparse_decomp_val
,
quant_type
=
self
.
args
.
quant_type
)
else
:
out
=
bnb
.
nn
.
functional
.
linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
quant_type
=
self
.
args
.
quant_type
)
return
out
csrc/kernels.cu
View file @
c771b3a7
This diff is collapsed.
Click to expand it.
csrc/kernels.cuh
View file @
c771b3a7
...
...
@@ -106,6 +106,18 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
__global__
void
kHistogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
const
int
maxidx1
,
const
int
n
);
template
<
typename
T
,
int
SPMM_ITEMS
,
int
BITS
>
__global__
void
kspmm_coo_very_sparse_naive
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
T
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
<
int
ITEMS_PER_THREAD
,
int
SUBTILE_ROWS
,
int
THREADS
>
__global__
void
kdequant_mm_int32_fp16
(
int
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
const
int
numRows
,
const
int
numCols
,
const
int
tileCols
,
const
int
n
);
template
<
typename
T
,
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
SPARSE_DECOMP
>
__global__
void
kgetColRowStats
(
T
*
__restrict__
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
,
int
tiledRows
,
int
tiledCols
);
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
SPARSE_DECOMP
>
__global__
void
kDoubleRowColQuant
(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
);
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
TRANSPOSE
,
int
FORMAT
>
__global__
void
kTransformRowToFormat
(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
#endif
csrc/ops.cu
View file @
c771b3a7
...
...
@@ -8,6 +8,7 @@
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
#include <cassert>
#include <common.h>
...
...
@@ -188,11 +189,416 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
void
gemmex
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
)
{
const
int
falpha
=
1
;
const
int
fbeta
=
0
;
const
void
*
alpha
=
&
falpha
;
const
void
*
beta
=
&
fbeta
;
cublasStatus_t
status
;
status
=
cublasGemmEx
(
context
->
m_handle
,
transposeA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
transposeB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_8I
,
lda
,
B
,
CUDA_R_8I
,
ldb
,
beta
,
C
,
CUDA_R_32I
,
ldc
,
CUDA_R_32I
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
std
::
cout
<<
"CUBLAS ERROR: Status "
<<
status
<<
std
::
endl
;
}
}
void
strided_gemmex
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
,
long
long
int
strideA
,
long
long
int
strideB
,
long
long
int
strideC
,
int
batchCount
)
{
const
int
falpha
=
1
;
const
int
fbeta
=
0
;
const
void
*
alpha
=
&
falpha
;
const
void
*
beta
=
&
fbeta
;
cublasStatus_t
status
;
//cout << transposeA << transposeB << endl;
//printf("%i %i %i\n", m,n,k);
//printf("%i %i %i\n", lda,ldb,ldc);
//printf("%i %i %i\n", strideA, strideB, strideC);
//printf("%i\n", batchCount);
status
=
cublasGemmStridedBatchedEx
(
context
->
m_handle
,
transposeA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
transposeB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_8I
,
lda
,
(
long
long
int
)
strideA
,
B
,
CUDA_R_8I
,
ldb
,
(
long
long
int
)
strideB
,
beta
,
C
,
CUDA_R_32I
,
ldc
,
(
long
long
int
)
strideC
,
batchCount
,
CUDA_R_32I
,
CUBLAS_GEMM_DEFAULT
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
std
::
cout
<<
"CUBLAS ERROR: Status "
<<
status
<<
std
::
endl
;
}
}
int
roundoff
(
int
v
,
int
d
)
{
return
(
v
+
d
-
1
)
/
d
*
d
;
}
template
<
int
ORDER
>
cublasLtOrder_t
get_order
()
{
switch
(
ORDER
)
{
case
ROW
:
return
CUBLASLT_ORDER_ROW
;
break
;
case
COL
:
return
CUBLASLT_ORDER_COL
;
break
;
case
COL32
:
return
CUBLASLT_ORDER_COL32
;
break
;
case
COL_TURING
:
return
CUBLASLT_ORDER_COL4_4R2_8C
;
break
;
case
COL_AMPERE
:
return
CUBLASLT_ORDER_COL32_2R_4R4
;
break
;
}
}
template
cublasLtOrder_t
get_order
<
ROW
>();
template
cublasLtOrder_t
get_order
<
COL
>();
template
cublasLtOrder_t
get_order
<
COL32
>();
template
cublasLtOrder_t
get_order
<
COL_TURING
>();
template
cublasLtOrder_t
get_order
<
COL_AMPERE
>();
template
<
int
ORDER
>
int
get_leading_dim
(
int
dim1
,
int
dim2
)
{
switch
(
ORDER
)
{
case
ROW
:
return
dim2
;
break
;
case
COL
:
return
dim1
;
break
;
case
COL32
:
// 32*row tiles
return
dim1
*
32
;
break
;
case
COL_TURING
:
return
32
*
roundoff
(
dim1
,
8
);
break
;
case
COL_AMPERE
:
// 32*32 tiles
return
32
*
roundoff
(
dim1
,
32
);
break
;
}
}
template
int
get_leading_dim
<
ROW
>(
int
dim1
,
int
dim2
);
template
int
get_leading_dim
<
COL
>(
int
dim1
,
int
dim2
);
template
int
get_leading_dim
<
COL32
>(
int
dim1
,
int
dim2
);
template
<
typename
T
,
int
SRC
,
int
TARGET
,
bool
transpose
,
int
DTYPE
>
void
transform
(
cublasLtHandle_t
ltHandle
,
T
*
A
,
T
*
out
,
int
dim1
,
int
dim2
)
{
cublasLtOrder_t
orderA
=
get_order
<
SRC
>
();
cublasLtOrder_t
orderOut
=
get_order
<
TARGET
>
();
int
ldA
=
get_leading_dim
<
SRC
>
(
dim1
,
dim2
);
int
ldOut
=
get_leading_dim
<
TARGET
>
(
dim1
,
dim2
);
cublasLtMatrixLayout_t
A_desc
=
NULL
,
out_desc
=
NULL
;
cublasLtMatrixTransformDesc_t
A2Out_desc
=
NULL
;
cublasOperation_t
opTranspose
=
CUBLAS_OP_T
;
float
transformAlpha
=
1.0
f
,
transformBeta
=
0.0
f
;
if
(
DTYPE
==
8
)
{
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
A_desc
,
CUDA_R_8I
,
dim1
,
dim2
,
ldA
));
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
out_desc
,
CUDA_R_8I
,
dim1
,
dim2
,
ldOut
));
}
else
if
(
DTYPE
==
32
)
{
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
A_desc
,
CUDA_R_32I
,
dim1
,
dim2
,
ldA
));
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
out_desc
,
CUDA_R_32I
,
dim1
,
dim2
,
ldOut
));
}
else
{
printf
(
"ERROR WRONG TYPE FOR TRANSFORM: %i
\n
"
,
DTYPE
);
}
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
A_desc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
orderA
,
sizeof
(
orderA
)));
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
out_desc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
orderOut
,
sizeof
(
orderOut
)));
checkCublasStatus
(
cublasLtMatrixTransformDescCreate
(
&
A2Out_desc
,
CUDA_R_32F
));
if
(
transpose
){
checkCublasStatus
(
cublasLtMatrixTransformDescSetAttribute
(
A2Out_desc
,
CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA
,
&
opTranspose
,
sizeof
(
opTranspose
)));
}
checkCublasStatus
(
cublasLtMatrixTransform
(
ltHandle
,
A2Out_desc
,
&
transformAlpha
,
A
,
A_desc
,
&
transformBeta
,
NULL
,
NULL
,
out
,
out_desc
,
0
));
if
(
A_desc
)
checkCublasStatus
(
cublasLtMatrixLayoutDestroy
(
A_desc
));
if
(
out_desc
)
checkCublasStatus
(
cublasLtMatrixLayoutDestroy
(
out_desc
));
if
(
A2Out_desc
)
checkCublasStatus
(
cublasLtMatrixTransformDescDestroy
(
A2Out_desc
));
}
template
void
transform
<
int8_t
,
ROW
,
COL
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int8_t
,
ROW
,
ROW
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int8_t
,
ROW
,
COL32
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int32_t
,
ROW
,
COL32
,
false
,
32
>(
cublasLtHandle_t
ltHandle
,
int32_t
*
A
,
int32_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int8_t
,
ROW
,
COL_TURING
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int8_t
,
ROW
,
COL_AMPERE
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int8_t
,
COL32
,
ROW
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int32_t
,
COL32
,
ROW
,
false
,
32
>(
cublasLtHandle_t
ltHandle
,
int32_t
*
A
,
int32_t
*
out
,
int
dim1
,
int
dim2
);
template
<
int
FORMATB
,
int
DTYPE_OUT
,
int
SCALE_ROWS
>
int
igemmlt
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
int
has_error
=
0
;
cublasLtMatmulDesc_t
matmulDesc
=
NULL
;
cublasLtMatrixLayout_t
Adesc
=
NULL
,
Bdesc
=
NULL
,
Cdesc
=
NULL
;
cublasOperation_t
opT
=
CUBLAS_OP_T
;
cublasLtPointerMode_t
alphaVec
=
CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO
;
cublasLtOrder_t
col32
=
CUBLASLT_ORDER_COL32
;
cublasLtOrder_t
col_turing
=
CUBLASLT_ORDER_COL4_4R2_8C
;
cublasLtOrder_t
col_ampere
=
CUBLASLT_ORDER_COL32_2R_4R4
;
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
Adesc
,
CUDA_R_8I
,
m
,
k
,
lda
));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
Bdesc
,
CUDA_R_8I
,
n
,
k
,
ldb
));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
Adesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
col32
,
sizeof
(
col32
)));
if
(
FORMATB
==
COL_TURING
)
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
Bdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
col_turing
,
sizeof
(
col_turing
)));
else
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
Bdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
col_ampere
,
sizeof
(
col_ampere
)));
if
(
DTYPE_OUT
==
32
)
{
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescCreate
(
&
matmulDesc
,
CUBLAS_COMPUTE_32I
,
CUDA_R_32I
));
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
matmulDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
opT
,
sizeof
(
opT
)));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
Cdesc
,
CUDA_R_32I
,
m
,
n
,
ldc
));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
Cdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
col32
,
sizeof
(
col32
)));
int
alpha
=
1
,
beta
=
0
;
has_error
|=
checkCublasStatus
(
cublasLtMatmul
(
ltHandle
,
matmulDesc
,
&
alpha
,
A
,
Adesc
,
B
,
Bdesc
,
&
beta
,
(
int32_t
*
)
C
,
Cdesc
,
(
int32_t
*
)
C
,
Cdesc
,
NULL
,
NULL
,
0
,
0
));
}
else
{
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescCreate
(
&
matmulDesc
,
CUBLAS_COMPUTE_32I
,
CUDA_R_32F
));
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
matmulDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
opT
,
sizeof
(
opT
)));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
Cdesc
,
CUDA_R_8I
,
m
,
n
,
ldc
));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
Cdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
col32
,
sizeof
(
col32
)));
if
(
!
SCALE_ROWS
)
{
float
alpha
=
1.0
f
,
beta
=
0.0
f
;
has_error
|=
checkCublasStatus
(
cublasLtMatmul
(
ltHandle
,
matmulDesc
,
&
alpha
,
A
,
Adesc
,
B
,
Bdesc
,
&
beta
,
(
int8_t
*
)
C
,
Cdesc
,
(
int8_t
*
)
C
,
Cdesc
,
NULL
,
NULL
,
0
,
0
));
}
else
{
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
matmulDesc
,
CUBLASLT_MATMUL_DESC_POINTER_MODE
,
&
alphaVec
,
sizeof
(
alphaVec
)));
has_error
|=
checkCublasStatus
(
cublasLtMatmul
(
ltHandle
,
matmulDesc
,
row_scale
,
A
,
Adesc
,
B
,
Bdesc
,
NULL
,
(
int8_t
*
)
C
,
Cdesc
,
(
int8_t
*
)
C
,
Cdesc
,
NULL
,
NULL
,
0
,
0
));
}
}
if
(
Cdesc
)
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutDestroy
(
Cdesc
));
if
(
Bdesc
)
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutDestroy
(
Bdesc
));
if
(
Adesc
)
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutDestroy
(
Adesc
));
if
(
matmulDesc
)
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescDestroy
(
matmulDesc
));
if
(
has_error
==
1
)
printf
(
"error detected"
);
return
has_error
;
}
int
fill_up_to_nearest_multiple
(
int
value
,
int
multiple
)
{
return
value
+
(
value
%
multiple
==
0
?
0
:
(
multiple
-
(
value
%
multiple
)));
}
void
dequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
int
numRows
,
int
numCols
)
{
int
threads
=
512
;
int
tileCols
=
fill_up_to_nearest_multiple
(
numCols
,
32
);
int
n
=
numRows
*
tileCols
;
int
subtile_rows
=
128
;
int
tilesize
=
32
*
subtile_rows
;
int
num_blocks
=
numRows
/
subtile_rows
;
num_blocks
+=
(
numRows
%
subtile_rows
==
0
)
?
0
:
1
;
num_blocks
=
num_blocks
*
(
tileCols
/
32
);
assert
(
threads
<=
tilesize
);
//cout << num_blocks << " blocks" << endl;
kdequant_mm_int32_fp16
<
4
,
128
,
512
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out
,
newRowStats
,
newcolStats
,
numRows
,
numCols
,
tileCols
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
#define STATS_THREADS 64
#define STATS_ITEMS 4
#define STATS_ROWS 16
void
getColRowStats
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
)
{
int
tile_cols
=
STATS_THREADS
*
STATS_ITEMS
;
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
STATS_ROWS
);
int
num_blocks
=
(
tiledCols
/
tile_cols
)
*
(
tiledRows
/
STATS_ROWS
);
if
(
nnz_threshold
==
0.0
)
kgetColRowStats
<
half
,
STATS_THREADS
,
STATS_ITEMS
,
STATS_ROWS
,
STATS_THREADS
*
STATS_ITEMS
,
0
><<<
num_blocks
,
STATS_THREADS
>>>
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
,
tiledRows
,
tiledCols
);
else
if
(
nnz_threshold
!=
0.0
)
kgetColRowStats
<
half
,
STATS_THREADS
,
STATS_ITEMS
,
STATS_ROWS
,
STATS_THREADS
*
STATS_ITEMS
,
1
><<<
num_blocks
,
STATS_THREADS
>>>
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
,
tiledRows
,
tiledCols
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
void
doubleRowColQuant
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
)
{
int
threads
=
64
;
int
items_per_thread
=
4
;
int
tile_cols
=
threads
*
items_per_thread
;
int
tile_rows
=
16
;
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
tile_rows
);
int
num_blocks
=
(
tiledCols
/
tile_cols
)
*
(
tiledRows
/
tile_rows
);
//cout << cols << " " << tiledCols << " " << tiledRows << endl;
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
if
(
threshold
>
0.0
f
)
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
1
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out_col_normed
,
out_row_normed
,
rowidx
,
colidx
,
val
,
nnz_block_ptr
,
threshold
,
rows
,
cols
,
tiledCols
);
else
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
0
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out_col_normed
,
out_row_normed
,
rowidx
,
colidx
,
val
,
nnz_block_ptr
,
threshold
,
rows
,
cols
,
tiledCols
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
template
<
int
FORMAT
,
int
TRANSPOSE
>
void
transformRowToFormat
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
int
threads
=
256
;
int
items_per_thread
=
8
;
// we load 128 column values per warp
int
tile_cols
=
32
*
items_per_thread
;
int
tile_rows
=
32
;
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
tile_rows
);
int
num_blocks
=
(
tiledCols
/
tile_cols
)
*
(
tiledRows
/
tile_rows
);
int
outCols
=
fill_up_to_nearest_multiple
(
cols
,
32
);
int
outRows
=
fill_up_to_nearest_multiple
(
rows
,
32
);
if
(
FORMAT
==
COL_TURING
)
{
if
(
TRANSPOSE
)
outRows
=
fill_up_to_nearest_multiple
(
cols
,
8
);
else
outRows
=
fill_up_to_nearest_multiple
(
rows
,
8
);
}
else
if
(
FORMAT
==
COL_AMPERE
)
{
if
(
TRANSPOSE
)
outRows
=
fill_up_to_nearest_multiple
(
cols
,
32
);
else
outRows
=
fill_up_to_nearest_multiple
(
rows
,
32
);
}
else
{
if
(
TRANSPOSE
)
{
outCols
=
fill_up_to_nearest_multiple
(
rows
,
32
);
outRows
=
cols
;
}
}
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
TRANSPOSE
,
FORMAT
><<<
num_blocks
,
threads
>>>
(
A
,
out
,
rows
,
cols
,
tiledCols
,
outRows
,
outCols
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
void
spmm_coo
(
cusparseHandle_t
handle
,
int
*
A_rowidx
,
int
*
A_colidx
,
half
*
A_vals
,
int
A_nnz
,
int
A_rows
,
int
A_cols
,
int
B_cols
,
int
ldb
,
half
*
B
,
int
ldc
,
half
*
C
,
bool
transposed_B
)
{
cusparseSpMatDescr_t
descA
;
cusparseDnMatDescr_t
descB
,
descC
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
void
*
dBuffer
=
NULL
;
size_t
bufferSize
=
0
;
CHECK_CUSPARSE
(
cusparseCreateCoo
(
&
descA
,
A_rows
,
A_cols
,
A_nnz
,
A_rowidx
,
A_colidx
,
A_vals
,
CUSPARSE_INDEX_32I
,
CUSPARSE_INDEX_BASE_ZERO
,
CUDA_R_16F
)
);
// Create dense matrix C
CHECK_CUSPARSE
(
cusparseCreateDnMat
(
&
descC
,
A_rows
,
B_cols
,
ldc
,
C
,
CUDA_R_16F
,
CUSPARSE_ORDER_ROW
)
);
// Create dense matrix B
if
(
transposed_B
)
{
int
tmp
=
A_cols
;
A_cols
=
B_cols
;
B_cols
=
tmp
;
}
CHECK_CUSPARSE
(
cusparseCreateDnMat
(
&
descB
,
A_cols
,
B_cols
,
ldb
,
B
,
CUDA_R_16F
,
CUSPARSE_ORDER_ROW
)
);
// allocate an external buffer if needed
CHECK_CUSPARSE
(
cusparseSpMM_bufferSize
(
handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
transposed_B
?
CUSPARSE_OPERATION_TRANSPOSE
:
CUSPARSE_OPERATION_NON_TRANSPOSE
,
&
alpha
,
descA
,
descB
,
&
beta
,
descC
,
CUDA_R_32F
,
CUSPARSE_SPMM_ALG_DEFAULT
,
&
bufferSize
)
);
CUDA_CHECK_RETURN
(
cudaMalloc
(
&
dBuffer
,
bufferSize
)
);
// execute SpMM
CHECK_CUSPARSE
(
cusparseSpMM
(
handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
transposed_B
?
CUSPARSE_OPERATION_TRANSPOSE
:
CUSPARSE_OPERATION_NON_TRANSPOSE
,
&
alpha
,
descA
,
descB
,
&
beta
,
descC
,
CUDA_R_32F
,
CUSPARSE_SPMM_ALG_DEFAULT
,
dBuffer
));
// destroy matrix/vector descriptors
CHECK_CUSPARSE
(
cusparseDestroySpMat
(
descA
)
);
CHECK_CUSPARSE
(
cusparseDestroyDnMat
(
descB
)
);
CHECK_CUSPARSE
(
cusparseDestroyDnMat
(
descC
)
);
CUDA_CHECK_RETURN
(
cudaFree
(
dBuffer
)
);
}
template
<
typename
T
,
int
BITS
>
void
spmm_coo_very_sparse_naive
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
T
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
kspmm_coo_very_sparse_naive
<
T
,
8
,
BITS
><<<
nnz_rows
,
256
>>>
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz
,
rowsA
,
rowsB
,
colsB
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template
void
spmm_coo_very_sparse_naive
<
half
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
void
spmm_coo_very_sparse_naive
<
signed
char
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
int
igemmlt
<
COL_TURING
,
32
,
0
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
int
igemmlt
<
COL_TURING
,
8
,
0
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
int
igemmlt
<
COL_TURING
,
8
,
1
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
int
igemmlt
<
COL_AMPERE
,
32
,
0
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
int
igemmlt
<
COL_AMPERE
,
8
,
0
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
int
igemmlt
<
COL_AMPERE
,
8
,
1
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
void
transformRowToFormat
<
COL32
,
0
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
transformRowToFormat
<
COL32
,
1
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
transformRowToFormat
<
COL_TURING
,
0
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
transformRowToFormat
<
COL_TURING
,
1
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
transformRowToFormat
<
COL_AMPERE
,
0
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
transformRowToFormat
<
COL_AMPERE
,
1
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
estimateQuantiles
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
estimateQuantiles
(
float
*
A
,
float
*
code
,
float
offset
,
int
n
);
...
...
csrc/ops.cuh
View file @
c771b3a7
...
...
@@ -14,6 +14,11 @@
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cublasLt.h>
#include <cusparse.h>
#include <vector>
#include <functional>
#define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \
...
...
@@ -25,6 +30,34 @@
#define THREADS_PER_BLOCKS (512)
#define CHECK_CUSPARSE(value) { \
cusparseStatus_t _m_cudaStat = value; \
if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} }
#define THREADS_PER_BLOCKS (512)
inline
void
checkCudaStatus
(
cudaError_t
status
)
{
if
(
status
!=
cudaSuccess
)
{
printf
(
"cuda API failed with status %d: %s
\n
"
,
status
,
cudaGetErrorString
(
status
));
throw
std
::
logic_error
(
"cuda API failed"
);
}
}
inline
int
checkCublasStatus
(
cublasStatus_t
status
)
{
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
printf
(
"cuBLAS API failed with status %d
\n
"
,
status
);
//throw std::logic_error("cuBLAS API failed");
return
1
;
}
return
0
;
}
typedef
enum
Operations_t
{
ksmul
=
0
,
...
...
@@ -39,6 +72,57 @@ typedef enum Optimizer_t
ADAGRAD
=
4
,
}
Optimizer_t
;
typedef
enum
Transform_t
{
ROW
=
0
,
COL
=
1
,
COL32
=
2
,
COL_TURING
=
3
,
COL_AMPERE
=
4
,
}
Transform_t
;
class
Context
{
public:
cublasHandle_t
m_handle
;
Context
()
{
cublasHandle_t
handle
;
cublasCreate_v2
(
&
handle
);
m_handle
=
handle
;
}
};
class
ContextLt
{
public:
cublasLtHandle_t
m_handle
;
ContextLt
()
{
cublasLtHandle_t
handle
;
cublasLtCreate
(
&
handle
);
m_handle
=
handle
;
}
};
class
ContextCusparse
{
public:
cusparseHandle_t
m_handle
;
ContextCusparse
()
{
cusparseHandle_t
handle
;
cusparseCreate
(
&
handle
);
m_handle
=
handle
;
}
};
template
<
typename
T
>
void
estimateQuantiles
(
T
*
A
,
float
*
code
,
float
offset
,
int
n
);
...
...
@@ -70,4 +154,24 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
void
histogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
int
maxidx1
,
int
n
);
void
gemmex
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
);
void
strided_gemmex
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
,
long
long
int
strideA
,
long
long
int
strideB
,
long
long
int
strideC
,
int
batchCount
);
template
<
int
FORMATB
,
int
DTYPE_OUT
,
int
SCALE_ROWS
>
int
igemmlt
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
SRC
,
int
TARGET
,
bool
transpose
,
int
DTYPE
>
void
transform
(
cublasLtHandle_t
ltHandle
,
T
*
A
,
T
*
out
,
int
dim1
,
int
dim2
);
void
cutlass_igemm
(
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
);
void
dequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
int
numRows
,
int
numCols
);
void
getColRowStats
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
);
void
doubleRowColQuant
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
);
template
<
int
FORMAT
,
int
TRANSPOSE
>
void
transformRowToFormat
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
void
spmm_coo
(
cusparseHandle_t
handle
,
int
*
A_rowidx
,
int
*
A_colidx
,
half
*
A_vals
,
int
A_nnz
,
int
A_rows
,
int
A_cols
,
int
B_cols
,
int
ldb
,
half
*
B
,
int
ldc
,
half
*
C
,
bool
transposed_B
);
template
<
typename
T
,
int
BITS
>
void
spmm_coo_very_sparse_naive
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
T
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
#endif
csrc/pythonInterface.c
View file @
c771b3a7
...
...
@@ -84,6 +84,52 @@ void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
#endif
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform<dtype, src, target, transpose, bits>(ltHandle, A, out, dim1, dim2); \
} \
MAKE_FUNC_TRANSFORM
(
8
,
row
,
col
,
n
,
int8_t
,
ROW
,
COL
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
8
,
row
,
row
,
n
,
int8_t
,
ROW
,
ROW
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
8
,
row
,
col32
,
n
,
int8_t
,
ROW
,
COL32
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
32
,
row
,
col32
,
n
,
int32_t
,
ROW
,
COL32
,
false
,
32
);
MAKE_FUNC_TRANSFORM
(
8
,
row
,
col_turing
,
n
,
int8_t
,
ROW
,
COL_TURING
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
8
,
row
,
col_ampere
,
n
,
int8_t
,
ROW
,
COL_AMPERE
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
8
,
col32
,
row
,
n
,
int8_t
,
COL32
,
ROW
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
32
,
col32
,
row
,
n
,
int32_t
,
COL32
,
ROW
,
false
,
32
);
void
transform_row2col32
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL32
,
0
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2col32T
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL32
,
1
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2turing
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL_TURING
,
0
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2turingT
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL_TURING
,
1
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2ampere
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL_AMPERE
,
0
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2ampereT
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL_AMPERE
,
1
>
(
A
,
out
,
rows
,
cols
);
}
int
igemmlt_turing_32
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_TURING
,
32
,
0
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
igemmlt_turing_8
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_TURING
,
8
,
0
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
igemmlt_turing_8_rowscale
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_TURING
,
8
,
1
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
igemmlt_ampere_32
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_AMPERE
,
32
,
0
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
igemmlt_ampere_8
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_AMPERE
,
8
,
0
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
igemmlt_ampere_8_rowscale
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_AMPERE
,
8
,
1
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
void
spmm_coo_very_sparse_naive_fp16
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
spmm_coo_very_sparse_naive
<
half
,
16
>
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz_rows
,
nnz
,
rowsA
,
rowsB
,
colsB
);
}
void
spmm_coo_very_sparse_naive_int8
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
spmm_coo_very_sparse_naive
<
signed
char
,
8
>
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz_rows
,
nnz
,
rowsA
,
rowsB
,
colsB
);
}
extern
"C"
{
#if BUILD_CUDA
...
...
@@ -155,7 +201,86 @@ extern "C"
void
cpercentile_clipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g16
(
g
,
gnorm_vec
,
step
,
n
);
}
void
chistogram_scatter_add_2d
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
int
maxidx1
,
int
n
){
histogramScatterAdd2D
(
histogram
,
index1
,
index2
,
src
,
maxidx1
,
n
);
}
#endif
void
cigemm
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemmex
(
context
,
transposeA
,
transposeB
,
m
,
n
,
k
,
A
,
B
,
C
,
lda
,
ldb
,
ldc
);
}
void
cbatched_igemm
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
,
long
strideA
,
long
strideB
,
long
strideC
,
int
batchCount
)
{
strided_gemmex
(
context
,
transposeA
,
transposeB
,
m
,
n
,
k
,
A
,
B
,
C
,
lda
,
ldb
,
ldc
,
strideA
,
strideB
,
strideC
,
batchCount
);
}
Context
*
get_context
(){
return
new
Context
();
}
ContextCusparse
*
get_cusparse
(){
return
new
ContextCusparse
();
}
int
cigemmlt_turing_32
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_turing_32
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
//{ (cublasLtHandle_t)context->m_handle; return 0; }
//{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int
cigemmlt_turing_8
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_turing_8
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
cigemmlt_turing_8_rowscale
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_turing_8_rowscale
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
cigemmlt_ampere_32
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_ampere_32
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
cigemmlt_ampere_8_rowscale
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_ampere_8_rowscale
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
cigemmlt_ampere_8
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_ampere_8_rowscale
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \
} \
MAKE_FUNC_CTRANSFORM
(
8
,
row
,
col
,
n
,
int8_t
,
ROW
,
COL
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
8
,
row
,
row
,
n
,
int8_t
,
ROW
,
ROW
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
8
,
row
,
col32
,
n
,
int8_t
,
ROW
,
COL32
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
32
,
row
,
col32
,
n
,
int32_t
,
ROW
,
COL32
,
false
,
32
)
MAKE_FUNC_CTRANSFORM
(
8
,
row
,
col_turing
,
n
,
int8_t
,
ROW
,
COL_TURING
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
8
,
row
,
col_ampere
,
n
,
int8_t
,
ROW
,
COL_AMPERE
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
8
,
col32
,
row
,
n
,
int8_t
,
COL32
,
ROW
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
32
,
col32
,
row
,
n
,
int32_t
,
COL32
,
ROW
,
false
,
32
)
void
cdequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
int
numRows
,
int
numCols
)
{
dequant_mm_int32_fp16
(
A
,
rowStats
,
colStats
,
out
,
newRowStats
,
newcolStats
,
numRows
,
numCols
);
}
void
cget_col_row_stats
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
)
{
getColRowStats
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
);
}
void
cdouble_rowcol_quant
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
nnz_row_ptr
,
float
threshold
,
int
rows
,
int
cols
)
{
doubleRowColQuant
(
A
,
rowStats
,
colStats
,
out_col_normed
,
out_row_normed
,
rowidx
,
colidx
,
val
,
nnz_row_ptr
,
threshold
,
rows
,
cols
);
}
void
ctransform_row2col32
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2col32
(
A
,
out
,
rows
,
cols
);
}
void
ctransform_row2col32T
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2col32T
(
A
,
out
,
rows
,
cols
);
}
void
ctransform_row2turing
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2turing
(
A
,
out
,
rows
,
cols
);
}
void
ctransform_row2turingT
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2turingT
(
A
,
out
,
rows
,
cols
);
}
void
ctransform_row2ampere
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2ampere
(
A
,
out
,
rows
,
cols
);
}
void
ctransform_row2ampereT
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2ampereT
(
A
,
out
,
rows
,
cols
);
}
void
cspmm_coo
(
ContextCusparse
*
context
,
int
*
A_rowidx
,
int
*
A_colidx
,
half
*
A_vals
,
int
A_nnz
,
int
A_rows
,
int
A_cols
,
int
B_cols
,
int
ldb
,
half
*
B
,
int
ldc
,
half
*
C
,
bool
transposed_B
)
{
spmm_coo
((
cusparseHandle_t
)
context
->
m_handle
,
A_rowidx
,
A_colidx
,
A_vals
,
A_nnz
,
A_rows
,
A_cols
,
B_cols
,
ldb
,
B
,
ldc
,
C
,
transposed_B
);
}
void
cspmm_coo_very_sparse_naive_fp16
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
spmm_coo_very_sparse_naive_fp16
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz_rows
,
nnz
,
rowsA
,
rowsB
,
colsB
);
}
void
cspmm_coo_very_sparse_naive_int8
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
spmm_coo_very_sparse_naive_int8
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz_rows
,
nnz
,
rowsA
,
rowsB
,
colsB
);
}
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
const
int
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
n
);
}
}
...
...
tests/test_autograd.py
0 → 100644
View file @
c771b3a7
import
pytest
import
torch
import
bitsandbytes
as
bnb
from
itertools
import
product
n
=
1
k
=
25
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
funcs
=
[(
torch
.
bmm
,
bnb
.
bmm_cublas
),
(
torch
.
matmul
,
bnb
.
matmul_cublas
)]
str_funcs
=
[
'bmm'
,
'matmul'
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad_str
=
[
'FF'
,
'TF'
,
'TT'
,
'FT'
]
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
True
),
(
True
,
False
)]
str_transpose
=
[
'FF'
,
'FT'
,
'TT'
,
'TF'
]
dtype
=
[
torch
.
float32
,
torch
.
float16
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
)
def
test_matmul
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dim2
=
dim2
-
(
dim2
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
dim4
=
dim4
-
(
dim4
%
16
)
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
A
=
torch
.
randn
(
size
=
dimA
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B
)
elif
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B
.
t
())
elif
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
.
t
(),
B
)
out_bnb
=
funcs
[
1
](
A
.
t
(),
B
)
elif
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
.
t
(),
B
.
t
())
out_bnb
=
funcs
[
1
](
A
.
t
(),
B
.
t
())
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
# batched matrix multiply
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
B
=
torch
.
randn
(
size
=
(
dim1
,
dim3
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B
)
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
torch
.
testing
.
assert_allclose
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
if
funcs
[
0
]
in
[
torch
.
matmul
]:
dim1
=
dim1
-
(
dim1
%
16
)
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
dimB
=
(
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim3
,
dim4
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
if
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B
.
t
())
else
:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B
)
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
n
=
1
k
=
3
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
#dim1 = (17,)
#dim2 = (7,)
#dim3 = (37,)
#dim4 = (23,)
decomp
=
[
0.0
,
6.0
]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
str_funcs
=
[
'matmul'
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad_str
=
[
'FF'
,
'TF'
,
'TT'
,
'FT'
]
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
'NT'
,
'NN'
]
dtype
=
[
torch
.
float16
]
has_fp16_weights
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
decomp
,
has_fp16_weights
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights"
,
values
,
ids
=
names
)
def
test_matmullt
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
'cuda'
)
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
dimA
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
if
decomp
==
6.0
:
with
torch
.
no_grad
():
A
[:,
outlier_dim
]
=
6.0
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
state
=
bnb
.
MatmulLtState
()
state
.
threshold
=
decomp
state
.
has_fp16_weights
=
has_fp16_weights
if
not
has_fp16_weights
:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
B2
=
B2
.
t
().
contiguous
()
state
.
CB
,
CBt
,
state
.
SCB
,
SCBt
,
coo_tensorB
=
bnb
.
functional
.
double_quant
(
B2
)
B2
=
state
.
CB
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B2
,
state
=
state
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
state
=
state
)
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
#print(f'abs error {err:.4f}')
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
has_fp16_weights
:
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
tests/test_functional.py
View file @
c771b3a7
This diff is collapsed.
Click to expand it.
tests/test_modules.py
View file @
c771b3a7
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
pytest
import
torch
from
itertools
import
product
from
torch
import
nn
import
bitsandbytes
as
bnb
class
MockArgs
(
object
):
def
__init__
(
self
,
initial_data
):
for
key
in
initial_data
:
setattr
(
self
,
key
,
initial_data
[
key
])
class
MLP8bit
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
threshold
=
0.0
):
super
(
MLP8bit
,
self
).
__init__
()
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
return
x
def
get_args
():
args
=
MockArgs
([])
args
.
quant_type
=
'vector'
args
.
use_8bit_training
=
'full'
args
.
clip_freq
=
9999
return
args
def
assert_all_approx_close
(
a
,
b
,
atol
=
1e-8
,
rtol
=
1e-5
,
count
=
10
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
sumval
=
(
idx
==
0
).
sum
().
item
()
if
sumval
>
count
:
print
(
f
'Too many values not close: assert
{
sumval
}
<
{
count
}
'
)
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
class
LinearFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
#std = torch.abs(x).mean()*norm
std
=
torch
.
std
(
x
)
max1
=
std
*
trim_value
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
x
[
x
>
127
]
=
127
x
[
x
<
-
127
]
=
-
127
x
=
x
/
127
*
max1
return
x
def
quant
(
x
,
quant_type
,
dim
=
1
):
if
quant_type
==
'linear'
:
max1
=
torch
.
abs
(
x
).
max
().
float
()
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
elif
quant_type
==
'vector'
:
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
elif
quant_type
==
'min-max'
:
maxA
=
torch
.
amax
(
x
,
dim
=
dim
,
keepdim
=
True
).
float
()
minA
=
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
).
float
()
scale
=
(
maxA
-
minA
)
/
2.0
xq
=
torch
.
round
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
)
return
xq
,
(
minA
.
float
(),
scale
.
float
())
else
:
return
None
def
dequant
(
xq
,
S1
,
S2
,
dtype
,
quant_type
):
if
quant_type
==
'linear'
:
norm
=
S1
*
S2
/
(
127
*
127
)
# double cast needed to prevent overflows
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
'vector'
:
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
S1
.
shape
)
==
3
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
S2
.
shape
)
==
3
:
S2
=
S2
.
squeeze
(
0
)
#print(x.shape, S1.shape, S2.shape)
if
len
(
S1
.
shape
)
==
2
:
x
*=
S1
.
t
()
/
127
else
:
x
*=
S1
/
127
x
*=
S2
/
127
return
x
.
to
(
dtype
)
else
:
return
None
def
dequant_min_max
(
xq
,
A
,
B
,
SA
,
SB
,
dtype
):
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
SB
=
SB
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
SA
.
shape
)
==
3
:
SA
=
SA
.
squeeze
(
0
)
if
len
(
SB
.
shape
)
==
2
:
x
*=
SB
.
t
()
/
127
else
:
x
*=
SB
/
127
x
*=
SA
[
1
]
/
127
x
+=
offset
return
x
.
to
(
dtype
)
def
get_8bit_linear
(
x
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
max1
=
torch
.
abs
(
x
).
max
()
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
/
127
*
max1
#x = torch.round(x)/128*max1
return
x
@
staticmethod
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
x
=
(
x
*
127
)
/
max1
x
=
round_func
(
x
)
/
127
*
max1
return
x
@
staticmethod
def
round_stoachastic
(
x
):
sign
=
torch
.
sign
(
x
)
absx
=
torch
.
abs
(
x
)
decimal
=
absx
-
torch
.
floor
(
absx
)
rdm
=
torch
.
rand_like
(
decimal
)
return
sign
*
(
torch
.
floor
(
absx
)
+
(
rdm
<
decimal
).
to
(
x
.
dtype
))
@
staticmethod
def
fake_8bit_storage
(
w
,
exponent_bits
):
code
=
bnb
.
functional
.
create_dynamic_map
(
n
=
exponent_bits
).
to
(
w
.
device
)
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
w
.
data
,
code
=
code
)
out
=
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
,
code
)
out
=
out
.
half
()
w
.
copy_
(
out
)
return
out
@
staticmethod
def
fake_8bit_storage_quantile
(
w
,
args
):
code
=
bnb
.
functional
.
estimate_quantiles
(
w
.
data
,
offset
=
args
.
offset
)
#C = bnb.functional.quantize_no_absmax(code, w)
#out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
#print(out)
#out = out.half()
code
/=
torch
.
max
(
torch
.
abs
(
code
))
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
w
.
data
,
code
=
code
)
out
=
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
,
code
)
out
=
out
.
half
()
w
.
copy_
(
out
)
return
out
@
staticmethod
def
fake_8bit_storage_stoachstic
(
w
):
rand
=
torch
.
rand
(
1024
,
device
=
w
.
device
)
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
w
.
data
,
rand
=
rand
)
out
=
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
)
out
=
out
.
half
()
w
.
copy_
(
out
)
return
out
@
staticmethod
def
fake_8bit_storage_with_max
(
w
,
topk
=
8
):
blocked_w
=
einops
.
rearrange
(
w
.
flatten
(),
'(h b) -> h b'
,
b
=
256
)
max_val
,
idx
=
torch
.
sort
(
torch
.
abs
(
blocked_w
),
dim
=
1
,
descending
=
True
)
idx
=
idx
[:,
:
topk
]
max_val
=
max_val
[:,
:
topk
]
mask
=
torch
.
zeros_like
(
blocked_w
)
mask
.
scatter_
(
dim
=
1
,
index
=
idx
,
src
=
torch
.
ones_like
(
max_val
))
mask
=
mask
.
bool
()
# 1. zero out max values
# 2. quantize + dequantize
# 3. write back max values
# 4. copy matrix back to weight
values
=
blocked_w
[
mask
]
blocked_w
[
mask
]
=
0
code
=
bnb
.
functional
.
create_dynamic_map
()
code
=
code
.
to
(
w
.
device
)
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
blocked_w
.
data
)
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
,
out
=
blocked_w
)
blocked_w
[
mask
]
=
values
unblocked_w
=
blocked_w
.
flatten
().
view
(
w
.
shape
)
w
.
copy_
(
unblocked_w
)
return
unblocked_w
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
=
None
,
args
=
None
):
if
args
.
use_8bit_training
!=
'off'
:
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
#if torch.rand(1) < 0.01:
#output32 = torch.matmul(x, weight.t())
#err = torch.abs(output-output32).float()
#relerr = err/(torch.abs(output32).float()+1e-8)
#print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else
:
#output = torch.matmul(x, weight.t())
output
=
torch
.
einsum
(
'bsi,oi->bso'
,
x
,
weight
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
)
ctx
.
args
=
args
if
bias
is
not
None
:
output
+=
bias
.
unsqueeze
(
0
).
expand_as
(
output
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
x
,
weight
,
bias
=
ctx
.
saved_tensors
args
=
ctx
.
args
stochastic
=
False
grad_input
=
grad_weight
=
grad_bias
=
None
if
bias
is
not
None
and
ctx
.
needs_input_grad
[
2
]:
grad_bias
=
grad_output
.
sum
(
0
)
# weight and x are already 8bit
# -> transform grad_output to 8-bit
if
args
.
use_8bit_training
==
'forward+wgrad'
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
bnb
.
functional
.
igemm
(
grad_output8
,
x8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
#grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input
=
grad_output
.
matmul
(
weight
)
elif
args
.
use_8bit_training
==
'full'
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int32
)
bnb
.
functional
.
igemm
(
grad_output8
,
x8
,
out
=
grad_weight8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input
=
LinearFunction
.
dequant
(
grad_input8
,
S1
,
S3
,
grad_output
.
dtype
,
args
.
quant_type
)
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
grad_weight
=
torch
.
einsum
(
'bsi,bso->oi'
,
x
,
grad_output
)
@
pytest
.
mark
.
parametrize
(
"embcls"
,
[
bnb
.
nn
.
Embedding
,
bnb
.
nn
.
StableEmbedding
],
ids
=
[
'Embedding'
,
'StableEmbedding'
])
def
test_embeddings
(
embcls
):
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
emb1
=
torch
.
nn
.
Embedding
(
100
,
512
).
cuda
()
emb2
=
embcls
(
100
,
512
).
cuda
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
adam1
=
bnb
.
optim
.
Adam8bit
(
emb1
.
parameters
())
adam2
=
bnb
.
optim
.
Adam8bit
(
emb2
.
parameters
())
class
Linear8bit
(
nn
.
Module
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
args
=
None
):
super
(
Linear8bit
,
self
).
__init__
()
self
.
input_features
=
input_features
self
.
output_features
=
output_features
self
.
args
=
args
batches
=
torch
.
randint
(
1
,
100
,
size
=
(
100
,
4
,
32
)).
cuda
()
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
output_features
,
input_features
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
output_features
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
if
self
.
bias
is
not
None
:
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x
):
self
.
args
.
training
=
self
.
training
return
LinearFunction
.
apply
(
x
,
self
.
weight
,
self
.
bias
,
self
.
args
)
def
test_linear8bit
():
l0
=
torch
.
nn
.
Linear
(
32
,
64
).
cuda
().
half
()
l1
=
bnb
.
nn
.
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l2
=
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l3
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
).
cuda
().
half
()
l0
.
weight
.
data
=
l2
.
weight
.
data
.
clone
()
l0
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
l1
.
weight
.
data
=
l2
.
weight
.
data
.
clone
()
l1
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
l3
.
weight
.
data
=
l2
.
weight
.
data
.
clone
()
l3
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
t
=
torch
.
randn
(
16
,
8
,
64
,
device
=
'cuda'
).
half
()
b2
=
b1
.
clone
()
b3
=
b1
.
clone
()
b0
=
b1
.
clone
()
o0
=
l0
(
b0
)
o1
=
l1
(
b1
)
o2
=
l2
(
b2
)
o3
=
l3
(
b3
)
assert_all_approx_close
(
o1
,
o2
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
1
)
assert_all_approx_close
(
o3
,
o2
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
1
)
loss0
=
torch
.
nn
.
functional
.
mse_loss
(
o0
,
t
)
loss1
=
torch
.
nn
.
functional
.
mse_loss
(
o1
,
t
)
loss2
=
torch
.
nn
.
functional
.
mse_loss
(
o2
,
t
)
loss3
=
torch
.
nn
.
functional
.
mse_loss
(
o3
,
t
)
loss0
.
backward
()
loss1
.
backward
()
loss2
.
backward
()
loss3
.
backward
()
assert_all_approx_close
(
l1
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l3
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l1
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
assert_all_approx_close
(
l3
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
err1
=
torch
.
abs
(
l0
.
weight
.
grad
-
l1
.
weight
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
l0
.
weight
.
grad
-
l2
.
weight
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
l0
.
weight
.
grad
-
l3
.
weight
.
grad
).
mean
().
item
()
assert
err1
*
0.8
<
err2
assert
err2
*
0.8
<
err3
assert
err3
*
0.8
<
err1
l0
.
weight
.
grad
=
None
l1
.
weight
.
grad
=
None
l2
.
weight
.
grad
=
None
l3
.
weight
.
grad
=
None
l0
.
bias
.
grad
=
None
l1
.
bias
.
grad
=
None
l2
.
bias
.
grad
=
None
l3
.
bias
.
grad
=
None
threshold
=
[
0.0
,
3.0
]
values
=
threshold
names
=
[
'threshold_{0}'
.
format
(
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_inference
(
threshold
):
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
).
cuda
().
half
()
assert
l1
.
weight
.
device
.
type
==
'cuda'
assert
l1
.
weight
.
dtype
==
torch
.
float16
l1
.
eval
()
for
i
in
range
(
100
):
batch
=
batches
[
i
]
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
l1
(
b1
)
if
i
==
1
:
assert
l1
.
state
.
CxB
is
not
None
def
test_linear8bitlt_accumulated_gradient
():
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
l2
[
1
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
1
].
bias
.
clone
())
opt1
=
bnb
.
optim
.
Adam8bit
(
l1
.
parameters
(),
lr
=
0.001
)
opt2
=
bnb
.
optim
.
Adam8bit
(
l2
.
parameters
(),
lr
=
0.001
)
acc_steps
=
10
embedded1
=
emb1
(
batch
)
embedded2
=
emb2
(
batch
)
for
i
in
range
(
10
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
l1
(
b1
)
o2
=
l2
(
b1
)
loss1
=
o1
.
mean
()
loss2
=
o2
.
mean
()
loss1
.
backward
()
loss2
.
backward
()
if
i
==
2
:
assert
l1
[
0
].
state
.
CxB
is
not
None
assert
l1
[
1
].
state
.
CxB
is
not
None
l1
=
embedded1
.
mean
()
l2
=
embedded2
.
mean
()
if
i
>
0
and
i
%
acc_steps
==
0
:
opt1
.
step
()
opt1
.
zero_grad
(
True
)
opt2
.
step
()
opt2
.
zero_grad
(
True
)
assert_all_approx_close
(
l1
[
0
].
weight
,
l2
[
0
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
assert_all_approx_close
(
l1
[
1
].
weight
,
l2
[
1
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
# we do this copy because otherwise we have small divergences over time that add up
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
else
:
torch
.
testing
.
assert_allclose
(
l1
[
0
].
weight
.
grad
,
l2
[
0
].
weight
.
grad
)
torch
.
testing
.
assert_allclose
(
l1
[
1
].
weight
.
grad
,
l2
[
1
].
weight
.
grad
)
l1
.
backward
()
l2
.
backward
()
adam1
.
step
()
adam2
.
step
()
threshold
=
[
0.0
,
2.0
]
values
=
threshold
names
=
[
'threshold_{0}'
.
format
(
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_no_fp16_weights
(
threshold
):
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
assert
l1
.
weight
.
dtype
==
torch
.
int8
adam1
.
zero_grad
()
adam2
.
zero_grad
()
l1
.
eval
()
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
l1
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
()
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
adam1
.
state
[
emb1
.
weight
][
'state1'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
emb2
.
weight
][
'state1'
].
dtype
==
torch
.
float32
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
cuda
()
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
to
(
'cuda'
)
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
'cuda'
assert
mlp
.
fc2
.
weight
.
device
.
type
==
'cuda'
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
to
(
torch
.
float16
).
to
(
'cuda'
)
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
'cuda'
assert
mlp
.
fc2
.
weight
.
device
.
type
==
'cuda'
tests/test_optim.py
View file @
c771b3a7
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
time
import
shutil
import
uuid
import
pytest
import
ctypes
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
...
...
@@ -14,7 +11,9 @@ import bitsandbytes.functional as F
from
os.path
import
join
from
itertools
import
product
import
apex
#import apex
k
=
20
def
get_temp_dir
():
path
=
'/tmp/autoswap/{0}'
.
format
(
str
(
uuid
.
uuid4
()))
...
...
@@ -26,55 +25,47 @@ def rm_path(path):
str2optimizers
=
{}
str2optimizers
[
'adam_pytorch'
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'adam_apex'
]
=
(
None
,
apex
.
optimizers
.
FusedAdam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'momentum_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
optimizers
.
FusedSGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
)
#
str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
#
str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers
[
'momentum_pytorch'
]
=
(
None
,
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
)
str2optimizers
[
'lamb_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.00
,
use_nvlamb
=
True
),
bnb
.
optim
.
Adam
)
str2optimizers
[
'lars_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
parallel
.
LARC
.
LARC
(
apex
.
optimizers
.
FusedSGD
(
pxx
,
0.01
,
0.9
)),
bnb
.
optim
.
Adam
)
#
str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
#
str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers
[
'adam'
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'adamw'
]
=
(
torch
.
optim
.
AdamW
,
bnb
.
optim
.
AdamW
)
str2optimizers
[
'fused_adam'
]
=
(
apex
.
optimizers
.
FusedAdam
,
bnb
.
optim
.
Adam
)
#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers
[
'momentum'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'lars'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'lamb'
]
=
(
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.0
,
max_grad_norm
=
10000.0
,
eps
=
1e-8
,
use_nvlamb
=
True
),
bnb
.
optim
.
LAMB
)
#
str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers
[
'rmsprop'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'adagrad'
]
=
(
lambda
pxx
:
torch
.
optim
.
Adagrad
(
pxx
,
0.01
),
lambda
pxx
:
bnb
.
optim
.
Adagrad
(
pxx
,
0.01
,
block_wise
=
False
))
str2optimizers
[
'adam8bit'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
))
str2optimizers
[
'momentum8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'rmsprop8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'lamb8bit'
]
=
(
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.0
,
max_grad_norm
=
10000.0
,
eps
=
1e-8
,
use_nvlamb
=
True
),
bnb
.
optim
.
LAMB8bit
)
#
str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers
[
'lars8bit'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'adam8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'adamw8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
AdamW8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'momentum8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'rmsprop8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'adagrad8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
Adagrad
(
pxx
,
0.01
),
lambda
pxx
:
bnb
.
optim
.
Adagrad8bit
(
pxx
,
0.01
,
block_wise
=
True
))
str2statenames
=
{}
str2statenames
[
'adam'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'adamw'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'momentum'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lars'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lamb'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'rmsprop'
]
=
[(
'square_avg'
,
'state1'
)]
str2statenames
[
'adagrad'
]
=
[(
'sum'
,
'state1'
)]
str2statenames
[
'adam8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'lamb8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'adam8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'adamw8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'momentum8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'momentum8bit_blockwise'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'lars8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit_blockwise'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'adagrad8bit_blockwise'
]
=
[(
'sum'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam'
,
'adamw'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
,
'adagrad'
]
optimizer_names
=
[
'adam'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
...
@@ -89,12 +80,12 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
2
e-6
,
1e-5
atol
,
rtol
=
1
e-6
,
1e-5
else
:
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
50
):
for
i
in
range
(
k
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
...
...
@@ -107,7 +98,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
if
i
%
10
==
0
and
i
>
0
:
if
i
%
(
k
//
5
)
==
0
and
i
>
0
:
path
=
get_temp_dir
()
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
'opt.pt'
))
del
bnb_optimizer
...
...
@@ -148,7 +139,6 @@ def test_global_config(dim1, dim2, gtype):
eps
=
1e-8
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p2
,
'skip_zeros'
,
True
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
'optim_bits'
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
...
...
@@ -163,8 +153,6 @@ def test_global_config(dim1, dim2, gtype):
else
:
atol
,
rtol
=
1e-4
,
1e-3
original_p2
=
p2
[
mask
].
clone
()
for
i
in
range
(
50
):
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
...
...
@@ -173,38 +161,17 @@ def test_global_config(dim1, dim2, gtype):
p2
.
grad
=
g2
p3
.
grad
=
g3
if
i
>
30
and
i
%
10
==
0
:
g1
.
data
[
mask
]
=
0.0
g2
.
data
[
mask
]
=
0.0
p1
.
grad
=
g1
p2
.
grad
=
g2
original_p1
=
p1
[
mask
].
clone
()
original_p2
=
p2
[
mask
].
clone
()
og_s1
=
adam2
.
state
[
p2
][
'state1'
][
mask
].
clone
()
og_s2
=
adam2
.
state
[
p2
][
'state2'
][
mask
].
clone
()
og_s11
=
adam2
.
state
[
p1
][
'state1'
][
mask
].
clone
()
og_s21
=
adam2
.
state
[
p1
][
'state2'
][
mask
].
clone
()
adam2
.
step
()
assert
adam2
.
state
[
p3
][
'state1'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
'state2'
].
dtype
==
torch
.
uint8
if
i
>
30
and
i
%
10
==
0
:
torch
.
testing
.
assert_allclose
(
original_p2
,
p2
[
mask
])
torch
.
testing
.
assert_allclose
(
adam2
.
state
[
p2
][
'state1'
][
mask
],
og_s1
)
torch
.
testing
.
assert_allclose
(
adam2
.
state
[
p2
][
'state2'
][
mask
],
og_s2
)
assert
((
p1
[
mask
]
-
original_p1
)
==
0.0
).
sum
()
<
p1
.
numel
()
assert
((
adam2
.
state
[
p1
][
'state1'
][
mask
]
-
og_s11
)
==
0.0
).
sum
()
==
0.0
assert
((
adam2
.
state
[
p1
][
'state2'
][
mask
]
-
og_s21
)
==
0.0
).
sum
()
==
0.0
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'adamw8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
,
'adagrad8bit_blockwise'
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
...
@@ -370,13 +337,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p1
])
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
for
i
in
range
(
5000
):
if
i
==
500
:
for
i
in
range
(
k
):
if
i
==
k
//
5
:
# 100 iterations for burn-in
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
...
...
@@ -386,23 +352,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
torch
.
cuda
.
synchronize
()
s
=
time
.
time
()
-
t0
print
(
''
)
params
=
4500
*
4096
*
4096
params
=
(
k
-
k
//
5
)
*
dim1
*
dim2
print
(
optim_name
,
gtype
,
s
/
params
)
#assert s < 3.9
def
test_str_betas
():
betas
=
(
0.80
,
0.95
)
strbetas
=
'(0.80, 0.95)'
layer
=
torch
.
nn
.
Linear
(
10
,
10
)
base
=
bnb
.
optim
.
Adam
(
layer
.
parameters
(),
betas
=
betas
)
strbase
=
bnb
.
optim
.
Adam
(
layer
.
parameters
(),
betas
=
strbetas
)
assert
base
.
defaults
[
'betas'
][
0
]
==
0.8
assert
base
.
defaults
[
'betas'
][
1
]
==
0.95
assert
strbase
.
defaults
[
'betas'
][
0
]
==
0.8
assert
strbase
.
defaults
[
'betas'
][
1
]
==
0.95
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment