Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
90b0ac57
Commit
90b0ac57
authored
Jul 11, 2023
by
Tim Dettmers
Browse files
Fixed missing bias in bnb.matmul_4bit for inference; more tests.
parent
dc96e9e7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
12 deletions
+46
-12
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+4
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+0
-2
tests/test_functional.py
tests/test_functional.py
+29
-1
tests/test_generation.py
tests/test_generation.py
+13
-8
No files found.
bitsandbytes/autograd/_functions.py
View file @
90b0ac57
...
...
@@ -571,6 +571,9 @@ def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bia
warn
(
f
'Some matrices hidden dimension is not a multiple of
{
blocksize
}
and efficient inference kernels are not supported for these (slow). Matrix input size found:
{
A
.
shape
}
'
)
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
else
:
return
F
.
gemv_4bit
(
A
,
B
.
t
(),
out
,
state
=
quant_state
)
out
=
F
.
gemv_4bit
(
A
,
B
.
t
(),
out
,
state
=
quant_state
)
if
bias
is
not
None
:
out
+=
bias
return
out
else
:
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
bitsandbytes/functional.py
View file @
90b0ac57
...
...
@@ -1512,8 +1512,6 @@ def gemv_4bit(
return
out
def
igemm
(
A
:
Tensor
,
B
:
Tensor
,
...
...
tests/test_functional.py
View file @
90b0ac57
...
...
@@ -2364,7 +2364,7 @@ def test_normal_map_tree():
@
pytest
.
mark
.
parametrize
(
"kind"
,
[
'fc1'
,
'fc2'
,
'attn'
,
'attn_packed'
],
ids
=
[
'fc1'
,
'fc2'
,
'attn'
,
'attn_packed'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
[
'fp16'
,
'bf16'
,
'fp32'
])
def
test_gemv_4bit
(
dtype
,
storage_type
,
double_quant
,
kind
):
for
dim
in
[
128
,
256
,
512
,
1024
,
2048
,
4096
,
6144
]:
for
dim
in
[
128
,
256
,
512
,
1024
]:
#for dim in [4*1024]:
#for dim in [1*128]:
errs1
=
[]
...
...
@@ -2525,3 +2525,31 @@ def test_managed():
# assert (A==17).sum().item() == n*n
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
[
'fp16'
,
'bf16'
,
'fp32'
])
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
False
],
ids
=
[
'DQ_True'
])
def
test_gemv_eye_4bit
(
storage_type
,
dtype
,
double_quant
):
dims
=
10
torch
.
random
.
manual_seed
(
np
.
random
.
randint
(
0
,
412424242
))
dims
=
torch
.
randint
(
0
,
8192
,
size
=
(
dims
,)).
tolist
()
dims
=
[
dim
+
(
64
-
(
dim
%
64
))
for
dim
in
dims
]
#for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for
dim
in
dims
:
A
=
torch
.
normal
(
0
,
0.1
,
size
=
(
1
,
1
,
dim
),
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
eye
(
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
A
.
requires_grad
=
True
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
torch
.
testing
.
assert_close
(
A
,
C3
)
torch
.
testing
.
assert_close
(
A
,
C1
)
torch
.
testing
.
assert_close
(
A
,
C2
)
#torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
#torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
tests/test_generation.py
View file @
90b0ac57
...
...
@@ -65,7 +65,7 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
return
tokenizer
.
decode
(
outputs
[
0
],
skip_special_tokens
=
True
)
models
=
[
'huggyllama/llama-7b'
,
'bigscience/bloom-1b7'
]
dtypes
=
[
'nf4'
,
'fp4'
,
'16bit'
]
dtypes
=
[
'nf4'
,
'fp4'
]
load_in_4bit
=
[
True
,
False
]
values
=
list
(
product
(
models
,
dtypes
))
strfunc
=
lambda
lst
:
[
str
(
x
)
for
x
in
lst
]
...
...
@@ -73,14 +73,17 @@ ids = ['_'.join(strfunc(x)) for x in values]
@
pytest
.
fixture
(
scope
=
'session'
,
params
=
values
,
ids
=
ids
)
def
model_and_tokenizer
(
request
):
model
,
tokenizer
=
get_model_and_tokenizer
(
request
.
param
)
yield
model
,
tokenizer
yield
request
.
param
,
model
,
tokenizer
del
model
@
pytest
.
mark
.
parametrize
(
"DQ"
,
[
True
,
False
],
ids
=
[
'DQ_True'
,
'DQ_False'
])
@
pytest
.
mark
.
parametrize
(
"inference_kernel"
,
[
True
,
False
],
ids
=
[
'inference_kernel_True'
,
'inference_kernel_False'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
[
'fp16'
,
'bf16'
,
'fp32'
])
def
test_pi
(
model_and_tokenizer
,
dtype
,
inference_kernel
):
#@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
def
test_pi
(
model_and_tokenizer
,
inference_kernel
,
DQ
):
print
(
''
)
dtype
=
torch
.
float16
model
,
tokenizer
=
model_and_tokenizer
fixture_config
,
model
,
tokenizer
=
model_and_tokenizer
generation_config
=
transformers
.
GenerationConfig
(
max_new_tokens
=
20
,
...
...
@@ -94,16 +97,16 @@ def test_pi(model_and_tokenizer, dtype, inference_kernel):
#text = 'Please write down the first 50 digits of pi.'
#text = get_prompt_for_generation_eval(text)
#text += ' Sure, here the first 50 digits of pi: 3.14159'
n_cases
=
3
n_cases
=
6
text
=
'3.14159'
if
hasattr
(
model
.
config
,
'quantization_config'
):
model
.
config
.
quantization_config
.
bnb_4bit_compute_dtype
=
dtype
model
.
config
.
quantization_config
.
bnb_4bit_use_double_quant
=
DQ
if
not
inference_kernel
:
text
=
[
text
]
*
n_cases
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
'cuda:0'
)
x
=
inputs
[
'input_ids'
]
failure_count
=
0
outputs
=
[]
if
inference_kernel
:
for
i
in
range
(
n_cases
):
...
...
@@ -116,10 +119,12 @@ def test_pi(model_and_tokenizer, dtype, inference_kernel):
assert
len
(
outputs
)
==
n_cases
failure_count
=
0
for
i
in
range
(
n_cases
):
if
not
outputs
[
i
][:
len
(
str
(
math
.
pi
))]
==
str
(
math
.
pi
):
failure_count
+=
1
if
failure_count
>
1
:
failure_max
=
(
2
if
fixture_config
[
0
]
==
'huggyllama/llama-7b'
else
4
)
if
failure_count
>
failure_max
:
print
(
math
.
pi
)
for
out
in
outputs
:
print
(
out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment