Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
ea7c14f8
You need to sign in or sign up before continuing.
Commit
ea7c14f8
authored
Aug 01, 2022
by
Titus von Koeller
Browse files
reran black with linelength 80 for greater readability
parent
3fd06fb6
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
665 additions
and
203 deletions
+665
-203
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+7
-2
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+36
-9
bitsandbytes/cuda_setup.py
bitsandbytes/cuda_setup.py
+38
-7
bitsandbytes/functional.py
bitsandbytes/functional.py
+99
-29
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+28
-6
bitsandbytes/optim/adagrad.py
bitsandbytes/optim/adagrad.py
+9
-3
bitsandbytes/optim/adam.py
bitsandbytes/optim/adam.py
+21
-6
bitsandbytes/optim/lars.py
bitsandbytes/optim/lars.py
+15
-5
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+56
-21
bitsandbytes/optim/rmsprop.py
bitsandbytes/optim/rmsprop.py
+9
-3
bitsandbytes/utils.py
bitsandbytes/utils.py
+1
-1
quicktest.py
quicktest.py
+14
-6
tests/test_autograd.py
tests/test_autograd.py
+74
-22
tests/test_cuda_setup_evaluator.py
tests/test_cuda_setup_evaluator.py
+26
-7
tests/test_functional.py
tests/test_functional.py
+138
-49
tests/test_modules.py
tests/test_modules.py
+39
-11
tests/test_optim.py
tests/test_optim.py
+55
-16
No files found.
bitsandbytes/__init__.py
View file @
ea7c14f8
...
...
@@ -3,8 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.autograd._functions
import
(
MatmulLtState
,
bmm_cublas
,
matmul
,
matmul_cublas
,
mm_cublas
)
from
.autograd._functions
import
(
MatmulLtState
,
bmm_cublas
,
matmul
,
matmul_cublas
,
mm_cublas
,
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
...
...
bitsandbytes/autograd/_functions.py
View file @
ea7c14f8
...
...
@@ -111,7 +111,9 @@ class MatMul8bit(torch.autograd.Function):
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
...
...
@@ -146,7 +148,11 @@ class MatMul8bit(torch.autograd.Function):
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
,
)
return
grad_A
,
grad_B
,
None
,
None
,
None
...
...
@@ -211,7 +217,9 @@ class MatMul8bitLt(torch.autograd.Function):
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
has_fp16_weights
:
...
...
@@ -225,7 +233,9 @@ class MatMul8bitLt(torch.autograd.Function):
if
state
.
CxB
is
None
:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
...
...
@@ -259,7 +269,13 @@ class MatMul8bitLt(torch.autograd.Function):
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
state
.
reset_grads
()
CB
,
state
.
CBt
,
state
.
SCB
,
state
.
SCBt
,
coo_tensorB
=
F
.
double_quant
(
B
)
(
CB
,
state
.
CBt
,
state
.
SCB
,
state
.
SCBt
,
coo_tensorB
,
)
=
F
.
double_quant
(
B
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
else
:
has_grad
=
False
...
...
@@ -277,7 +293,10 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
)
.
t
()
.
contiguous
()
.
half
()
)
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
...
...
@@ -325,10 +344,14 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
state
=
ctx
.
state
assert
state
.
has_fp16_weights
,
"Backprop only supported for fp16 weights."
assert
(
state
.
has_fp16_weights
),
"Backprop only supported for fp16 weights."
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]
).
contiguous
()
grad_A
=
grad_B
=
None
...
...
@@ -359,7 +382,11 @@ matmul = MatMul8bitLt.apply
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
,
):
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
...
...
bitsandbytes/cuda_setup.py
View file @
ea7c14f8
"""
build is dependent on
-
compute capability
- dependent on GPU family
extract factors the
build is dependent on
:
[X]
compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
...
...
@@ -19,6 +19,8 @@ evaluation:
"""
import
ctypes
import
shlex
import
subprocess
from
os
import
environ
as
env
from
pathlib
import
Path
from
typing
import
Set
,
Union
...
...
@@ -26,10 +28,31 @@ from typing import Set, Union
from
.utils
import
print_err
,
warn_of_missing_prerequisite
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
_decode
(
subprocess_err_out_tuple
):
return
tuple
(
to_decode
.
decode
(
"UTF-8"
).
strip
()
for
to_decode
in
subprocess_err_out_tuple
)
def
execute_and_return_decoded_std_streams
(
command_string
):
return
_decode
(
subprocess
.
Popen
(
shlex
.
split
(
command_string
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
).
communicate
()
)
std_out
,
std_err
=
execute_and_return_decoded_std_streams
()
return
std_out
,
std_err
def
check_cuda_result
(
cuda
,
result_val
):
if
result_val
!=
0
:
# TODO: undefined name 'error_str'
cuda
.
cuGetErrorString
(
result_val
,
ctypes
.
byref
(
error_str
))
print
(
f
"Count not initialize CUDA - failure!"
)
print
(
"Count not initialize CUDA - failure!"
)
raise
Exception
(
"CUDA exception!"
)
return
result_val
...
...
@@ -53,7 +76,9 @@ def get_compute_capability():
result
=
ctypes
.
c_int
()
device
=
ctypes
.
c_int
()
# TODO: local variable 'context' is assigned to but never used
context
=
ctypes
.
c_void_p
()
# TODO: local variable 'error_str' is assigned to but never used
error_str
=
ctypes
.
c_char_p
()
result
=
check_cuda_result
(
cuda
,
cuda
.
cuInit
(
0
))
...
...
@@ -61,7 +86,9 @@ def get_compute_capability():
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGetCount
(
ctypes
.
byref
(
nGpus
)))
ccs
=
[]
for
i
in
range
(
nGpus
.
value
):
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
i
))
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
i
)
)
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceComputeCapability
(
...
...
@@ -114,11 +141,15 @@ def get_cuda_runtime_lib_path(
}
-
non_existent_directories
if
len
(
cuda_runtime_libs
)
>
1
:
err_msg
=
f
"Found duplicate
{
CUDA_RUNTIME_LIB
}
files:
{
cuda_runtime_libs
}
.."
err_msg
=
(
f
"Found duplicate
{
CUDA_RUNTIME_LIB
}
files:
{
cuda_runtime_libs
}
.."
)
raise
FileNotFoundError
(
err_msg
)
elif
len
(
cuda_runtime_libs
)
<
1
:
err_msg
=
f
"Did not find
{
CUDA_RUNTIME_LIB
}
files:
{
cuda_runtime_libs
}
.."
err_msg
=
(
f
"Did not find
{
CUDA_RUNTIME_LIB
}
files:
{
cuda_runtime_libs
}
.."
)
raise
FileNotFoundError
(
err_msg
)
single_cuda_runtime_lib_dir
=
next
(
iter
(
cuda_runtime_libs
))
...
...
bitsandbytes/functional.py
View file @
ea7c14f8
...
...
@@ -17,14 +17,29 @@ if COMPILED_WITH_CUDA:
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{}
str2optimizer32bit
[
"adam"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
"momentum"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
"rmsprop"
]
=
(
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
)
str2optimizer32bit
[
"adagrad"
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
)
str2optimizer32bit
[
"lars"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
"momentum"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
,
)
str2optimizer32bit
[
"rmsprop"
]
=
(
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
,
)
str2optimizer32bit
[
"adagrad"
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
,
)
str2optimizer32bit
[
"lars"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
,
)
str2optimizer32bit
[
"lamb"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer8bit
=
{}
str2optimizer8bit
[
"adam"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
"adam"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
,
)
str2optimizer8bit
[
"momentum"
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
,
...
...
@@ -33,7 +48,10 @@ if COMPILED_WITH_CUDA:
lib
.
crmsprop_static_8bit_g32
,
lib
.
crmsprop_static_8bit_g16
,
)
str2optimizer8bit
[
"lamb"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
"lamb"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
,
)
str2optimizer8bit
[
"lars"
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
,
...
...
@@ -137,7 +155,9 @@ def create_dynamic_map(signed=True, n=7):
if
not
signed
:
additional_items
=
2
*
additional_items
for
i
in
range
(
n
):
fraction_items
=
2
**
(
i
+
7
-
n
)
+
1
if
signed
else
2
**
(
i
+
7
-
n
+
1
)
+
1
fraction_items
=
(
2
**
(
i
+
7
-
n
)
+
1
if
signed
else
2
**
(
i
+
7
-
n
+
1
)
+
1
)
boundaries
=
torch
.
linspace
(
0.1
,
1
,
fraction_items
)
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
...
...
@@ -272,7 +292,13 @@ def get_transform_buffer(
def
nvidia_transform
(
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
,
):
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
...
...
@@ -352,7 +378,11 @@ def estimate_quantiles(
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
,
)
->
Tensor
:
"""
Quantize tensor A in blocks of size 4096 values.
...
...
@@ -629,7 +659,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
"""
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
lib
.
cdequantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
())
)
return
out
...
...
@@ -1005,7 +1037,9 @@ def histogram_scatter_add_2d(
)
def
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
int8
):
def
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
int8
):
if
not
torch
.
cuda
.
is_initialized
():
torch
.
cuda
.
init
()
if
A
.
dtype
!=
expected_type
or
B
.
dtype
!=
expected_type
:
...
...
@@ -1097,7 +1131,11 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
def
igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
,
):
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
if
out
is
None
:
...
...
@@ -1193,7 +1231,11 @@ def igemm(
def
batched_igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
,
):
if
not
len
(
A
.
shape
)
==
3
or
not
len
(
B
.
shape
)
==
3
:
raise
ValueError
(
...
...
@@ -1392,9 +1434,13 @@ def mm_dequant(
if
out
is
None
:
out
=
torch
.
empty
(
out_shape
,
dtype
=
torch
.
float16
,
device
=
A
.
device
)
if
new_row_stats
is
None
:
new_row_stats
=
torch
.
empty
(
out_shape
[
0
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
new_row_stats
=
torch
.
empty
(
out_shape
[
0
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
if
new_col_stats
is
None
:
new_col_stats
=
torch
.
empty
(
out_shape
[
1
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
new_col_stats
=
torch
.
empty
(
out_shape
[
1
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
assert
(
new_row_stats
.
shape
[
0
]
==
row_stats
.
shape
[
0
]
),
f
"
{
new_row_stats
.
shape
}
vs
{
row_stats
.
shape
}
"
...
...
@@ -1440,13 +1486,13 @@ def get_colrow_absmax(
col_tiles
=
(
cols
+
255
)
//
256
tiled_rows
=
((
rows
+
15
)
//
16
)
*
16
if
row_stats
is
None
:
row_stats
=
torch
.
empty
(
(
rows
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
row_stats
=
torch
.
empty
(
(
rows
,),
dtype
=
torch
.
float32
,
device
=
device
)
.
fill_
(
-
50000.0
)
if
col_stats
is
None
:
col_stats
=
torch
.
empty
(
(
cols
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
col_stats
=
torch
.
empty
(
(
cols
,),
dtype
=
torch
.
float32
,
device
=
device
)
.
fill_
(
-
50000.0
)
if
nnz_block_ptr
is
None
and
threshold
>
0.0
:
nnz_block_ptr
=
torch
.
zeros
(
...
...
@@ -1462,7 +1508,13 @@ def get_colrow_absmax(
prev_device
=
pre_call
(
A
.
device
)
lib
.
cget_col_row_stats
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
,
)
post_call
(
prev_device
)
...
...
@@ -1526,7 +1578,9 @@ class CSCSparseTensor(object):
def
coo2csr
(
cooA
):
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
values
.
add_
(
1
)
rowptr
=
torch
.
zeros
((
cooA
.
rows
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
rowidx
.
device
)
rowptr
=
torch
.
zeros
(
(
cooA
.
rows
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
rowidx
.
device
)
rowptr
.
scatter_
(
index
=
values
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
rowptr
.
cumsum_
(
0
)
return
CSRSparseTensor
(
...
...
@@ -1540,10 +1594,14 @@ def coo2csc(cooA):
values
=
cooA
.
values
[
col2rowidx
]
colvalues
,
counts
=
torch
.
unique
(
val
,
return_counts
=
True
)
colvalues
.
add_
(
1
)
colptr
=
torch
.
zeros
((
cooA
.
cols
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
colidx
.
device
)
colptr
=
torch
.
zeros
(
(
cooA
.
cols
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
colidx
.
device
)
colptr
.
scatter_
(
index
=
colvalues
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
colptr
.
cumsum_
(
0
)
return
CSCSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
colptr
,
rowidx
,
values
)
return
CSCSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
colptr
,
rowidx
,
values
)
def
coo_zeros
(
rows
,
cols
,
nnz
,
device
,
dtype
=
torch
.
half
):
...
...
@@ -1568,7 +1626,9 @@ def double_quant(
rows
=
A
.
shape
[
0
]
if
row_stats
is
None
or
col_stats
is
None
:
row_stats
,
col_stats
,
nnz_row_ptr
=
get_colrow_absmax
(
A
,
threshold
=
threshold
)
row_stats
,
col_stats
,
nnz_row_ptr
=
get_colrow_absmax
(
A
,
threshold
=
threshold
)
if
out_col
is
None
:
out_col
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
...
...
@@ -1663,7 +1723,13 @@ def get_special_format_str():
def
transform
(
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
,
):
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
...
...
@@ -1716,7 +1782,9 @@ def transform(
def
spmm_coo
(
cooA
,
B
,
out
=
None
):
if
out
is
None
:
out
=
torch
.
empty
((
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
B
.
dtype
)
out
=
torch
.
empty
(
(
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
B
.
dtype
)
nnz
=
cooA
.
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
...
...
@@ -1982,7 +2050,9 @@ def extract_outliers(A, SA, idx):
assert
formatA
in
[
"col_turing"
,
"col_ampere"
]
assert
A
.
device
.
type
==
"cuda"
out
=
torch
.
zeros
((
shapeA
[
0
],
idx
.
numel
()),
dtype
=
torch
.
int8
,
device
=
A
.
device
)
out
=
torch
.
zeros
(
(
shapeA
[
0
],
idx
.
numel
()),
dtype
=
torch
.
int8
,
device
=
A
.
device
)
idx_size
=
ct
.
c_int32
(
idx
.
numel
())
rows
=
ct
.
c_int32
(
shapeA
[
0
])
...
...
bitsandbytes/nn/modules.py
View file @
ea7c14f8
...
...
@@ -2,8 +2,19 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypeVar
,
Union
,
overload
)
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypeVar
,
Union
,
overload
,
)
import
torch
import
torch.nn.functional
as
F
...
...
@@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding):
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
,
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
...
...
@@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter):
return
self
.
cuda
(
device
)
else
:
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
,
)
...
...
@@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear):
threshold
=
0.0
,
index
=
None
,
):
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
index
=
index
...
...
@@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear):
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
)
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
)
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
...
...
bitsandbytes/optim/adagrad.py
View file @
ea7c14f8
...
...
@@ -23,7 +23,9 @@ class Adagrad(Optimizer1State):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
...
...
@@ -63,7 +65,9 @@ class Adagrad8bit(Optimizer1State):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
...
...
@@ -104,7 +108,9 @@ class Adagrad32bit(Optimizer1State):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
...
...
bitsandbytes/optim/adam.py
View file @
ea7c14f8
...
...
@@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer):
savedir
=
None
,
):
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
,
)
super
(
AnalysisAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
analysis
=
bnb_analysis
...
...
@@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer):
state
[
"relerrors"
]
=
torch
.
zeros
(
(
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"counts"
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"counts"
]
=
torch
.
zeros
(
(
256
,
256
),
device
=
p_data_fp32
.
device
)
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
...
...
@@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer):
beta1
,
beta2
=
group
[
"betas"
]
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
step_size
=
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
step_size
=
(
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
)
e
=
state
[
"abserrors"
]
rele
=
state
[
"relerrors"
]
counts
=
state
[
"counts"
]
...
...
@@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer):
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
update_fp32
=
exp_avg
/
denom
if
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
:
if
(
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
):
# embedding layer or too small
p_data_fp32
+=
-
step_size
*
update_fp32
else
:
...
...
@@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer):
# 3. dequantize
# Error will be calculated automatically!
else
:
raise
ValueError
(
f
"Invalid analysis value:
{
self
.
analysis
}
!"
)
raise
ValueError
(
f
"Invalid analysis value:
{
self
.
analysis
}
!"
)
denom
=
state2
.
sqrt
().
add_
(
group
[
"eps"
])
update_8bit
=
state1
/
denom
...
...
@@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer):
if
self
.
savedir
!=
""
and
state
[
"step"
]
%
100
==
0
:
if
not
os
.
path
.
exists
(
self
.
savedir
):
os
.
makedirs
(
self
.
savedir
)
shapestr
=
"_"
.
join
([
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
])
shapestr
=
"_"
.
join
(
[
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
]
)
pathe
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_abserr.pkl"
)
...
...
bitsandbytes/optim/lars.py
View file @
ea7c14f8
...
...
@@ -24,7 +24,9 @@ class LARS(Optimizer1State):
max_unorm
=
0.02
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS
,
self
).
__init__
(
"lars"
,
params
,
...
...
@@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State):
max_unorm
=
0.02
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS8bit
,
self
).
__init__
(
"lars"
,
params
,
...
...
@@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State):
max_unorm
=
0.02
,
):
if
momentum
==
0
:
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS32bit
,
self
).
__init__
(
"lars"
,
params
,
...
...
@@ -121,7 +127,9 @@ class PytorchLARS(Optimizer):
if
momentum
<
0.0
:
raise
ValueError
(
"Invalid momentum value: {}"
.
format
(
momentum
))
if
weight_decay
<
0.0
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
defaults
=
dict
(
lr
=
lr
,
...
...
@@ -132,7 +140,9 @@ class PytorchLARS(Optimizer):
max_unorm
=
max_unorm
,
)
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
super
(
PytorchLARS
,
self
).
__init__
(
params
,
defaults
)
def
__setstate__
(
self
,
state
):
...
...
bitsandbytes/optim/optimizer.py
View file @
ea7c14f8
...
...
@@ -46,9 +46,13 @@ class GlobalOptimManager(object):
for
group_index
,
group
in
enumerate
(
param_groups
):
for
p_index
,
p
in
enumerate
(
group
[
"params"
]):
if
id
(
p
)
in
self
.
pid2config
:
self
.
index2config
[(
group_index
,
p_index
)]
=
self
.
pid2config
[
id
(
p
)]
self
.
index2config
[(
group_index
,
p_index
)]
=
self
.
pid2config
[
id
(
p
)
]
def
override_config
(
self
,
parameters
,
key
=
None
,
value
=
None
,
key_value_dict
=
None
):
def
override_config
(
self
,
parameters
,
key
=
None
,
value
=
None
,
key_value_dict
=
None
):
"""
Overrides initial optimizer config for specific parameters.
...
...
@@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer):
if
len
(
groups
)
!=
len
(
saved_groups
):
raise
ValueError
(
"loaded state dict has a different number of "
"parameter groups"
"loaded state dict has a different number of "
"parameter groups"
)
param_lens
=
(
len
(
g
[
"params"
])
for
g
in
groups
)
saved_lens
=
(
len
(
g
[
"params"
])
for
g
in
saved_groups
)
...
...
@@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group
[
"params"
]
=
group
[
"params"
]
return
new_group
param_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
groups
,
saved_groups
)]
param_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
groups
,
saved_groups
)
]
self
.
__setstate__
({
"state"
:
state
,
"param_groups"
:
param_groups
})
def
to_gpu
(
self
):
...
...
@@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer):
# found the matching parameter
# init override
self
.
mng
.
pid2config
[
id
(
p
)]
=
config
self
.
mng
.
index2config
[
(
gindex
,
pindex
)]
=
self
.
mng
.
pid2config
[
id
(
p
)
]
self
.
mng
.
index2config
[
(
gindex
,
pindex
)
]
=
self
.
mng
.
pid2config
[
id
(
p
)]
found
=
True
@
torch
.
no_grad
()
...
...
@@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer):
raise
NotImplementedError
(
f
"init_state method needs to be overidden"
)
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
raise
NotImplementedError
(
f
"The update_step method needs to be overidden"
)
raise
NotImplementedError
(
f
"The update_step method needs to be overidden"
)
class
Optimizer2State
(
Optimizer8bit
):
...
...
@@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit):
betas
=
[
float
(
b
)
for
b
in
betas
]
for
i
in
range
(
len
(
betas
)):
if
not
0.0
<=
betas
[
i
]
<
1.0
:
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
Optimizer2State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
...
...
@@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit):
state
=
self
.
state
[
p
]
state
[
"step"
]
=
0
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
...
...
@@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit):
if
state
[
"step"
]
==
0
:
if
"dynamic"
not
in
self
.
name2qmap
:
self
.
fill_qmap
()
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
self
.
name2qmap
[
"udynamic"
]
=
self
.
name2qmap
[
"udynamic"
].
to
(
p
.
device
)
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
self
.
name2qmap
[
"udynamic"
]
=
self
.
name2qmap
[
"udynamic"
].
to
(
p
.
device
)
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
...
...
@@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit):
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
else
:
state
[
"max1"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max2"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max2"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max2"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
...
...
@@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit):
state
[
"new_max2"
],
config
[
"weight_decay"
],
gnorm_scale
=
gnorm_scale
,
unorm_vec
=
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
unorm_vec
=
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
max_unorm
=
config
[
"max_unorm"
],
)
...
...
@@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit):
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
for
i
in
range
(
len
(
betas
)):
if
not
0.0
<=
betas
[
i
]
<
1.0
:
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
Optimizer1State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
...
...
@@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit):
state
=
self
.
state
[
p
]
state
[
"step"
]
=
0
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
memory_format
=
torch
.
preserve_format
,
...
...
@@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit):
if
state
[
"step"
]
==
0
:
if
"dynamic"
not
in
self
.
name2qmap
:
self
.
fill_qmap
()
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
...
...
@@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit):
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
else
:
state
[
"max1"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
...
...
bitsandbytes/optim/rmsprop.py
View file @
ea7c14f8
...
...
@@ -22,7 +22,9 @@ class RMSprop(Optimizer1State):
block_wise
=
True
,
):
if
alpha
==
0
:
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
if
centered
:
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop
,
self
).
__init__
(
...
...
@@ -56,7 +58,9 @@ class RMSprop8bit(Optimizer1State):
block_wise
=
True
,
):
if
alpha
==
0
:
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
if
centered
:
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop8bit
,
self
).
__init__
(
...
...
@@ -91,7 +95,9 @@ class RMSprop32bit(Optimizer1State):
):
if
alpha
==
0
:
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
if
centered
:
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop32bit
,
self
).
__init__
(
...
...
bitsandbytes/utils.py
View file @
ea7c14f8
import
sys
import
shlex
import
subprocess
import
sys
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
...
...
quicktest.py
View file @
ea7c14f8
...
...
@@ -14,23 +14,31 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
torch
.
int8
)
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
"col32"
)
B2
,
SB
=
F
.
transform
(
B
,
"colx"
)
if
dims
==
2
:
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"col32"
,
)
else
:
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
),
"col32"
,
)
...
...
tests/test_autograd.py
View file @
ea7c14f8
...
...
@@ -18,9 +18,13 @@ req_grad_str = ["FF", "TF", "TT", "FT"]
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
True
),
(
True
,
False
)]
str_transpose
=
[
"FF"
,
"FT"
,
"TT"
,
"TF"
]
dtype
=
[
torch
.
float32
,
torch
.
float16
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
)
)
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
)
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
)
)
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}"
.
format
(
...
...
@@ -31,7 +35,9 @@ names = [
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
,
)
def
test_matmul
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dim2
=
dim2
-
(
dim2
%
16
)
...
...
@@ -79,7 +85,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
...
...
@@ -87,25 +95,35 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
# batched matrix multiply
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
]
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
)
B
=
torch
.
randn
(
size
=
(
dim1
,
dim3
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
size
=
(
dim1
,
dim3
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
)
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
...
...
@@ -115,7 +133,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
torch
.
testing
.
assert_allclose
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
torch
.
testing
.
assert_allclose
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
...
...
@@ -127,7 +147,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
...
...
@@ -135,7 +157,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
...
...
@@ -146,12 +170,16 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if
funcs
[
0
]
in
[
torch
.
matmul
]:
dim1
=
dim1
-
(
dim1
%
16
)
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
]
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
)
dimB
=
(
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim3
,
dim4
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
...
...
@@ -178,7 +206,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
...
...
@@ -186,7 +216,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
...
...
@@ -258,7 +290,16 @@ names = [
ids
=
names
,
)
def
test_matmullt
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
,
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
...
...
@@ -278,7 +319,10 @@ def test_matmullt(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
,
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
...
...
@@ -317,14 +361,18 @@ def test_matmullt(
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
...
...
@@ -332,7 +380,9 @@ def test_matmullt(
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
...
...
@@ -341,4 +391,6 @@ def test_matmullt(
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
tests/test_cuda_setup_evaluator.py
View file @
ea7c14f8
...
...
@@ -3,8 +3,12 @@ from typing import List, NamedTuple
import
pytest
from
bitsandbytes.cuda_setup
import
(
CUDA_RUNTIME_LIB
,
evaluate_cuda_setup
,
get_cuda_runtime_lib_path
,
tokenize_paths
)
from
bitsandbytes.cuda_setup
import
(
CUDA_RUNTIME_LIB
,
evaluate_cuda_setup
,
get_cuda_runtime_lib_path
,
tokenize_paths
,
)
class
InputAndExpectedOutput
(
NamedTuple
):
...
...
@@ -13,11 +17,26 @@ class InputAndExpectedOutput(NamedTuple):
HAPPY_PATH__LD_LIB_TEST_PATHS
:
List
[
InputAndExpectedOutput
]
=
[
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
":some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
:"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"some/other/dir::dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:some/other/dir"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
(
f
":some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
:"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
(
f
"some/other/dir::dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:some/other/dir"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:other/dir/libcuda.so"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
...
...
tests/test_functional.py
View file @
ea7c14f8
...
...
@@ -86,7 +86,9 @@ def teardown():
pass
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
]
)
def
test_estimate_quantiles
(
dtype
):
A
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A
=
A
.
to
(
dtype
)
...
...
@@ -190,7 +192,9 @@ def test_dynamic_blockwise_stochastic_quantization():
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
]
)
def
test_percentile_clipping
(
gtype
):
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
...
...
@@ -270,7 +274,13 @@ def mean(xx):
dim1
=
[
1024
*
2
]
dim2
=
[
1024
*
16
]
methods
=
[
(
lambda
x
,
dim
:
quant
(
x
),
lambda
x
,
dim
:
quant
(
x
),
dequant
,
dequant
,
mm_dequant
)
(
lambda
x
,
dim
:
quant
(
x
),
lambda
x
,
dim
:
quant
(
x
),
dequant
,
dequant
,
mm_dequant
,
)
]
methods
.
append
((
quant_multi
,
quant_multi
,
dequant
,
dequant
,
mm_dequant
))
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
...
...
@@ -279,11 +289,14 @@ batched = [False, True]
values
=
list
(
product
(
dim1
,
dim2
,
methods
,
batched
))
values_names
=
list
(
product
(
dim1
,
dim2
,
method_names
,
batched
))
names
=
[
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}"
.
format
(
*
vals
)
for
vals
in
values_names
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}"
.
format
(
*
vals
)
for
vals
in
values_names
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, quant_methods, batched"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, quant_methods, batched"
,
values
,
ids
=
names
)
def
test_approx_igemm
(
dim1
,
dim2
,
quant_methods
,
batched
):
dim1
=
dim1
-
(
dim1
%
32
)
dim2
=
dim2
-
(
dim2
%
32
)
...
...
@@ -339,14 +352,18 @@ names = [
]
@
pytest
.
mark
.
parametrize
(
"hidden_dim, batch_dim, transpose, seq_dim"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"hidden_dim, batch_dim, transpose, seq_dim"
,
values
,
ids
=
names
)
def
test_igemm
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
):
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
for
i
in
range
(
k
):
shapeA
=
(
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
)
shapeB
=
(
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
...
...
@@ -394,7 +411,9 @@ seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
))
names
=
[
"seq_dim{0}_hidden_dim{1}_batch_dim{2}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"seq_dim{0}_hidden_dim{1}_batch_dim{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim"
,
values
,
ids
=
names
)
...
...
@@ -406,11 +425,13 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
"cuda"
)
.
to
(
torch
.
int8
)
out2
=
torch
.
einsum
(
"bsi, bso->io"
,
A
.
float
(),
B
.
float
())
iout
=
torch
.
empty
(
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
iout
=
torch
.
empty
(
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
)
...
...
@@ -428,7 +449,9 @@ names = [
]
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim, transpose"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim, transpose"
,
values
,
ids
=
names
)
def
test_minmax_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
):
def
min_max
(
x
):
maxA
=
torch
.
amax
(
x
,
dim
=
2
,
keepdim
=
True
)
...
...
@@ -444,7 +467,9 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2
=
[]
relerrs2
=
[]
for
i
in
range
(
k
):
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
)
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
)
if
transpose
:
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
"cuda"
)
else
:
...
...
@@ -504,7 +529,8 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose
=
[(
False
,
False
),
(
True
,
False
),
(
False
,
True
),
(
True
,
True
)]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
))
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}"
.
format
(
*
vals
)
for
vals
in
values
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}"
.
format
(
*
vals
)
for
vals
in
values
]
...
...
@@ -529,7 +555,9 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
)
elif
transpose
[
0
]
and
transpose
[
1
]:
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
())
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
()
)
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
.
float
())
...
...
@@ -563,7 +591,9 @@ a_order = ["row"]
out_order
=
[
"col"
,
"row"
,
"col32"
]
transpose
=
[
False
]
dims
=
[
2
,
3
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
)
)
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}"
.
format
(
...
...
@@ -574,9 +604,13 @@ names = [
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
,
)
def
test_nvidia_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
def
test_nvidia_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
if
dims
==
3
and
out_order
!=
"col32"
:
return
if
dtype
==
torch
.
int32
and
out_order
!=
"col32"
:
...
...
@@ -586,7 +620,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
...
...
@@ -598,7 +634,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if
dims
==
2
:
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
elif
dims
==
3
:
n
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
n
=
(
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
)
assert
out
.
numel
()
==
n
elif
orderOut
==
"col_turing"
:
# 32 col 8 row tiles
...
...
@@ -613,7 +653,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
j
=
col
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
rowtile
=
((
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
))
*
total_coltile
rowtile
=
(
(
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
)
)
*
total_coltile
offset
=
32
*
8
*
(
rowtile
+
coltile
)
col2
=
col
%
32
row2
=
(
row
%
8
)
*
32
...
...
@@ -624,7 +666,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if
orderOut
==
"col32"
:
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
torch
.
testing
.
assert_allclose
(
A
,
out2
)
...
...
@@ -657,10 +701,12 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch
.
int8
)
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
"col32"
)
...
...
@@ -670,7 +716,9 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
# transpose
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
B2t
,
SBt
=
F
.
transform
(
B
,
"col_turing"
,
transpose
=
True
)
...
...
@@ -688,7 +736,8 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
))
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}"
.
format
(
*
vals
)
for
vals
in
values
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}"
.
format
(
*
vals
)
for
vals
in
values
]
...
...
@@ -699,7 +748,9 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
if
dims
==
2
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
half
()
elif
dims
==
3
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
half
()
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
"cuda"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
...
...
@@ -742,7 +793,9 @@ values = [
# values = list(product(batch, seq, model, hidden))
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
...
...
@@ -909,7 +962,9 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256))
formatB
=
[
"col_turing"
,
"col_ampere"
]
values
=
list
(
product
(
dim1
,
dim4
,
dims
,
formatB
))
names
=
[
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, dims, formatB"
,
values
,
ids
=
names
)
...
...
@@ -992,7 +1047,9 @@ def test_colrow_absmax(dim1, dim2, dims):
torch
.
testing
.
assert_allclose
(
row_stats1_trunc
,
row_stats2
)
torch
.
testing
.
assert_allclose
(
nnz_block_ptr1
,
nnz_block_ptr2
)
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
0.0
)
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
0.0
)
torch
.
testing
.
assert_allclose
(
col_stats1
,
col_stats2
)
torch
.
testing
.
assert_allclose
(
row_stats1
,
row_stats2
)
...
...
@@ -1023,8 +1080,12 @@ def test_double_quant(dim1, dim2):
torch
.
testing
.
assert_allclose
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
n
=
CAt
.
numel
()
num_not_close_rows
=
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_cols
=
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_rows
=
(
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
)
num_not_close_cols
=
(
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
)
# allow for 1:500 error due to rounding differences
min_error
=
1
/
500
...
...
@@ -1123,7 +1184,9 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
c
=
10.0
*
inner
*
scale
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
"row"
,
state
=
SC
)
maxval
=
torch
.
abs
(
C3
).
max
()
if
maxval
==
127
:
...
...
@@ -1204,7 +1267,9 @@ def test_row_scale_bench(dim1, dim4, inner):
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
torch
.
cuda
.
synchronize
()
print
(
"row-wise"
,
time
.
time
()
-
t0
)
...
...
@@ -1230,7 +1295,9 @@ a_order = ["row"]
out_order
=
[
"col32"
,
"col_turing"
,
"col_ampere"
]
transpose
=
[
False
,
True
]
dims
=
[
2
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
)
)
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}"
.
format
(
*
vals
...
...
@@ -1240,14 +1307,20 @@ names = [
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
,
)
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
for
i
in
range
(
k
):
if
dims
==
2
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
elif
dims
==
3
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
A
.
view
(
-
1
)[
-
1
]
=
-
1
if
transpose
:
...
...
@@ -1282,7 +1355,9 @@ names = [
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, orderA, orderOut"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, orderA, orderOut"
,
values
,
ids
=
names
)
def
test_transform_to_row
(
dim1
,
dim2
,
dtype
,
orderA
,
orderOut
):
for
i
in
range
(
1
):
A
=
torch
.
randint
(
-
127
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
...
...
@@ -1332,17 +1407,23 @@ def test_coo_double_quant(dim1, dim2):
idx
=
torch
.
abs
(
A
)
>=
threshold
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
if
coo_tensor
is
not
None
:
A1
=
A
*
idx
A2
=
torch
.
zeros_like
(
A
)
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()]
=
coo_tensor
.
values
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()
]
=
coo_tensor
.
values
torch
.
testing
.
assert_allclose
(
A1
,
A2
)
A1
=
A
*
(
idx
==
0
)
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
torch
.
testing
.
assert_allclose
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
torch
.
testing
.
assert_allclose
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
n
=
2
...
...
@@ -1454,7 +1535,9 @@ def test_integrated_sparse_decomp(dim1, dim2):
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
...
...
@@ -1494,7 +1577,9 @@ dim2 = [12288]
dtype
=
[
torch
.
float16
]
out_function
=
[
"zeros"
,
"ones"
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
out_function
))
names
=
[
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, out_func"
,
values
,
ids
=
names
)
...
...
@@ -1536,7 +1621,9 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
std
=
out1
.
std
()
out1
/=
std
out2
/=
std
assert_all_approx_close
(
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
assert_all_approx_close
(
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
...
...
@@ -1734,7 +1821,9 @@ values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
...
...
tests/test_modules.py
View file @
ea7c14f8
...
...
@@ -48,7 +48,9 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
class
LinearFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
round_func
=
(
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
# std = torch.abs(x).mean()*norm
std
=
torch
.
std
(
x
)
...
...
@@ -116,7 +118,9 @@ class LinearFunction(torch.autograd.Function):
return
x
.
to
(
dtype
)
def
get_8bit_linear
(
x
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
round_func
=
(
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
max1
=
torch
.
abs
(
x
).
max
()
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
/
127
*
max1
...
...
@@ -125,7 +129,9 @@ class LinearFunction(torch.autograd.Function):
@
staticmethod
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
round_func
=
(
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
x
=
(
x
*
127
)
/
max1
...
...
@@ -209,7 +215,9 @@ class LinearFunction(torch.autograd.Function):
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
...
...
@@ -261,7 +269,9 @@ class LinearFunction(torch.autograd.Function):
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input
=
LinearFunction
.
dequant
(
...
...
@@ -338,8 +348,12 @@ def test_linear8bit():
loss2
.
backward
()
loss3
.
backward
()
assert_all_approx_close
(
l1
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l3
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l1
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l3
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l1
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
...
...
@@ -388,7 +402,9 @@ def test_linear8bitlt_accumulated_gradient():
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
)
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
)
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
...
...
@@ -462,7 +478,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
cuda
()
.
half
()
)
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
...
...
@@ -475,7 +495,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
cuda
()
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
half
()
.
cuda
()
)
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
...
@@ -488,7 +512,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
to
(
"cuda"
)
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
half
()
.
to
(
"cuda"
)
)
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
...
tests/test_optim.py
View file @
ea7c14f8
...
...
@@ -103,20 +103,26 @@ str2statenames["adam8bit_blockwise"] = [
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"momentum8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"momentum8bit"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)
]
str2statenames
[
"momentum8bit_blockwise"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
str2statenames
[
"lars8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[
(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
,
"lamb"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
...
@@ -203,9 +209,13 @@ def test_global_config(dim1, dim2, gtype):
eps
=
1e-8
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"optim_bits"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"optim_bits"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
(
[
p1
,
p2
,
p3
]
)
p1
=
p1
.
cuda
()
p2
=
p2
.
cuda
()
p3
=
p3
.
cuda
()
...
...
@@ -245,7 +255,9 @@ optimizer_names = [
"rmsprop8bit_blockwise"
,
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
...
@@ -329,8 +341,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"opt.pt"
)))
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
])
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
])
torch
.
testing
.
assert_allclose
(
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
]
)
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
]
)
if
"blockwise"
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
...
...
@@ -349,12 +365,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
num_not_close
=
(
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
,
)
==
0
)
assert
num_not_close
.
sum
().
item
()
<
20
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
...
...
@@ -375,7 +396,10 @@ dim2 = [32, 1024, 4097]
gtype
=
[
torch
.
float32
]
optim_bits
=
[
32
,
8
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optim_bits
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_bits"
,
values
,
ids
=
names
)
...
...
@@ -391,7 +415,12 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p2
=
p1
.
clone
()
adam1
=
bnb
.
optim
.
Adam
([
p1
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
)
adam2
=
bnb
.
optim
.
Adam
(
[
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
[
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
,
)
gnorm_vec
=
torch
.
zeros
(
100
).
cuda
()
...
...
@@ -399,7 +428,9 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
for
i
in
range
(
50
):
step
+=
1
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
g2
=
g1
.
clone
()
p2
.
grad
=
g2
...
...
@@ -430,10 +461,16 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
elif
optim_bits
==
8
:
torch
.
testing
.
assert_allclose
(
p1
,
p2
,
atol
=
1e-4
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
atol
=
2
,
rtol
=
1e-3
adam1
.
state
[
p1
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
atol
=
2
,
rtol
=
1e-3
,
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
atol
=
2
,
rtol
=
1e-3
adam1
.
state
[
p1
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
atol
=
2
,
rtol
=
1e-3
,
)
adam1
.
state
[
p1
][
"state1"
].
copy_
(
adam2
.
state
[
p2
][
"state1"
])
adam1
.
state
[
p1
][
"state2"
].
copy_
(
adam2
.
state
[
p2
][
"state2"
])
...
...
@@ -463,7 +500,9 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names
=
[
"adam8bit_blockwise"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment