Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
a24aae30
Commit
a24aae30
authored
Jul 06, 2023
by
Jeongseok Kang
Browse files
Merge branch 'main' into fix/libcuda-to-torch
parents
2b4cc256
4395d68c
Changes
48
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1016 additions
and
335 deletions
+1016
-335
deploy.sh
deploy.sh
+0
-11
examples/int8_inference_huggingface.py
examples/int8_inference_huggingface.py
+27
-0
setup.py
setup.py
+2
-2
tests/test_autograd.py
tests/test_autograd.py
+205
-10
tests/test_functional.py
tests/test_functional.py
+486
-165
tests/test_modules.py
tests/test_modules.py
+129
-41
tests/test_optim.py
tests/test_optim.py
+108
-106
tests/test_triton.py
tests/test_triton.py
+59
-0
No files found.
deploy.sh
View file @
a24aae30
...
@@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
...
@@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
fi
fi
make clean
export
CUDA_HOME
=
$BASE_PATH
/cuda-10.2
make cuda10x_nomatmul
CUDA_VERSION
=
102
if
[
!
-f
"./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so"
]
;
then
# Control will enter here if $DIRECTORY doesn't exist.
echo
"Compilation unsuccessul!"
1>&2
exit
64
fi
make clean
make clean
export
CUDA_HOME
=
$BASE_PATH
/cuda-11.0
export
CUDA_HOME
=
$BASE_PATH
/cuda-11.0
make cuda110_nomatmul
CUDA_VERSION
=
110
make cuda110_nomatmul
CUDA_VERSION
=
110
...
...
examples/int8_inference_huggingface.py
0 → 100644
View file @
a24aae30
import
torch
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
MAX_NEW_TOKENS
=
128
model_name
=
'decapoda-research/llama-7b-hf'
text
=
'Hamburg is in which country?
\n
'
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
input_ids
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
input_ids
free_in_GB
=
int
(
torch
.
cuda
.
mem_get_info
()[
0
]
/
1024
**
3
)
max_memory
=
f
'
{
int
(
torch
.
cuda
.
mem_get_info
()[
0
]
/
1024
**
3
)
-
2
}
GB'
n_gpus
=
torch
.
cuda
.
device_count
()
max_memory
=
{
i
:
max_memory
for
i
in
range
(
n_gpus
)}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
'auto'
,
load_in_8bit
=
True
,
max_memory
=
max_memory
)
generated_ids
=
model
.
generate
(
input_ids
,
max_length
=
MAX_NEW_TOKENS
)
print
(
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
))
setup.py
View file @
a24aae30
...
@@ -18,10 +18,10 @@ def read(fname):
...
@@ -18,10 +18,10 @@ def read(fname):
setup
(
setup
(
name
=
f
"bitsandbytes"
,
name
=
f
"bitsandbytes"
,
version
=
f
"0.3
8.0
"
,
version
=
f
"0.3
9.1
"
,
author
=
"Tim Dettmers"
,
author
=
"Tim Dettmers"
,
author_email
=
"dettmers@cs.washington.edu"
,
author_email
=
"dettmers@cs.washington.edu"
,
description
=
"
8
-bit optimizers and matrix multiplication routines."
,
description
=
"
k
-bit optimizers and matrix multiplication routines."
,
license
=
"MIT"
,
license
=
"MIT"
,
keywords
=
"gpu optimizers optimization 8-bit quantization compression"
,
keywords
=
"gpu optimizers optimization 8-bit quantization compression"
,
url
=
"https://github.com/TimDettmers/bitsandbytes"
,
url
=
"https://github.com/TimDettmers/bitsandbytes"
,
...
...
tests/test_autograd.py
View file @
a24aae30
...
@@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
...
@@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
)
...
@@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
)
...
@@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
...
@@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
...
@@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
...
@@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2
.
append
(
0
)
dim2
.
append
(
0
)
decomp
=
[
0.0
,
6.0
]
decomp
=
[
0.0
,
6.0
]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)
,
(
torch
.
matmul
,
bnb
.
research
.
switchback_bnb
)
]
str_funcs
=
[
"matmul
"
]
str_funcs
=
[
"matmul
lt"
,
'switchback_bnb'
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
req_grad_str
=
[]
...
@@ -407,7 +407,7 @@ def test_matmullt(
...
@@ -407,7 +407,7 @@ def test_matmullt(
bias
.
grad
=
None
bias
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
...
@@ -423,9 +423,204 @@ def test_matmullt(
...
@@ -423,9 +423,204 @@ def test_matmullt(
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
)
if
req_grad
[
2
]:
if
req_grad
[
2
]:
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
n
=
1
k
=
3
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim2
.
append
(
0
)
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul_4bit
)]
str_funcs
=
[
"matmul"
]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
for
c
in
req_grad
:
strval
=
''
for
v
in
c
:
if
v
==
True
:
strval
+=
'T'
else
:
strval
+=
'F'
req_grad_str
.
append
(
strval
)
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
,
torch
.
float32
]
compress_statistics
=
[
False
,
True
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
quant_type
=
[
'fp4'
,
'nf4'
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
,
compress_statistics
,
quant_type
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type"
,
values
,
ids
=
names
)
def
test_matmul_4bit
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
if
has_bias
==
False
:
req_grad
=
list
(
req_grad
)
req_grad
[
2
]
=
False
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
bias
=
None
bias2
=
None
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_4bit
(
B
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
quant_state
,
bias
=
bias2
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B2
,
quant_state
,
bias
=
bias2
)
if
has_bias
:
out_torch
+=
bias
assert
out_bnb
.
dtype
==
A
.
dtype
,
f
"bnb matmullt received
{
A
.
dtype
}
but returned
{
out_bnb
.
dtype
}
"
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
if
n
>
0
:
assert
err
<
0.115
#assert err < 0.20
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias1
=
bias
.
grad
bias
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias2
=
bias
.
grad
bias
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
2
]:
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
funcs
=
[(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_mixed
),
(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_global
)]
str_funcs
=
[
"matmul_fp8_mixed"
,
'matmul_fp8_global'
]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
for
c
in
req_grad
:
strval
=
''
for
v
in
c
:
if
v
==
True
:
strval
+=
'T'
else
:
strval
+=
'F'
req_grad_str
.
append
(
strval
)
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
,
torch
.
float32
]
has_fp16_weights
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
)
def
test_matmul_fp8
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
req_grad
=
list
(
req_grad
)
req_grad
[
2
]
=
False
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
A
.
device
)
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
A
.
device
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B
.
t
(),
fw_code
,
bw_code
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B
,
fw_code
,
bw_code
)
assert
out_bnb
.
dtype
==
A
.
dtype
,
f
"bnb matmullt received
{
A
.
dtype
}
but returned
{
out_bnb
.
dtype
}
"
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
if
n
>
0
:
assert
err
<
0.115
#assert err < 0.20
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
if
dim2
>
0
:
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
else
:
assert
torch
.
abs
(
gradB1
).
sum
()
==
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
==
0.0
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
grad_err
=
(
gradB1
-
gradB2
).
abs
().
mean
()
assert
grad_err
.
item
()
<
0.003
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
tests/test_functional.py
View file @
a24aae30
This diff is collapsed.
Click to expand it.
tests/test_modules.py
View file @
a24aae30
...
@@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
...
@@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
sumval
=
(
idx
==
0
).
sum
().
item
()
sumval
=
(
idx
==
0
).
sum
().
item
()
if
sumval
>
count
:
if
sumval
>
count
:
print
(
f
"Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
print
(
f
"Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
torch
.
testing
.
assert_
all
close
(
a
,
b
,
rtol
,
atol
)
torch
.
testing
.
assert_close
(
a
,
b
,
rtol
,
atol
)
class
LinearFunction
(
torch
.
autograd
.
Function
):
class
LinearFunction
(
torch
.
autograd
.
Function
):
...
@@ -330,18 +330,15 @@ def test_linear8bitlt_inference(threshold):
...
@@ -330,18 +330,15 @@ def test_linear8bitlt_inference(threshold):
def
test_linear8bitlt_accumulated_gradient
():
def
test_linear8bitlt_accumulated_gradient
():
l1
=
torch
.
nn
.
Sequential
(
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
)
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l2
=
torch
.
nn
.
Sequential
(
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
l1
[
0
].
bias
.
data
.
copy_
(
l2
[
0
].
bias
.
data
)
)
l1
[
1
].
bias
.
data
.
copy_
(
l2
[
1
].
bias
.
data
)
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
opt1
=
bnb
.
optim
.
Adam32bit
(
l1
.
parameters
(),
lr
=
0.001
)
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
opt2
=
bnb
.
optim
.
Adam32bit
(
l2
.
parameters
(),
lr
=
0.001
)
l2
[
1
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
1
].
bias
.
clone
())
opt1
=
bnb
.
optim
.
Adam8bit
(
l1
.
parameters
(),
lr
=
0.001
)
opt2
=
bnb
.
optim
.
Adam8bit
(
l2
.
parameters
(),
lr
=
0.001
)
acc_steps
=
10
acc_steps
=
10
...
@@ -371,26 +368,17 @@ def test_linear8bitlt_accumulated_gradient():
...
@@ -371,26 +368,17 @@ def test_linear8bitlt_accumulated_gradient():
# we do this copy because otherwise we have small divergences over time that add up
# we do this copy because otherwise we have small divergences over time that add up
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
l1
[
0
].
bias
.
data
.
copy_
(
l2
[
0
].
bias
.
data
)
l1
[
1
].
bias
.
data
.
copy_
(
l2
[
1
].
bias
.
data
)
else
:
else
:
torch
.
testing
.
assert_
all
close
(
l1
[
0
].
weight
.
grad
,
l2
[
0
].
weight
.
grad
)
torch
.
testing
.
assert_close
(
l1
[
0
].
weight
.
grad
,
l2
[
0
].
weight
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_
all
close
(
l1
[
1
].
weight
.
grad
,
l2
[
1
].
weight
.
grad
)
torch
.
testing
.
assert_close
(
l1
[
1
].
weight
.
grad
,
l2
[
1
].
weight
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
)
threshold
=
[
0.0
,
2.0
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
[
0.0
,
2.0
])
values
=
threshold
names
=
[
f
"threshold_
{
vals
}
"
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
False
])
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
l1
=
(
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
).
cuda
().
half
())
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
cuda
()
.
half
()
)
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
weight
.
dtype
==
torch
.
int8
l1
.
eval
()
l1
.
eval
()
...
@@ -446,13 +434,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -446,13 +434,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
(
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
).
half
().
to
(
"cuda"
))
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
half
()
.
to
(
"cuda"
)
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
@@ -499,15 +481,16 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -499,15 +481,16 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
grad_ref
=
grad_proj
.
flatten
(
2
)
@
w2
.
half
()
@
w1
.
half
()
grad_ref
=
grad_proj
.
flatten
(
2
)
@
w2
.
half
()
@
w1
.
half
()
scale
=
grad_ref
.
abs
().
mean
()
scale
=
grad_ref
.
abs
().
mean
()
torch
.
testing
.
assert_
all
close
(
b1
.
grad
,
grad_ref
,
rtol
=
0
,
atol
=
0.05
*
scale
)
torch
.
testing
.
assert_close
(
b1
.
grad
,
grad_ref
,
rtol
=
0
,
atol
=
0.05
*
scale
)
idx
=
torch
.
isclose
(
b1
.
grad
,
grad_ref
,
atol
=
0.01
*
scale
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
b1
.
grad
,
grad_ref
,
atol
=
0.01
*
scale
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<=
b1
.
numel
()
*
0.005
assert
(
idx
==
0
).
sum
().
item
()
<=
b1
.
numel
()
*
0.005
def
test_linear8bitlt_fp32_bias
():
@
pytest
.
mark
.
parametrize
(
"module"
,
[
lambda
nin
,
nout
,
bias
=
True
:
bnb
.
nn
.
Linear8bitLt
(
nin
,
nout
,
bias
=
bias
,
has_fp16_weights
=
False
),
bnb
.
nn
.
LinearFP4
],
ids
=
[
'Int8Lt'
,
'FP4'
])
def
test_linear_kbit_fp32_bias
(
module
):
# casts model to fp16 -> int8 automatically
# casts model to fp16 -> int8 automatically
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
has_fp16_weights
=
False
).
cuda
()
l1
=
module
(
32
,
64
).
cuda
()
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
weight
.
dtype
in
[
torch
.
int8
,
torch
.
u
int8
]
assert
l1
.
bias
.
dtype
==
torch
.
float32
assert
l1
.
bias
.
dtype
==
torch
.
float32
for
i
in
range
(
100
):
for
i
in
range
(
100
):
...
@@ -517,11 +500,116 @@ def test_linear8bitlt_fp32_bias():
...
@@ -517,11 +500,116 @@ def test_linear8bitlt_fp32_bias():
assert
l1
.
bias
.
dtype
==
torch
.
float16
assert
l1
.
bias
.
dtype
==
torch
.
float16
# casts model to fp16 -> int8 automatically
# casts model to fp16 -> int8 automatically
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
has_fp16_weights
=
False
,
bias
=
False
).
cuda
()
l1
=
module
(
32
,
64
,
bias
=
False
).
cuda
()
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
weight
.
dtype
in
[
torch
.
int8
,
torch
.
u
int8
]
assert
l1
.
bias
is
None
assert
l1
.
bias
is
None
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
o1
=
l1
(
b1
)
o1
=
l1
(
b1
)
assert
l1
.
bias
is
None
assert
l1
.
bias
is
None
modules
=
[]
modules
.
append
(
bnb
.
nn
.
Linear8bitLt
)
modules
.
append
(
bnb
.
nn
.
Linear4bit
)
modules
.
append
(
bnb
.
nn
.
LinearFP4
)
modules
.
append
(
bnb
.
nn
.
LinearNF4
)
modules
.
append
(
lambda
d1
,
d2
:
bnb
.
nn
.
LinearFP4
(
d1
,
d2
,
compress_statistics
=
True
))
modules
.
append
(
lambda
d1
,
d2
:
bnb
.
nn
.
LinearNF4
(
d1
,
d2
,
compress_statistics
=
True
))
names
=
[
'Int8Lt'
,
'4bit'
,
'FP4'
,
'NF4'
,
'FP4+C'
,
'NF4+C'
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"module"
,
modules
,
ids
=
names
)
def
test_kbit_backprop
(
module
):
b
=
17
dim1
=
37
dim2
=
83
ref
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
torch
.
nn
.
Linear
(
dim2
,
10
)])
ref
[
1
].
weight
.
requires_grad
=
False
torch
.
nn
.
init
.
kaiming_normal_
(
ref
[
0
].
weight
)
torch
.
nn
.
init
.
kaiming_normal_
(
ref
[
1
].
weight
)
kbit
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
module
(
dim2
,
10
)])
kbit
[
0
].
weight
.
detach
().
copy_
(
ref
[
0
].
weight
)
kbit
[
1
].
weight
.
detach
().
copy_
(
ref
[
1
].
weight
)
kbit
[
0
].
bias
.
detach
().
copy_
(
ref
[
0
].
bias
)
kbit
[
1
].
bias
.
detach
().
copy_
(
ref
[
1
].
bias
)
ref
=
ref
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
errs1
=
[]
errs2
=
[]
relerrs1
=
[]
relerrs2
=
[]
for
i
in
range
(
100
):
batch
=
torch
.
randn
(
b
,
dim1
).
half
().
cuda
()
out1
=
ref
(
batch
)
out2
=
kbit
(
batch
)
out1
.
mean
().
backward
()
out2
.
mean
().
backward
()
grad1
=
ref
[
0
].
weight
.
grad
grad2
=
kbit
[
0
].
weight
.
grad
bgrad1
=
ref
[
0
].
bias
.
grad
bgrad2
=
kbit
[
0
].
bias
.
grad
err1
=
(
out1
-
out2
).
abs
().
float
()
err2
=
(
grad1
-
grad2
).
abs
().
float
()
relerr1
=
(
err1
/
(
out1
.
abs
().
float
()
+
1e-9
))
relerr2
=
(
err2
/
(
grad1
.
abs
().
float
()
+
1e-9
))
errs1
.
append
(
err1
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
relerrs1
.
append
(
relerr1
.
mean
().
item
())
relerrs2
.
append
(
relerr2
.
mean
().
item
())
if
isinstance
(
module
,
bnb
.
nn
.
Linear8bitLt
):
torch
.
testing
.
assert_close
(
grad1
,
grad2
,
atol
=
0.008
,
rtol
=
0.05
)
torch
.
testing
.
assert_close
(
bgrad1
,
bgrad2
,
atol
=
0.008
,
rtol
=
0.05
)
else
:
torch
.
testing
.
assert_close
(
grad1
,
grad2
,
atol
=
0.015
,
rtol
=
0.05
)
torch
.
testing
.
assert_close
(
bgrad1
,
bgrad2
,
atol
=
0.02
,
rtol
=
0.05
)
ref
.
zero_grad
()
kbit
.
zero_grad
()
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
print
(
'out'
,
sum
(
errs1
)
/
len
(
errs1
))
print
(
'grad'
,
sum
(
errs2
)
/
len
(
errs2
))
print
(
'rel out'
,
sum
(
relerrs1
)
/
len
(
relerrs1
))
print
(
'rel grad'
,
sum
(
relerrs2
)
/
len
(
relerrs2
))
def
test_fp8linear
():
b
=
10
h
=
1024
inp
=
torch
.
randn
(
b
,
h
).
cuda
()
fp32
=
torch
.
nn
.
Linear
(
h
,
h
*
2
).
cuda
()
fp8
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
,
h
*
2
).
cuda
()
fp32b
=
torch
.
nn
.
Linear
(
h
*
2
,
h
).
cuda
()
fp8b
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
*
2
,
h
).
cuda
()
fp8
.
weight
.
data
.
copy_
(
fp32
.
weight
.
data
)
fp8
.
bias
.
data
.
copy_
(
fp32
.
bias
.
data
)
fp8b
.
weight
.
data
.
copy_
(
fp32b
.
weight
.
data
)
fp8b
.
bias
.
data
.
copy_
(
fp32b
.
bias
.
data
)
a
=
fp32b
(
torch
.
nn
.
functional
.
gelu
(
fp32
(
inp
)))
b
=
fp8b
(
torch
.
nn
.
functional
.
gelu
(
fp8
(
inp
)))
err
=
(
a
-
b
).
abs
().
mean
()
a
.
mean
().
backward
()
b
.
mean
().
backward
()
graderr
=
(
fp8
.
weight
.
grad
-
fp32
.
weight
.
grad
).
abs
().
mean
()
bgraderr
=
(
fp8
.
bias
.
grad
-
fp32
.
bias
.
grad
).
abs
().
mean
()
assert
err
<
0.05
assert
graderr
<
0.00002
assert
bgraderr
<
0.00002
tests/test_optim.py
View file @
a24aae30
This diff is collapsed.
Click to expand it.
tests/test_triton.py
0 → 100644
View file @
a24aae30
import
pytest
import
torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackLinear
from
bitsandbytes.nn
import
Linear8bitLt
@
pytest
.
mark
.
skipif
(
not
is_triton_available
()
or
not
torch
.
cuda
.
is_available
()
or
not
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
,
reason
=
"This test requires triton and a GPU with compute capability 8.0 or higher."
)
@
pytest
.
mark
.
parametrize
(
"vector_wise_quantization"
,
[
False
,
True
])
def
test_switchback
(
vector_wise_quantization
):
for
dim
in
[
83
]:
for
batch
in
[
13
]:
standard
=
torch
.
nn
.
Linear
(
dim
,
4
*
dim
).
cuda
().
half
()
switchback
=
SwitchBackLinear
(
dim
,
4
*
dim
,
vector_wise_quantization
=
vector_wise_quantization
).
cuda
().
half
()
baseline
=
Linear8bitLt
(
dim
,
4
*
dim
).
cuda
().
half
()
switchback
.
weight
.
data
.
copy_
(
standard
.
weight
)
switchback
.
bias
.
data
.
copy_
(
standard
.
bias
)
baseline
.
weight
.
data
.
copy_
(
standard
.
weight
)
baseline
.
bias
.
data
.
copy_
(
standard
.
bias
)
x1
=
torch
.
randn
(
batch
,
dim
).
cuda
().
half
().
requires_grad_
(
True
)
x2
=
x1
.
clone
().
detach
().
requires_grad_
(
True
)
x3
=
x1
.
clone
().
detach
().
requires_grad_
(
True
)
out_standard
=
standard
(
x1
)
(
2
**
10
*
out_standard
.
abs
().
mean
()).
backward
()
print
(
x2
.
dtype
)
out_sb
=
switchback
(
x2
)
(
2
**
10
*
out_sb
.
abs
().
mean
()).
backward
()
out_baseline
=
baseline
(
x3
)
(
2
**
10
*
out_baseline
.
abs
().
mean
()).
backward
()
err_sb
=
(
out_standard
-
out_sb
).
abs
().
mean
()
err_baseline
=
(
out_standard
-
out_baseline
).
abs
().
mean
()
print
(
'OUT'
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
standard
.
bias
.
grad
-
switchback
.
bias
.
grad
).
abs
().
mean
()
err_baseline
=
(
standard
.
bias
.
grad
-
baseline
.
bias
.
grad
).
abs
().
mean
()
print
(
'GW2'
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
standard
.
weight
.
grad
-
switchback
.
weight
.
grad
).
abs
().
mean
()
err_baseline
=
(
standard
.
weight
.
grad
-
baseline
.
weight
.
grad
).
abs
().
mean
()
print
(
'GW1'
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
x1
.
grad
-
x2
.
grad
).
abs
().
mean
()
err_baseline
=
(
x1
.
grad
-
x3
.
grad
).
abs
().
mean
()
print
(
'GX1'
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment