Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
ea7c14f8
Commit
ea7c14f8
authored
Aug 01, 2022
by
Titus von Koeller
Browse files
reran black with linelength 80 for greater readability
parent
3fd06fb6
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
665 additions
and
203 deletions
+665
-203
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+7
-2
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+36
-9
bitsandbytes/cuda_setup.py
bitsandbytes/cuda_setup.py
+38
-7
bitsandbytes/functional.py
bitsandbytes/functional.py
+99
-29
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+28
-6
bitsandbytes/optim/adagrad.py
bitsandbytes/optim/adagrad.py
+9
-3
bitsandbytes/optim/adam.py
bitsandbytes/optim/adam.py
+21
-6
bitsandbytes/optim/lars.py
bitsandbytes/optim/lars.py
+15
-5
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+56
-21
bitsandbytes/optim/rmsprop.py
bitsandbytes/optim/rmsprop.py
+9
-3
bitsandbytes/utils.py
bitsandbytes/utils.py
+1
-1
quicktest.py
quicktest.py
+14
-6
tests/test_autograd.py
tests/test_autograd.py
+74
-22
tests/test_cuda_setup_evaluator.py
tests/test_cuda_setup_evaluator.py
+26
-7
tests/test_functional.py
tests/test_functional.py
+138
-49
tests/test_modules.py
tests/test_modules.py
+39
-11
tests/test_optim.py
tests/test_optim.py
+55
-16
No files found.
bitsandbytes/__init__.py
View file @
ea7c14f8
...
@@ -3,8 +3,13 @@
...
@@ -3,8 +3,13 @@
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.autograd._functions
import
(
MatmulLtState
,
bmm_cublas
,
matmul
,
from
.autograd._functions
import
(
matmul_cublas
,
mm_cublas
)
MatmulLtState
,
bmm_cublas
,
matmul
,
matmul_cublas
,
mm_cublas
,
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
from
.nn
import
modules
...
...
bitsandbytes/autograd/_functions.py
View file @
ea7c14f8
...
@@ -111,7 +111,9 @@ class MatMul8bit(torch.autograd.Function):
...
@@ -111,7 +111,9 @@ class MatMul8bit(torch.autograd.Function):
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
igrad_B
,
...
@@ -146,7 +148,11 @@ class MatMul8bit(torch.autograd.Function):
...
@@ -146,7 +148,11 @@ class MatMul8bit(torch.autograd.Function):
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
grad_A
=
F
.
vectorwise_mm_dequant
(
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
,
)
)
return
grad_A
,
grad_B
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
...
@@ -211,7 +217,9 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -211,7 +217,9 @@ class MatMul8bitLt(torch.autograd.Function):
# 1. Quantize A
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
has_fp16_weights
:
if
state
.
has_fp16_weights
:
...
@@ -225,7 +233,9 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -225,7 +233,9 @@ class MatMul8bitLt(torch.autograd.Function):
if
state
.
CxB
is
None
:
if
state
.
CxB
is
None
:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# # generate outlier index and subB
...
@@ -259,7 +269,13 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -259,7 +269,13 @@ class MatMul8bitLt(torch.autograd.Function):
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
state
.
reset_grads
()
state
.
reset_grads
()
CB
,
state
.
CBt
,
state
.
SCB
,
state
.
SCBt
,
coo_tensorB
=
F
.
double_quant
(
B
)
(
CB
,
state
.
CBt
,
state
.
SCB
,
state
.
SCBt
,
coo_tensorB
,
)
=
F
.
double_quant
(
B
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
else
:
else
:
has_grad
=
False
has_grad
=
False
...
@@ -277,7 +293,10 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -277,7 +293,10 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
# state.idx = outlier_idx
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
state
.
subB
=
(
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
)
.
t
()
.
contiguous
()
.
half
()
)
)
CA
[:,
state
.
idx
.
long
()]
=
0
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
...
@@ -325,10 +344,14 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -325,10 +344,14 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt
,
idx
=
ctx
.
tensor_states
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
formatB
=
ctx
.
formatB
state
=
ctx
.
state
state
=
ctx
.
state
assert
state
.
has_fp16_weights
,
"Backprop only supported for fp16 weights."
assert
(
state
.
has_fp16_weights
),
"Backprop only supported for fp16 weights."
if
len
(
grad_output
.
shape
)
==
3
:
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]
).
contiguous
()
grad_A
=
grad_B
=
None
grad_A
=
grad_B
=
None
...
@@ -359,7 +382,11 @@ matmul = MatMul8bitLt.apply
...
@@ -359,7 +382,11 @@ matmul = MatMul8bitLt.apply
def
matmul
(
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
,
):
):
state
=
state
or
MatmulLtState
()
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
if
threshold
>
0.0
:
...
...
bitsandbytes/cuda_setup.py
View file @
ea7c14f8
"""
"""
build is dependent on
extract factors the
build is dependent on
:
-
compute capability
[X]
compute capability
- dependent on GPU family
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- CUDA version
- Software:
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
...
@@ -19,6 +19,8 @@ evaluation:
...
@@ -19,6 +19,8 @@ evaluation:
"""
"""
import
ctypes
import
ctypes
import
shlex
import
subprocess
from
os
import
environ
as
env
from
os
import
environ
as
env
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Set
,
Union
from
typing
import
Set
,
Union
...
@@ -26,10 +28,31 @@ from typing import Set, Union
...
@@ -26,10 +28,31 @@ from typing import Set, Union
from
.utils
import
print_err
,
warn_of_missing_prerequisite
from
.utils
import
print_err
,
warn_of_missing_prerequisite
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
_decode
(
subprocess_err_out_tuple
):
return
tuple
(
to_decode
.
decode
(
"UTF-8"
).
strip
()
for
to_decode
in
subprocess_err_out_tuple
)
def
execute_and_return_decoded_std_streams
(
command_string
):
return
_decode
(
subprocess
.
Popen
(
shlex
.
split
(
command_string
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
).
communicate
()
)
std_out
,
std_err
=
execute_and_return_decoded_std_streams
()
return
std_out
,
std_err
def
check_cuda_result
(
cuda
,
result_val
):
def
check_cuda_result
(
cuda
,
result_val
):
if
result_val
!=
0
:
if
result_val
!=
0
:
# TODO: undefined name 'error_str'
cuda
.
cuGetErrorString
(
result_val
,
ctypes
.
byref
(
error_str
))
cuda
.
cuGetErrorString
(
result_val
,
ctypes
.
byref
(
error_str
))
print
(
f
"Count not initialize CUDA - failure!"
)
print
(
"Count not initialize CUDA - failure!"
)
raise
Exception
(
"CUDA exception!"
)
raise
Exception
(
"CUDA exception!"
)
return
result_val
return
result_val
...
@@ -53,7 +76,9 @@ def get_compute_capability():
...
@@ -53,7 +76,9 @@ def get_compute_capability():
result
=
ctypes
.
c_int
()
result
=
ctypes
.
c_int
()
device
=
ctypes
.
c_int
()
device
=
ctypes
.
c_int
()
# TODO: local variable 'context' is assigned to but never used
context
=
ctypes
.
c_void_p
()
context
=
ctypes
.
c_void_p
()
# TODO: local variable 'error_str' is assigned to but never used
error_str
=
ctypes
.
c_char_p
()
error_str
=
ctypes
.
c_char_p
()
result
=
check_cuda_result
(
cuda
,
cuda
.
cuInit
(
0
))
result
=
check_cuda_result
(
cuda
,
cuda
.
cuInit
(
0
))
...
@@ -61,7 +86,9 @@ def get_compute_capability():
...
@@ -61,7 +86,9 @@ def get_compute_capability():
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGetCount
(
ctypes
.
byref
(
nGpus
)))
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGetCount
(
ctypes
.
byref
(
nGpus
)))
ccs
=
[]
ccs
=
[]
for
i
in
range
(
nGpus
.
value
):
for
i
in
range
(
nGpus
.
value
):
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
i
))
result
=
check_cuda_result
(
cuda
,
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
i
)
)
result
=
check_cuda_result
(
result
=
check_cuda_result
(
cuda
,
cuda
,
cuda
.
cuDeviceComputeCapability
(
cuda
.
cuDeviceComputeCapability
(
...
@@ -114,11 +141,15 @@ def get_cuda_runtime_lib_path(
...
@@ -114,11 +141,15 @@ def get_cuda_runtime_lib_path(
}
-
non_existent_directories
}
-
non_existent_directories
if
len
(
cuda_runtime_libs
)
>
1
:
if
len
(
cuda_runtime_libs
)
>
1
:
err_msg
=
f
"Found duplicate
{
CUDA_RUNTIME_LIB
}
files:
{
cuda_runtime_libs
}
.."
err_msg
=
(
f
"Found duplicate
{
CUDA_RUNTIME_LIB
}
files:
{
cuda_runtime_libs
}
.."
)
raise
FileNotFoundError
(
err_msg
)
raise
FileNotFoundError
(
err_msg
)
elif
len
(
cuda_runtime_libs
)
<
1
:
elif
len
(
cuda_runtime_libs
)
<
1
:
err_msg
=
f
"Did not find
{
CUDA_RUNTIME_LIB
}
files:
{
cuda_runtime_libs
}
.."
err_msg
=
(
f
"Did not find
{
CUDA_RUNTIME_LIB
}
files:
{
cuda_runtime_libs
}
.."
)
raise
FileNotFoundError
(
err_msg
)
raise
FileNotFoundError
(
err_msg
)
single_cuda_runtime_lib_dir
=
next
(
iter
(
cuda_runtime_libs
))
single_cuda_runtime_lib_dir
=
next
(
iter
(
cuda_runtime_libs
))
...
...
bitsandbytes/functional.py
View file @
ea7c14f8
...
@@ -17,14 +17,29 @@ if COMPILED_WITH_CUDA:
...
@@ -17,14 +17,29 @@ if COMPILED_WITH_CUDA:
"""C FUNCTIONS FOR OPTIMIZERS"""
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{}
str2optimizer32bit
=
{}
str2optimizer32bit
[
"adam"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
"adam"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
"momentum"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
"momentum"
]
=
(
str2optimizer32bit
[
"rmsprop"
]
=
(
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
)
lib
.
cmomentum32bit_g32
,
str2optimizer32bit
[
"adagrad"
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
)
lib
.
cmomentum32bit_g16
,
str2optimizer32bit
[
"lars"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
)
str2optimizer32bit
[
"rmsprop"
]
=
(
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
,
)
str2optimizer32bit
[
"adagrad"
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
,
)
str2optimizer32bit
[
"lars"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
,
)
str2optimizer32bit
[
"lamb"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
"lamb"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer8bit
=
{}
str2optimizer8bit
=
{}
str2optimizer8bit
[
"adam"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
"adam"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
,
)
str2optimizer8bit
[
"momentum"
]
=
(
str2optimizer8bit
[
"momentum"
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
,
lib
.
cmomentum_static_8bit_g16
,
...
@@ -33,7 +48,10 @@ if COMPILED_WITH_CUDA:
...
@@ -33,7 +48,10 @@ if COMPILED_WITH_CUDA:
lib
.
crmsprop_static_8bit_g32
,
lib
.
crmsprop_static_8bit_g32
,
lib
.
crmsprop_static_8bit_g16
,
lib
.
crmsprop_static_8bit_g16
,
)
)
str2optimizer8bit
[
"lamb"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
)
str2optimizer8bit
[
"lamb"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
,
)
str2optimizer8bit
[
"lars"
]
=
(
str2optimizer8bit
[
"lars"
]
=
(
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g32
,
lib
.
cmomentum_static_8bit_g16
,
lib
.
cmomentum_static_8bit_g16
,
...
@@ -137,7 +155,9 @@ def create_dynamic_map(signed=True, n=7):
...
@@ -137,7 +155,9 @@ def create_dynamic_map(signed=True, n=7):
if
not
signed
:
if
not
signed
:
additional_items
=
2
*
additional_items
additional_items
=
2
*
additional_items
for
i
in
range
(
n
):
for
i
in
range
(
n
):
fraction_items
=
2
**
(
i
+
7
-
n
)
+
1
if
signed
else
2
**
(
i
+
7
-
n
+
1
)
+
1
fraction_items
=
(
2
**
(
i
+
7
-
n
)
+
1
if
signed
else
2
**
(
i
+
7
-
n
+
1
)
+
1
)
boundaries
=
torch
.
linspace
(
0.1
,
1
,
fraction_items
)
boundaries
=
torch
.
linspace
(
0.1
,
1
,
fraction_items
)
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
means
=
(
boundaries
[:
-
1
]
+
boundaries
[
1
:])
/
2.0
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
data
+=
((
10
**
(
-
(
n
-
1
)
+
i
))
*
means
).
tolist
()
...
@@ -272,7 +292,13 @@ def get_transform_buffer(
...
@@ -272,7 +292,13 @@ def get_transform_buffer(
def
nvidia_transform
(
def
nvidia_transform
(
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
,
):
):
if
state
is
None
:
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
state
=
(
A
.
shape
,
from_order
)
...
@@ -352,7 +378,11 @@ def estimate_quantiles(
...
@@ -352,7 +378,11 @@ def estimate_quantiles(
def
quantize_blockwise
(
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
rand
=
None
,
out
:
Tensor
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Quantize tensor A in blocks of size 4096 values.
Quantize tensor A in blocks of size 4096 values.
...
@@ -629,7 +659,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
...
@@ -629,7 +659,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
"""
"""
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
lib
.
cdequantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
())
)
return
out
return
out
...
@@ -1005,7 +1037,9 @@ def histogram_scatter_add_2d(
...
@@ -1005,7 +1037,9 @@ def histogram_scatter_add_2d(
)
)
def
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
int8
):
def
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
int8
):
if
not
torch
.
cuda
.
is_initialized
():
if
not
torch
.
cuda
.
is_initialized
():
torch
.
cuda
.
init
()
torch
.
cuda
.
init
()
if
A
.
dtype
!=
expected_type
or
B
.
dtype
!=
expected_type
:
if
A
.
dtype
!=
expected_type
or
B
.
dtype
!=
expected_type
:
...
@@ -1097,7 +1131,11 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
...
@@ -1097,7 +1131,11 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
def
igemm
(
def
igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
,
):
):
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
if
out
is
None
:
if
out
is
None
:
...
@@ -1193,7 +1231,11 @@ def igemm(
...
@@ -1193,7 +1231,11 @@ def igemm(
def
batched_igemm
(
def
batched_igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
,
):
):
if
not
len
(
A
.
shape
)
==
3
or
not
len
(
B
.
shape
)
==
3
:
if
not
len
(
A
.
shape
)
==
3
or
not
len
(
B
.
shape
)
==
3
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1392,9 +1434,13 @@ def mm_dequant(
...
@@ -1392,9 +1434,13 @@ def mm_dequant(
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty
(
out_shape
,
dtype
=
torch
.
float16
,
device
=
A
.
device
)
out
=
torch
.
empty
(
out_shape
,
dtype
=
torch
.
float16
,
device
=
A
.
device
)
if
new_row_stats
is
None
:
if
new_row_stats
is
None
:
new_row_stats
=
torch
.
empty
(
out_shape
[
0
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
new_row_stats
=
torch
.
empty
(
out_shape
[
0
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
if
new_col_stats
is
None
:
if
new_col_stats
is
None
:
new_col_stats
=
torch
.
empty
(
out_shape
[
1
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
new_col_stats
=
torch
.
empty
(
out_shape
[
1
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
assert
(
assert
(
new_row_stats
.
shape
[
0
]
==
row_stats
.
shape
[
0
]
new_row_stats
.
shape
[
0
]
==
row_stats
.
shape
[
0
]
),
f
"
{
new_row_stats
.
shape
}
vs
{
row_stats
.
shape
}
"
),
f
"
{
new_row_stats
.
shape
}
vs
{
row_stats
.
shape
}
"
...
@@ -1440,13 +1486,13 @@ def get_colrow_absmax(
...
@@ -1440,13 +1486,13 @@ def get_colrow_absmax(
col_tiles
=
(
cols
+
255
)
//
256
col_tiles
=
(
cols
+
255
)
//
256
tiled_rows
=
((
rows
+
15
)
//
16
)
*
16
tiled_rows
=
((
rows
+
15
)
//
16
)
*
16
if
row_stats
is
None
:
if
row_stats
is
None
:
row_stats
=
torch
.
empty
(
(
rows
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
row_stats
=
torch
.
empty
(
-
50000.0
(
rows
,),
dtype
=
torch
.
float32
,
device
=
device
)
)
.
fill_
(
-
50000.0
)
if
col_stats
is
None
:
if
col_stats
is
None
:
col_stats
=
torch
.
empty
(
(
cols
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
col_stats
=
torch
.
empty
(
-
50000.0
(
cols
,),
dtype
=
torch
.
float32
,
device
=
device
)
)
.
fill_
(
-
50000.0
)
if
nnz_block_ptr
is
None
and
threshold
>
0.0
:
if
nnz_block_ptr
is
None
and
threshold
>
0.0
:
nnz_block_ptr
=
torch
.
zeros
(
nnz_block_ptr
=
torch
.
zeros
(
...
@@ -1462,7 +1508,13 @@ def get_colrow_absmax(
...
@@ -1462,7 +1508,13 @@ def get_colrow_absmax(
prev_device
=
pre_call
(
A
.
device
)
prev_device
=
pre_call
(
A
.
device
)
lib
.
cget_col_row_stats
(
lib
.
cget_col_row_stats
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
,
)
)
post_call
(
prev_device
)
post_call
(
prev_device
)
...
@@ -1526,7 +1578,9 @@ class CSCSparseTensor(object):
...
@@ -1526,7 +1578,9 @@ class CSCSparseTensor(object):
def
coo2csr
(
cooA
):
def
coo2csr
(
cooA
):
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
values
.
add_
(
1
)
values
.
add_
(
1
)
rowptr
=
torch
.
zeros
((
cooA
.
rows
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
rowidx
.
device
)
rowptr
=
torch
.
zeros
(
(
cooA
.
rows
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
rowidx
.
device
)
rowptr
.
scatter_
(
index
=
values
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
rowptr
.
scatter_
(
index
=
values
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
rowptr
.
cumsum_
(
0
)
rowptr
.
cumsum_
(
0
)
return
CSRSparseTensor
(
return
CSRSparseTensor
(
...
@@ -1540,10 +1594,14 @@ def coo2csc(cooA):
...
@@ -1540,10 +1594,14 @@ def coo2csc(cooA):
values
=
cooA
.
values
[
col2rowidx
]
values
=
cooA
.
values
[
col2rowidx
]
colvalues
,
counts
=
torch
.
unique
(
val
,
return_counts
=
True
)
colvalues
,
counts
=
torch
.
unique
(
val
,
return_counts
=
True
)
colvalues
.
add_
(
1
)
colvalues
.
add_
(
1
)
colptr
=
torch
.
zeros
((
cooA
.
cols
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
colidx
.
device
)
colptr
=
torch
.
zeros
(
(
cooA
.
cols
+
1
,),
dtype
=
torch
.
int32
,
device
=
cooA
.
colidx
.
device
)
colptr
.
scatter_
(
index
=
colvalues
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
colptr
.
scatter_
(
index
=
colvalues
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
colptr
.
cumsum_
(
0
)
colptr
.
cumsum_
(
0
)
return
CSCSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
colptr
,
rowidx
,
values
)
return
CSCSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
colptr
,
rowidx
,
values
)
def
coo_zeros
(
rows
,
cols
,
nnz
,
device
,
dtype
=
torch
.
half
):
def
coo_zeros
(
rows
,
cols
,
nnz
,
device
,
dtype
=
torch
.
half
):
...
@@ -1568,7 +1626,9 @@ def double_quant(
...
@@ -1568,7 +1626,9 @@ def double_quant(
rows
=
A
.
shape
[
0
]
rows
=
A
.
shape
[
0
]
if
row_stats
is
None
or
col_stats
is
None
:
if
row_stats
is
None
or
col_stats
is
None
:
row_stats
,
col_stats
,
nnz_row_ptr
=
get_colrow_absmax
(
A
,
threshold
=
threshold
)
row_stats
,
col_stats
,
nnz_row_ptr
=
get_colrow_absmax
(
A
,
threshold
=
threshold
)
if
out_col
is
None
:
if
out_col
is
None
:
out_col
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
out_col
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
...
@@ -1663,7 +1723,13 @@ def get_special_format_str():
...
@@ -1663,7 +1723,13 @@ def get_special_format_str():
def
transform
(
def
transform
(
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
A
,
to_order
,
from_order
=
"row"
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
,
):
):
if
state
is
None
:
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
state
=
(
A
.
shape
,
from_order
)
...
@@ -1716,7 +1782,9 @@ def transform(
...
@@ -1716,7 +1782,9 @@ def transform(
def
spmm_coo
(
cooA
,
B
,
out
=
None
):
def
spmm_coo
(
cooA
,
B
,
out
=
None
):
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty
((
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
B
.
dtype
)
out
=
torch
.
empty
(
(
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
B
.
dtype
)
nnz
=
cooA
.
nnz
nnz
=
cooA
.
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
...
@@ -1982,7 +2050,9 @@ def extract_outliers(A, SA, idx):
...
@@ -1982,7 +2050,9 @@ def extract_outliers(A, SA, idx):
assert
formatA
in
[
"col_turing"
,
"col_ampere"
]
assert
formatA
in
[
"col_turing"
,
"col_ampere"
]
assert
A
.
device
.
type
==
"cuda"
assert
A
.
device
.
type
==
"cuda"
out
=
torch
.
zeros
((
shapeA
[
0
],
idx
.
numel
()),
dtype
=
torch
.
int8
,
device
=
A
.
device
)
out
=
torch
.
zeros
(
(
shapeA
[
0
],
idx
.
numel
()),
dtype
=
torch
.
int8
,
device
=
A
.
device
)
idx_size
=
ct
.
c_int32
(
idx
.
numel
())
idx_size
=
ct
.
c_int32
(
idx
.
numel
())
rows
=
ct
.
c_int32
(
shapeA
[
0
])
rows
=
ct
.
c_int32
(
shapeA
[
0
])
...
...
bitsandbytes/nn/modules.py
View file @
ea7c14f8
...
@@ -2,8 +2,19 @@
...
@@ -2,8 +2,19 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Set
,
from
typing
import
(
Tuple
,
TypeVar
,
Union
,
overload
)
Any
,
Callable
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypeVar
,
Union
,
overload
,
)
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding):
...
@@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding):
class
Int8Params
(
torch
.
nn
.
Parameter
):
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
,
):
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
CB
=
None
...
@@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter):
...
@@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter):
return
self
.
cuda
(
device
)
return
self
.
cuda
(
device
)
else
:
else
:
new_param
=
Int8Params
(
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
,
has_fp16_weights
=
self
.
has_fp16_weights
,
)
)
...
@@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear):
...
@@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear):
threshold
=
0.0
,
threshold
=
0.0
,
index
=
None
,
index
=
None
,
):
):
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
state
=
bnb
.
MatmulLtState
()
self
.
index
=
index
self
.
index
=
index
...
@@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear):
...
@@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear):
if
threshold
>
0.0
and
not
has_fp16_weights
:
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
)
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
)
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
CB
=
self
.
weight
.
CB
...
...
bitsandbytes/optim/adagrad.py
View file @
ea7c14f8
...
@@ -23,7 +23,9 @@ class Adagrad(Optimizer1State):
...
@@ -23,7 +23,9 @@ class Adagrad(Optimizer1State):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
if
initial_accumulator_value
!=
0.0
:
...
@@ -63,7 +65,9 @@ class Adagrad8bit(Optimizer1State):
...
@@ -63,7 +65,9 @@ class Adagrad8bit(Optimizer1State):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
if
initial_accumulator_value
!=
0.0
:
...
@@ -104,7 +108,9 @@ class Adagrad32bit(Optimizer1State):
...
@@ -104,7 +108,9 @@ class Adagrad32bit(Optimizer1State):
if
not
0.0
<=
lr
:
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
if
not
0.0
<=
eps
:
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
if
initial_accumulator_value
!=
0.0
:
...
...
bitsandbytes/optim/adam.py
View file @
ea7c14f8
...
@@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer):
savedir
=
None
,
savedir
=
None
,
):
):
defaults
=
dict
(
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
,
)
)
super
(
AnalysisAdam
,
self
).
__init__
(
params
,
defaults
)
super
(
AnalysisAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
analysis
=
bnb_analysis
self
.
analysis
=
bnb_analysis
...
@@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer):
state
[
"relerrors"
]
=
torch
.
zeros
(
state
[
"relerrors"
]
=
torch
.
zeros
(
(
256
,
256
),
device
=
p_data_fp32
.
device
(
256
,
256
),
device
=
p_data_fp32
.
device
)
)
state
[
"counts"
]
=
torch
.
zeros
((
256
,
256
),
device
=
p_data_fp32
.
device
)
state
[
"counts"
]
=
torch
.
zeros
(
(
256
,
256
),
device
=
p_data_fp32
.
device
)
if
amsgrad
:
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. values
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data_fp32
)
...
@@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer):
beta1
,
beta2
=
group
[
"betas"
]
beta1
,
beta2
=
group
[
"betas"
]
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
step_size
=
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
step_size
=
(
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
)
e
=
state
[
"abserrors"
]
e
=
state
[
"abserrors"
]
rele
=
state
[
"relerrors"
]
rele
=
state
[
"relerrors"
]
counts
=
state
[
"counts"
]
counts
=
state
[
"counts"
]
...
@@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer):
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
update_fp32
=
exp_avg
/
denom
update_fp32
=
exp_avg
/
denom
if
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
:
if
(
p_data_fp32
.
numel
()
<=
8192
or
p_data_fp32
.
numel
()
>
50000
*
1000
):
# embedding layer or too small
# embedding layer or too small
p_data_fp32
+=
-
step_size
*
update_fp32
p_data_fp32
+=
-
step_size
*
update_fp32
else
:
else
:
...
@@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer):
# 3. dequantize
# 3. dequantize
# Error will be calculated automatically!
# Error will be calculated automatically!
else
:
else
:
raise
ValueError
(
f
"Invalid analysis value:
{
self
.
analysis
}
!"
)
raise
ValueError
(
f
"Invalid analysis value:
{
self
.
analysis
}
!"
)
denom
=
state2
.
sqrt
().
add_
(
group
[
"eps"
])
denom
=
state2
.
sqrt
().
add_
(
group
[
"eps"
])
update_8bit
=
state1
/
denom
update_8bit
=
state1
/
denom
...
@@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer):
...
@@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer):
if
self
.
savedir
!=
""
and
state
[
"step"
]
%
100
==
0
:
if
self
.
savedir
!=
""
and
state
[
"step"
]
%
100
==
0
:
if
not
os
.
path
.
exists
(
self
.
savedir
):
if
not
os
.
path
.
exists
(
self
.
savedir
):
os
.
makedirs
(
self
.
savedir
)
os
.
makedirs
(
self
.
savedir
)
shapestr
=
"_"
.
join
([
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
])
shapestr
=
"_"
.
join
(
[
str
(
dim
)
for
dim
in
p_data_fp32
.
shape
]
)
pathe
=
os
.
path
.
join
(
pathe
=
os
.
path
.
join
(
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_abserr.pkl"
self
.
savedir
,
f
"
{
p_id
}
_
{
shapestr
}
_abserr.pkl"
)
)
...
...
bitsandbytes/optim/lars.py
View file @
ea7c14f8
...
@@ -24,7 +24,9 @@ class LARS(Optimizer1State):
...
@@ -24,7 +24,9 @@ class LARS(Optimizer1State):
max_unorm
=
0.02
,
max_unorm
=
0.02
,
):
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS
,
self
).
__init__
(
super
(
LARS
,
self
).
__init__
(
"lars"
,
"lars"
,
params
,
params
,
...
@@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State):
...
@@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State):
max_unorm
=
0.02
,
max_unorm
=
0.02
,
):
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS8bit
,
self
).
__init__
(
super
(
LARS8bit
,
self
).
__init__
(
"lars"
,
"lars"
,
params
,
params
,
...
@@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State):
...
@@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State):
max_unorm
=
0.02
,
max_unorm
=
0.02
,
):
):
if
momentum
==
0
:
if
momentum
==
0
:
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
raise
NotImplementedError
(
f
"LARS without momentum is not supported!"
)
super
(
LARS32bit
,
self
).
__init__
(
super
(
LARS32bit
,
self
).
__init__
(
"lars"
,
"lars"
,
params
,
params
,
...
@@ -121,7 +127,9 @@ class PytorchLARS(Optimizer):
...
@@ -121,7 +127,9 @@ class PytorchLARS(Optimizer):
if
momentum
<
0.0
:
if
momentum
<
0.0
:
raise
ValueError
(
"Invalid momentum value: {}"
.
format
(
momentum
))
raise
ValueError
(
"Invalid momentum value: {}"
.
format
(
momentum
))
if
weight_decay
<
0.0
:
if
weight_decay
<
0.0
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
defaults
=
dict
(
defaults
=
dict
(
lr
=
lr
,
lr
=
lr
,
...
@@ -132,7 +140,9 @@ class PytorchLARS(Optimizer):
...
@@ -132,7 +140,9 @@ class PytorchLARS(Optimizer):
max_unorm
=
max_unorm
,
max_unorm
=
max_unorm
,
)
)
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
super
(
PytorchLARS
,
self
).
__init__
(
params
,
defaults
)
super
(
PytorchLARS
,
self
).
__init__
(
params
,
defaults
)
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
...
...
bitsandbytes/optim/optimizer.py
View file @
ea7c14f8
...
@@ -46,9 +46,13 @@ class GlobalOptimManager(object):
...
@@ -46,9 +46,13 @@ class GlobalOptimManager(object):
for
group_index
,
group
in
enumerate
(
param_groups
):
for
group_index
,
group
in
enumerate
(
param_groups
):
for
p_index
,
p
in
enumerate
(
group
[
"params"
]):
for
p_index
,
p
in
enumerate
(
group
[
"params"
]):
if
id
(
p
)
in
self
.
pid2config
:
if
id
(
p
)
in
self
.
pid2config
:
self
.
index2config
[(
group_index
,
p_index
)]
=
self
.
pid2config
[
id
(
p
)]
self
.
index2config
[(
group_index
,
p_index
)]
=
self
.
pid2config
[
id
(
p
)
]
def
override_config
(
self
,
parameters
,
key
=
None
,
value
=
None
,
key_value_dict
=
None
):
def
override_config
(
self
,
parameters
,
key
=
None
,
value
=
None
,
key_value_dict
=
None
):
"""
"""
Overrides initial optimizer config for specific parameters.
Overrides initial optimizer config for specific parameters.
...
@@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer):
if
len
(
groups
)
!=
len
(
saved_groups
):
if
len
(
groups
)
!=
len
(
saved_groups
):
raise
ValueError
(
raise
ValueError
(
"loaded state dict has a different number of "
"parameter groups"
"loaded state dict has a different number of "
"parameter groups"
)
)
param_lens
=
(
len
(
g
[
"params"
])
for
g
in
groups
)
param_lens
=
(
len
(
g
[
"params"
])
for
g
in
groups
)
saved_lens
=
(
len
(
g
[
"params"
])
for
g
in
saved_groups
)
saved_lens
=
(
len
(
g
[
"params"
])
for
g
in
saved_groups
)
...
@@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group
[
"params"
]
=
group
[
"params"
]
new_group
[
"params"
]
=
group
[
"params"
]
return
new_group
return
new_group
param_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
groups
,
saved_groups
)]
param_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
groups
,
saved_groups
)
]
self
.
__setstate__
({
"state"
:
state
,
"param_groups"
:
param_groups
})
self
.
__setstate__
({
"state"
:
state
,
"param_groups"
:
param_groups
})
def
to_gpu
(
self
):
def
to_gpu
(
self
):
...
@@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer):
# found the matching parameter
# found the matching parameter
# init override
# init override
self
.
mng
.
pid2config
[
id
(
p
)]
=
config
self
.
mng
.
pid2config
[
id
(
p
)]
=
config
self
.
mng
.
index2config
[
(
gindex
,
pindex
)]
=
self
.
mng
.
pid2config
[
self
.
mng
.
index2config
[
id
(
p
)
(
gindex
,
pindex
)
]
]
=
self
.
mng
.
pid2config
[
id
(
p
)]
found
=
True
found
=
True
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer):
raise
NotImplementedError
(
f
"init_state method needs to be overidden"
)
raise
NotImplementedError
(
f
"init_state method needs to be overidden"
)
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
def
update_step
(
self
,
group
,
p
,
gindex
,
pindex
):
raise
NotImplementedError
(
f
"The update_step method needs to be overidden"
)
raise
NotImplementedError
(
f
"The update_step method needs to be overidden"
)
class
Optimizer2State
(
Optimizer8bit
):
class
Optimizer2State
(
Optimizer8bit
):
...
@@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit):
...
@@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit):
betas
=
[
float
(
b
)
for
b
in
betas
]
betas
=
[
float
(
b
)
for
b
in
betas
]
for
i
in
range
(
len
(
betas
)):
for
i
in
range
(
len
(
betas
)):
if
not
0.0
<=
betas
[
i
]
<
1.0
:
if
not
0.0
<=
betas
[
i
]
<
1.0
:
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
Optimizer2State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
super
(
Optimizer2State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
...
@@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit):
...
@@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit):
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
state
[
"step"
]
=
0
state
[
"step"
]
=
0
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
state
[
"state1"
]
=
torch
.
zeros_like
(
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
p
,
memory_format
=
torch
.
preserve_format
,
memory_format
=
torch
.
preserve_format
,
...
@@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit):
...
@@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit):
if
state
[
"step"
]
==
0
:
if
state
[
"step"
]
==
0
:
if
"dynamic"
not
in
self
.
name2qmap
:
if
"dynamic"
not
in
self
.
name2qmap
:
self
.
fill_qmap
()
self
.
fill_qmap
()
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
self
.
name2qmap
[
"udynamic"
]
=
self
.
name2qmap
[
"udynamic"
].
to
(
p
.
device
)
p
.
device
)
self
.
name2qmap
[
"udynamic"
]
=
self
.
name2qmap
[
"udynamic"
].
to
(
p
.
device
)
state
[
"state1"
]
=
torch
.
zeros_like
(
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
p
,
...
@@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit):
...
@@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit):
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
)
else
:
else
:
state
[
"max1"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max1"
]
=
torch
.
zeros
(
state
[
"new_max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
)
state
[
"max2"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max2"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max2"
]
=
torch
.
zeros
(
state
[
"new_max2"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
)
...
@@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit):
...
@@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit):
state
[
"new_max2"
],
state
[
"new_max2"
],
config
[
"weight_decay"
],
config
[
"weight_decay"
],
gnorm_scale
=
gnorm_scale
,
gnorm_scale
=
gnorm_scale
,
unorm_vec
=
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
unorm_vec
=
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
max_unorm
=
config
[
"max_unorm"
],
max_unorm
=
config
[
"max_unorm"
],
)
)
...
@@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit):
...
@@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit):
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
for
i
in
range
(
len
(
betas
)):
for
i
in
range
(
len
(
betas
)):
if
not
0.0
<=
betas
[
i
]
<
1.0
:
if
not
0.0
<=
betas
[
i
]
<
1.0
:
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
raise
ValueError
(
f
"Invalid beta parameter at index
{
i
}
:
{
betas
[
i
]
}
"
)
if
not
0.0
<=
weight_decay
:
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
)
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
Optimizer1State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
super
(
Optimizer1State
,
self
).
__init__
(
params
,
defaults
,
optim_bits
)
...
@@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit):
...
@@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit):
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
state
[
"step"
]
=
0
state
[
"step"
]
=
0
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
if
dtype
==
torch
.
float32
or
(
dtype
==
torch
.
uint8
and
p
.
numel
()
<
4096
):
state
[
"state1"
]
=
torch
.
zeros_like
(
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
p
,
memory_format
=
torch
.
preserve_format
,
memory_format
=
torch
.
preserve_format
,
...
@@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit):
...
@@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit):
if
state
[
"step"
]
==
0
:
if
state
[
"step"
]
==
0
:
if
"dynamic"
not
in
self
.
name2qmap
:
if
"dynamic"
not
in
self
.
name2qmap
:
self
.
fill_qmap
()
self
.
fill_qmap
()
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
self
.
name2qmap
[
"dynamic"
]
=
self
.
name2qmap
[
"dynamic"
].
to
(
p
.
device
)
state
[
"state1"
]
=
torch
.
zeros_like
(
state
[
"state1"
]
=
torch
.
zeros_like
(
p
,
p
,
...
@@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit):
...
@@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit):
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
(
blocks
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
)
else
:
else
:
state
[
"max1"
]
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
state
[
"new_max1"
]
=
torch
.
zeros
(
state
[
"new_max1"
]
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
(
1
,),
dtype
=
torch
.
float32
,
device
=
p
.
device
)
)
...
...
bitsandbytes/optim/rmsprop.py
View file @
ea7c14f8
...
@@ -22,7 +22,9 @@ class RMSprop(Optimizer1State):
...
@@ -22,7 +22,9 @@ class RMSprop(Optimizer1State):
block_wise
=
True
,
block_wise
=
True
,
):
):
if
alpha
==
0
:
if
alpha
==
0
:
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
if
centered
:
if
centered
:
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop
,
self
).
__init__
(
super
(
RMSprop
,
self
).
__init__
(
...
@@ -56,7 +58,9 @@ class RMSprop8bit(Optimizer1State):
...
@@ -56,7 +58,9 @@ class RMSprop8bit(Optimizer1State):
block_wise
=
True
,
block_wise
=
True
,
):
):
if
alpha
==
0
:
if
alpha
==
0
:
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
if
centered
:
if
centered
:
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop8bit
,
self
).
__init__
(
super
(
RMSprop8bit
,
self
).
__init__
(
...
@@ -91,7 +95,9 @@ class RMSprop32bit(Optimizer1State):
...
@@ -91,7 +95,9 @@ class RMSprop32bit(Optimizer1State):
):
):
if
alpha
==
0
:
if
alpha
==
0
:
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
raise
NotImplementedError
(
f
"RMSprop with alpha==0.0 is not supported!"
)
if
centered
:
if
centered
:
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
raise
NotImplementedError
(
f
"Centered RMSprop is not supported!"
)
super
(
RMSprop32bit
,
self
).
__init__
(
super
(
RMSprop32bit
,
self
).
__init__
(
...
...
bitsandbytes/utils.py
View file @
ea7c14f8
import
sys
import
shlex
import
shlex
import
subprocess
import
subprocess
import
sys
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
...
...
quicktest.py
View file @
ea7c14f8
...
@@ -14,23 +14,31 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
...
@@ -14,23 +14,31 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
torch
.
int8
torch
.
int8
)
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
A
=
torch
.
randint
(
torch
.
int8
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
)
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
"col32"
)
A2
,
SA
=
F
.
transform
(
A
,
"col32"
)
B2
,
SB
=
F
.
transform
(
B
,
"colx"
)
B2
,
SB
=
F
.
transform
(
B
,
"colx"
)
if
dims
==
2
:
if
dims
==
2
:
C2
,
SC
=
F
.
transform
(
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
torch
.
zeros
(
A
.
shape
[
0
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
),
"col32"
,
"col32"
,
)
)
else
:
else
:
C2
,
SC
=
F
.
transform
(
C2
,
SC
=
F
.
transform
(
torch
.
zeros
(
torch
.
zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
),
),
"col32"
,
"col32"
,
)
)
...
...
tests/test_autograd.py
View file @
ea7c14f8
...
@@ -18,9 +18,13 @@ req_grad_str = ["FF", "TF", "TT", "FT"]
...
@@ -18,9 +18,13 @@ req_grad_str = ["FF", "TF", "TT", "FT"]
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
True
),
(
True
,
False
)]
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
True
),
(
True
,
False
)]
str_transpose
=
[
"FF"
,
"FT"
,
"TT"
,
"TF"
]
str_transpose
=
[
"FF"
,
"FT"
,
"TT"
,
"TF"
]
dtype
=
[
torch
.
float32
,
torch
.
float16
]
dtype
=
[
torch
.
float32
,
torch
.
float16
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
)
)
str_values
=
list
(
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
)
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
)
)
)
names
=
[
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}"
.
format
(
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}"
.
format
(
...
@@ -31,7 +35,9 @@ names = [
...
@@ -31,7 +35,9 @@ names = [
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
,
)
)
def
test_matmul
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
def
test_matmul
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dim2
=
dim2
-
(
dim2
%
16
)
dim2
=
dim2
-
(
dim2
%
16
)
...
@@ -79,7 +85,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -79,7 +85,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -87,25 +95,35 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -87,25 +95,35 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
# batched matrix multiply
# batched matrix multiply
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
]
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
)
)
B
=
torch
.
randn
(
B
=
torch
.
randn
(
size
=
(
dim1
,
dim3
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
size
=
(
dim1
,
dim3
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
)
)
target
=
torch
.
randn
(
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
)
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
...
@@ -115,7 +133,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -115,7 +133,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
torch
.
testing
.
assert_allclose
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
torch
.
testing
.
assert_allclose
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
...
@@ -127,7 +147,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -127,7 +147,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -135,7 +157,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -135,7 +157,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
...
@@ -146,12 +170,16 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -146,12 +170,16 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if
funcs
[
0
]
in
[
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
matmul
]:
dim1
=
dim1
-
(
dim1
%
16
)
dim1
=
dim1
-
(
dim1
%
16
)
A
=
torch
.
randn
(
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
]
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
)
)
dimB
=
(
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim3
,
dim4
)
dimB
=
(
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim3
,
dim4
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
)
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
...
@@ -178,7 +206,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -178,7 +206,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -186,7 +216,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -186,7 +216,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
...
@@ -258,7 +290,16 @@ names = [
...
@@ -258,7 +290,16 @@ names = [
ids
=
names
,
ids
=
names
,
)
)
def
test_matmullt
(
def
test_matmullt
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
,
):
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
...
@@ -278,7 +319,10 @@ def test_matmullt(
...
@@ -278,7 +319,10 @@ def test_matmullt(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
)
target
=
torch
.
randn
(
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
,
)
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
B2
=
B
.
clone
()
...
@@ -317,14 +361,18 @@ def test_matmullt(
...
@@ -317,14 +361,18 @@ def test_matmullt(
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradA1
=
A
.
grad
gradB1
=
B
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradA2
=
A
.
grad
gradB2
=
B
.
grad
gradB2
=
B
.
grad
...
@@ -332,7 +380,9 @@ def test_matmullt(
...
@@ -332,7 +380,9 @@ def test_matmullt(
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
...
@@ -341,4 +391,6 @@ def test_matmullt(
...
@@ -341,4 +391,6 @@ def test_matmullt(
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
tests/test_cuda_setup_evaluator.py
View file @
ea7c14f8
...
@@ -3,8 +3,12 @@ from typing import List, NamedTuple
...
@@ -3,8 +3,12 @@ from typing import List, NamedTuple
import
pytest
import
pytest
from
bitsandbytes.cuda_setup
import
(
CUDA_RUNTIME_LIB
,
evaluate_cuda_setup
,
from
bitsandbytes.cuda_setup
import
(
get_cuda_runtime_lib_path
,
tokenize_paths
)
CUDA_RUNTIME_LIB
,
evaluate_cuda_setup
,
get_cuda_runtime_lib_path
,
tokenize_paths
,
)
class
InputAndExpectedOutput
(
NamedTuple
):
class
InputAndExpectedOutput
(
NamedTuple
):
...
@@ -13,11 +17,26 @@ class InputAndExpectedOutput(NamedTuple):
...
@@ -13,11 +17,26 @@ class InputAndExpectedOutput(NamedTuple):
HAPPY_PATH__LD_LIB_TEST_PATHS
:
List
[
InputAndExpectedOutput
]
=
[
HAPPY_PATH__LD_LIB_TEST_PATHS
:
List
[
InputAndExpectedOutput
]
=
[
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
(
f
":some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
:"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
(
f
"some/other/dir::dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:some/other/dir"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
":some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
:"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
(
f
"some/other/dir::dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:some/other/dir"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
(
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:other/dir/libcuda.so"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:other/dir/libcuda.so"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
...
...
tests/test_functional.py
View file @
ea7c14f8
...
@@ -86,7 +86,9 @@ def teardown():
...
@@ -86,7 +86,9 @@ def teardown():
pass
pass
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
]
)
def
test_estimate_quantiles
(
dtype
):
def
test_estimate_quantiles
(
dtype
):
A
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A
=
A
.
to
(
dtype
)
A
=
A
.
to
(
dtype
)
...
@@ -190,7 +192,9 @@ def test_dynamic_blockwise_stochastic_quantization():
...
@@ -190,7 +192,9 @@ def test_dynamic_blockwise_stochastic_quantization():
)
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
]
)
def
test_percentile_clipping
(
gtype
):
def
test_percentile_clipping
(
gtype
):
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
"cuda"
)
...
@@ -270,7 +274,13 @@ def mean(xx):
...
@@ -270,7 +274,13 @@ def mean(xx):
dim1
=
[
1024
*
2
]
dim1
=
[
1024
*
2
]
dim2
=
[
1024
*
16
]
dim2
=
[
1024
*
16
]
methods
=
[
methods
=
[
(
lambda
x
,
dim
:
quant
(
x
),
lambda
x
,
dim
:
quant
(
x
),
dequant
,
dequant
,
mm_dequant
)
(
lambda
x
,
dim
:
quant
(
x
),
lambda
x
,
dim
:
quant
(
x
),
dequant
,
dequant
,
mm_dequant
,
)
]
]
methods
.
append
((
quant_multi
,
quant_multi
,
dequant
,
dequant
,
mm_dequant
))
methods
.
append
((
quant_multi
,
quant_multi
,
dequant
,
dequant
,
mm_dequant
))
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
...
@@ -279,11 +289,14 @@ batched = [False, True]
...
@@ -279,11 +289,14 @@ batched = [False, True]
values
=
list
(
product
(
dim1
,
dim2
,
methods
,
batched
))
values
=
list
(
product
(
dim1
,
dim2
,
methods
,
batched
))
values_names
=
list
(
product
(
dim1
,
dim2
,
method_names
,
batched
))
values_names
=
list
(
product
(
dim1
,
dim2
,
method_names
,
batched
))
names
=
[
names
=
[
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}"
.
format
(
*
vals
)
for
vals
in
values_names
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}"
.
format
(
*
vals
)
for
vals
in
values_names
]
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, quant_methods, batched"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, quant_methods, batched"
,
values
,
ids
=
names
)
def
test_approx_igemm
(
dim1
,
dim2
,
quant_methods
,
batched
):
def
test_approx_igemm
(
dim1
,
dim2
,
quant_methods
,
batched
):
dim1
=
dim1
-
(
dim1
%
32
)
dim1
=
dim1
-
(
dim1
%
32
)
dim2
=
dim2
-
(
dim2
%
32
)
dim2
=
dim2
-
(
dim2
%
32
)
...
@@ -339,14 +352,18 @@ names = [
...
@@ -339,14 +352,18 @@ names = [
]
]
@
pytest
.
mark
.
parametrize
(
"hidden_dim, batch_dim, transpose, seq_dim"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"hidden_dim, batch_dim, transpose, seq_dim"
,
values
,
ids
=
names
)
def
test_igemm
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
):
def
test_igemm
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
):
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
for
i
in
range
(
k
):
for
i
in
range
(
k
):
shapeA
=
(
shapeA
=
(
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
)
)
shapeB
=
(
shapeB
=
(
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
...
@@ -394,7 +411,9 @@ seq_dim = torch.randint(32, 512, size=(n,)).tolist()
...
@@ -394,7 +411,9 @@ seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
))
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
))
names
=
[
"seq_dim{0}_hidden_dim{1}_batch_dim{2}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"seq_dim{0}_hidden_dim{1}_batch_dim{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim"
,
values
,
ids
=
names
)
...
@@ -406,11 +425,13 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
...
@@ -406,11 +425,13 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
A
=
torch
.
randint
(
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
).
to
(
torch
.
int8
)
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
"cuda"
).
to
(
B
=
torch
.
randint
(
torch
.
int8
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
"cuda"
)
)
.
to
(
torch
.
int8
)
out2
=
torch
.
einsum
(
"bsi, bso->io"
,
A
.
float
(),
B
.
float
())
out2
=
torch
.
einsum
(
"bsi, bso->io"
,
A
.
float
(),
B
.
float
())
iout
=
torch
.
empty
(
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
iout
=
torch
.
empty
(
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
)
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
)
...
@@ -428,7 +449,9 @@ names = [
...
@@ -428,7 +449,9 @@ names = [
]
]
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim, transpose"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim, transpose"
,
values
,
ids
=
names
)
def
test_minmax_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
):
def
test_minmax_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
):
def
min_max
(
x
):
def
min_max
(
x
):
maxA
=
torch
.
amax
(
x
,
dim
=
2
,
keepdim
=
True
)
maxA
=
torch
.
amax
(
x
,
dim
=
2
,
keepdim
=
True
)
...
@@ -444,7 +467,9 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
...
@@ -444,7 +467,9 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2
=
[]
errs2
=
[]
relerrs2
=
[]
relerrs2
=
[]
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
)
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
)
if
transpose
:
if
transpose
:
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
"cuda"
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
"cuda"
)
else
:
else
:
...
@@ -504,7 +529,8 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
...
@@ -504,7 +529,8 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose
=
[(
False
,
False
),
(
True
,
False
),
(
False
,
True
),
(
True
,
True
)]
transpose
=
[(
False
,
False
),
(
True
,
False
),
(
False
,
True
),
(
True
,
True
)]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
))
names
=
[
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}"
.
format
(
*
vals
)
for
vals
in
values
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}"
.
format
(
*
vals
)
for
vals
in
values
]
]
...
@@ -529,7 +555,9 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
...
@@ -529,7 +555,9 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
float
())
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
)
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
)
elif
transpose
[
0
]
and
transpose
[
1
]:
elif
transpose
[
0
]
and
transpose
[
1
]:
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
())
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
()
)
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
.
float
())
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
.
float
())
...
@@ -563,7 +591,9 @@ a_order = ["row"]
...
@@ -563,7 +591,9 @@ a_order = ["row"]
out_order
=
[
"col"
,
"row"
,
"col32"
]
out_order
=
[
"col"
,
"row"
,
"col32"
]
transpose
=
[
False
]
transpose
=
[
False
]
dims
=
[
2
,
3
]
dims
=
[
2
,
3
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
)
)
names
=
[
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}"
.
format
(
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}"
.
format
(
...
@@ -574,9 +604,13 @@ names = [
...
@@ -574,9 +604,13 @@ names = [
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
,
)
)
def
test_nvidia_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
def
test_nvidia_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
if
dims
==
3
and
out_order
!=
"col32"
:
if
dims
==
3
and
out_order
!=
"col32"
:
return
return
if
dtype
==
torch
.
int32
and
out_order
!=
"col32"
:
if
dtype
==
torch
.
int32
and
out_order
!=
"col32"
:
...
@@ -586,7 +620,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -586,7 +620,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
...
@@ -598,7 +634,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -598,7 +634,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if
dims
==
2
:
if
dims
==
2
:
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
elif
dims
==
3
:
elif
dims
==
3
:
n
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
n
=
(
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
)
assert
out
.
numel
()
==
n
assert
out
.
numel
()
==
n
elif
orderOut
==
"col_turing"
:
elif
orderOut
==
"col_turing"
:
# 32 col 8 row tiles
# 32 col 8 row tiles
...
@@ -613,7 +653,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -613,7 +653,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
j
=
col
j
=
col
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
rowtile
=
((
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
))
*
total_coltile
rowtile
=
(
(
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
)
)
*
total_coltile
offset
=
32
*
8
*
(
rowtile
+
coltile
)
offset
=
32
*
8
*
(
rowtile
+
coltile
)
col2
=
col
%
32
col2
=
col
%
32
row2
=
(
row
%
8
)
*
32
row2
=
(
row
%
8
)
*
32
...
@@ -624,7 +666,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -624,7 +666,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if
orderOut
==
"col32"
:
if
orderOut
==
"col32"
:
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
torch
.
testing
.
assert_allclose
(
A
,
out2
)
torch
.
testing
.
assert_allclose
(
A
,
out2
)
...
@@ -657,10 +701,12 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
...
@@ -657,10 +701,12 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch
.
int8
torch
.
int8
)
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
A
=
torch
.
randint
(
torch
.
int8
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
)
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
"col32"
)
A2
,
SA
=
F
.
transform
(
A
,
"col32"
)
...
@@ -670,7 +716,9 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
...
@@ -670,7 +716,9 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
# transpose
# transpose
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
B2t
,
SBt
=
F
.
transform
(
B
,
"col_turing"
,
transpose
=
True
)
B2t
,
SBt
=
F
.
transform
(
B
,
"col_turing"
,
transpose
=
True
)
...
@@ -688,7 +736,8 @@ dims = (2,)
...
@@ -688,7 +736,8 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256))
# ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
))
names
=
[
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}"
.
format
(
*
vals
)
for
vals
in
values
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}"
.
format
(
*
vals
)
for
vals
in
values
]
]
...
@@ -699,7 +748,9 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
...
@@ -699,7 +748,9 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
half
()
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
half
()
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
half
()
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
"cuda"
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
"cuda"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C1
=
torch
.
matmul
(
A
,
B
.
t
())
...
@@ -742,7 +793,9 @@ values = [
...
@@ -742,7 +793,9 @@ values = [
# values = list(product(batch, seq, model, hidden))
# values = list(product(batch, seq, model, hidden))
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
...
@@ -909,7 +962,9 @@ dims = (2,)
...
@@ -909,7 +962,9 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256))
# ldb = list(range(256, 1*1024, 256))
formatB
=
[
"col_turing"
,
"col_ampere"
]
formatB
=
[
"col_turing"
,
"col_ampere"
]
values
=
list
(
product
(
dim1
,
dim4
,
dims
,
formatB
))
values
=
list
(
product
(
dim1
,
dim4
,
dims
,
formatB
))
names
=
[
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, dims, formatB"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, dims, formatB"
,
values
,
ids
=
names
)
...
@@ -992,7 +1047,9 @@ def test_colrow_absmax(dim1, dim2, dims):
...
@@ -992,7 +1047,9 @@ def test_colrow_absmax(dim1, dim2, dims):
torch
.
testing
.
assert_allclose
(
row_stats1_trunc
,
row_stats2
)
torch
.
testing
.
assert_allclose
(
row_stats1_trunc
,
row_stats2
)
torch
.
testing
.
assert_allclose
(
nnz_block_ptr1
,
nnz_block_ptr2
)
torch
.
testing
.
assert_allclose
(
nnz_block_ptr1
,
nnz_block_ptr2
)
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
0.0
)
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
0.0
)
torch
.
testing
.
assert_allclose
(
col_stats1
,
col_stats2
)
torch
.
testing
.
assert_allclose
(
col_stats1
,
col_stats2
)
torch
.
testing
.
assert_allclose
(
row_stats1
,
row_stats2
)
torch
.
testing
.
assert_allclose
(
row_stats1
,
row_stats2
)
...
@@ -1023,8 +1080,12 @@ def test_double_quant(dim1, dim2):
...
@@ -1023,8 +1080,12 @@ def test_double_quant(dim1, dim2):
torch
.
testing
.
assert_allclose
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
n
=
CAt
.
numel
()
n
=
CAt
.
numel
()
num_not_close_rows
=
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_rows
=
(
num_not_close_cols
=
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
)
num_not_close_cols
=
(
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
)
# allow for 1:500 error due to rounding differences
# allow for 1:500 error due to rounding differences
min_error
=
1
/
500
min_error
=
1
/
500
...
@@ -1123,7 +1184,9 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
...
@@ -1123,7 +1184,9 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
c
=
10.0
*
inner
*
scale
c
=
10.0
*
inner
*
scale
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
"row"
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
"row"
,
state
=
SC
)
maxval
=
torch
.
abs
(
C3
).
max
()
maxval
=
torch
.
abs
(
C3
).
max
()
if
maxval
==
127
:
if
maxval
==
127
:
...
@@ -1204,7 +1267,9 @@ def test_row_scale_bench(dim1, dim4, inner):
...
@@ -1204,7 +1267,9 @@ def test_row_scale_bench(dim1, dim4, inner):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
for
i
in
range
(
k
):
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
"row-wise"
,
time
.
time
()
-
t0
)
print
(
"row-wise"
,
time
.
time
()
-
t0
)
...
@@ -1230,7 +1295,9 @@ a_order = ["row"]
...
@@ -1230,7 +1295,9 @@ a_order = ["row"]
out_order
=
[
"col32"
,
"col_turing"
,
"col_ampere"
]
out_order
=
[
"col32"
,
"col_turing"
,
"col_ampere"
]
transpose
=
[
False
,
True
]
transpose
=
[
False
,
True
]
dims
=
[
2
]
dims
=
[
2
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
)
)
names
=
[
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}"
.
format
(
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}"
.
format
(
*
vals
*
vals
...
@@ -1240,14 +1307,20 @@ names = [
...
@@ -1240,14 +1307,20 @@ names = [
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
,
)
)
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
dtype
)
A
.
view
(
-
1
)[
-
1
]
=
-
1
A
.
view
(
-
1
)[
-
1
]
=
-
1
if
transpose
:
if
transpose
:
...
@@ -1282,7 +1355,9 @@ names = [
...
@@ -1282,7 +1355,9 @@ names = [
]
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, orderA, orderOut"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, orderA, orderOut"
,
values
,
ids
=
names
)
def
test_transform_to_row
(
dim1
,
dim2
,
dtype
,
orderA
,
orderOut
):
def
test_transform_to_row
(
dim1
,
dim2
,
dtype
,
orderA
,
orderOut
):
for
i
in
range
(
1
):
for
i
in
range
(
1
):
A
=
torch
.
randint
(
-
127
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
A
=
torch
.
randint
(
-
127
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"cuda"
).
to
(
dtype
)
...
@@ -1332,17 +1407,23 @@ def test_coo_double_quant(dim1, dim2):
...
@@ -1332,17 +1407,23 @@ def test_coo_double_quant(dim1, dim2):
idx
=
torch
.
abs
(
A
)
>=
threshold
idx
=
torch
.
abs
(
A
)
>=
threshold
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
if
coo_tensor
is
not
None
:
if
coo_tensor
is
not
None
:
A1
=
A
*
idx
A1
=
A
*
idx
A2
=
torch
.
zeros_like
(
A
)
A2
=
torch
.
zeros_like
(
A
)
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()]
=
coo_tensor
.
values
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()
]
=
coo_tensor
.
values
torch
.
testing
.
assert_allclose
(
A1
,
A2
)
torch
.
testing
.
assert_allclose
(
A1
,
A2
)
A1
=
A
*
(
idx
==
0
)
A1
=
A
*
(
idx
==
0
)
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
torch
.
testing
.
assert_allclose
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
torch
.
testing
.
assert_allclose
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
n
=
2
n
=
2
...
@@ -1454,7 +1535,9 @@ def test_integrated_sparse_decomp(dim1, dim2):
...
@@ -1454,7 +1535,9 @@ def test_integrated_sparse_decomp(dim1, dim2):
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
...
@@ -1494,7 +1577,9 @@ dim2 = [12288]
...
@@ -1494,7 +1577,9 @@ dim2 = [12288]
dtype
=
[
torch
.
float16
]
dtype
=
[
torch
.
float16
]
out_function
=
[
"zeros"
,
"ones"
]
out_function
=
[
"zeros"
,
"ones"
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
out_function
))
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
out_function
))
names
=
[
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, out_func"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, out_func"
,
values
,
ids
=
names
)
...
@@ -1536,7 +1621,9 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
...
@@ -1536,7 +1621,9 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
std
=
out1
.
std
()
std
=
out1
.
std
()
out1
/=
std
out1
/=
std
out2
/=
std
out2
/=
std
assert_all_approx_close
(
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
assert_all_approx_close
(
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
...
@@ -1734,7 +1821,9 @@ values.append((batch_size, seqdim, 768, 4 * 768))
...
@@ -1734,7 +1821,9 @@ values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288))
# values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
...
...
tests/test_modules.py
View file @
ea7c14f8
...
@@ -48,7 +48,9 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
...
@@ -48,7 +48,9 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
class
LinearFunction
(
torch
.
autograd
.
Function
):
class
LinearFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
round_func
=
(
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
# std = torch.abs(x).mean()*norm
# std = torch.abs(x).mean()*norm
std
=
torch
.
std
(
x
)
std
=
torch
.
std
(
x
)
...
@@ -116,7 +118,9 @@ class LinearFunction(torch.autograd.Function):
...
@@ -116,7 +118,9 @@ class LinearFunction(torch.autograd.Function):
return
x
.
to
(
dtype
)
return
x
.
to
(
dtype
)
def
get_8bit_linear
(
x
,
stochastic
=
False
):
def
get_8bit_linear
(
x
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
round_func
=
(
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
max1
=
torch
.
abs
(
x
).
max
()
max1
=
torch
.
abs
(
x
).
max
()
x
=
x
/
max1
*
127
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
/
127
*
max1
x
=
round_func
(
x
)
/
127
*
max1
...
@@ -125,7 +129,9 @@ class LinearFunction(torch.autograd.Function):
...
@@ -125,7 +129,9 @@ class LinearFunction(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
round_func
=
(
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
max1
[
max1
==
0
]
=
1.0
x
=
(
x
*
127
)
/
max1
x
=
(
x
*
127
)
/
max1
...
@@ -209,7 +215,9 @@ class LinearFunction(torch.autograd.Function):
...
@@ -209,7 +215,9 @@ class LinearFunction(torch.autograd.Function):
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
# if torch.rand(1) < 0.01:
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
# err = torch.abs(output-output32).float()
...
@@ -261,7 +269,9 @@ class LinearFunction(torch.autograd.Function):
...
@@ -261,7 +269,9 @@ class LinearFunction(torch.autograd.Function):
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input
=
LinearFunction
.
dequant
(
grad_input
=
LinearFunction
.
dequant
(
...
@@ -338,8 +348,12 @@ def test_linear8bit():
...
@@ -338,8 +348,12 @@ def test_linear8bit():
loss2
.
backward
()
loss2
.
backward
()
loss3
.
backward
()
loss3
.
backward
()
assert_all_approx_close
(
l1
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
assert_all_approx_close
(
l3
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
l1
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l3
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
assert_all_approx_close
(
l1
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
l1
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
)
...
@@ -388,7 +402,9 @@ def test_linear8bitlt_accumulated_gradient():
...
@@ -388,7 +402,9 @@ def test_linear8bitlt_accumulated_gradient():
l1
=
torch
.
nn
.
Sequential
(
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
)
)
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
)
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
...
@@ -462,7 +478,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -462,7 +478,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
cuda
()
.
half
()
)
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
...
@@ -475,7 +495,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -475,7 +495,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
cuda
()
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
half
()
.
cuda
()
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
@@ -488,7 +512,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -488,7 +512,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
to
(
"cuda"
)
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
half
()
.
to
(
"cuda"
)
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
...
tests/test_optim.py
View file @
ea7c14f8
...
@@ -103,20 +103,26 @@ str2statenames["adam8bit_blockwise"] = [
...
@@ -103,20 +103,26 @@ str2statenames["adam8bit_blockwise"] = [
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
]
str2statenames
[
"momentum8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"momentum8bit"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)
]
str2statenames
[
"momentum8bit_blockwise"
]
=
[
str2statenames
[
"momentum8bit_blockwise"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
]
str2statenames
[
"lars8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"lars8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[
(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
,
"lamb"
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
,
"lamb"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
@@ -203,9 +209,13 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -203,9 +209,13 @@ def test_global_config(dim1, dim2, gtype):
eps
=
1e-8
eps
=
1e-8
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"optim_bits"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"optim_bits"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
(
[
p1
,
p2
,
p3
]
)
p1
=
p1
.
cuda
()
p1
=
p1
.
cuda
()
p2
=
p2
.
cuda
()
p2
=
p2
.
cuda
()
p3
=
p3
.
cuda
()
p3
=
p3
.
cuda
()
...
@@ -245,7 +255,9 @@ optimizer_names = [
...
@@ -245,7 +255,9 @@ optimizer_names = [
"rmsprop8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
]
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
@@ -329,8 +341,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -329,8 +341,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"opt.pt"
)))
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"opt.pt"
)))
rm_path
(
path
)
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
])
torch
.
testing
.
assert_allclose
(
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
])
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
]
)
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
]
)
if
"blockwise"
in
optim_name
:
if
"blockwise"
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
s1
=
F
.
dequantize_blockwise
(
...
@@ -349,12 +365,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -349,12 +365,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
num_not_close
=
(
num_not_close
=
(
torch
.
isclose
(
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
,
)
)
==
0
==
0
)
)
assert
num_not_close
.
sum
().
item
()
<
20
assert
num_not_close
.
sum
().
item
()
<
20
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
# the parameters diverge quickly. Here we keep them close
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
# together so we can test against the Adam error
...
@@ -375,7 +396,10 @@ dim2 = [32, 1024, 4097]
...
@@ -375,7 +396,10 @@ dim2 = [32, 1024, 4097]
gtype
=
[
torch
.
float32
]
gtype
=
[
torch
.
float32
]
optim_bits
=
[
32
,
8
]
optim_bits
=
[
32
,
8
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optim_bits
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optim_bits
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_bits"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_bits"
,
values
,
ids
=
names
)
...
@@ -391,7 +415,12 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
...
@@ -391,7 +415,12 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p2
=
p1
.
clone
()
p2
=
p1
.
clone
()
adam1
=
bnb
.
optim
.
Adam
([
p1
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
)
adam1
=
bnb
.
optim
.
Adam
([
p1
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
)
adam2
=
bnb
.
optim
.
Adam
(
adam2
=
bnb
.
optim
.
Adam
(
[
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
[
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
,
)
)
gnorm_vec
=
torch
.
zeros
(
100
).
cuda
()
gnorm_vec
=
torch
.
zeros
(
100
).
cuda
()
...
@@ -399,7 +428,9 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
...
@@ -399,7 +428,9 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
for
i
in
range
(
50
):
for
i
in
range
(
50
):
step
+=
1
step
+=
1
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
g2
=
g1
.
clone
()
g2
=
g1
.
clone
()
p2
.
grad
=
g2
p2
.
grad
=
g2
...
@@ -430,10 +461,16 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
...
@@ -430,10 +461,16 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
elif
optim_bits
==
8
:
elif
optim_bits
==
8
:
torch
.
testing
.
assert_allclose
(
p1
,
p2
,
atol
=
1e-4
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
,
atol
=
1e-4
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
atol
=
2
,
rtol
=
1e-3
adam1
.
state
[
p1
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
atol
=
2
,
rtol
=
1e-3
,
)
)
torch
.
testing
.
assert_allclose
(
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
atol
=
2
,
rtol
=
1e-3
adam1
.
state
[
p1
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
atol
=
2
,
rtol
=
1e-3
,
)
)
adam1
.
state
[
p1
][
"state1"
].
copy_
(
adam2
.
state
[
p2
][
"state1"
])
adam1
.
state
[
p1
][
"state1"
].
copy_
(
adam2
.
state
[
p2
][
"state1"
])
adam1
.
state
[
p1
][
"state2"
].
copy_
(
adam2
.
state
[
p2
][
"state2"
])
adam1
.
state
[
p1
][
"state2"
].
copy_
(
adam2
.
state
[
p2
][
"state2"
])
...
@@ -463,7 +500,9 @@ gtype = [torch.float32, torch.float16]
...
@@ -463,7 +500,9 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['lars_apex', 'lars8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names
=
[
"adam8bit_blockwise"
]
optimizer_names
=
[
"adam8bit_blockwise"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment