Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
dc96e9e7
Commit
dc96e9e7
authored
Jul 11, 2023
by
Tim Dettmers
Browse files
Test for bloom that fails with inference kernels.
parent
ae7cd6ad
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
36 deletions
+65
-36
tests/test_generation.py
tests/test_generation.py
+65
-36
No files found.
tests/test_generation.py
View file @
dc96e9e7
...
@@ -2,6 +2,9 @@ import pytest
...
@@ -2,6 +2,9 @@ import pytest
import
torch
import
torch
import
math
import
math
from
itertools
import
product
import
transformers
from
transformers
import
(
from
transformers
import
(
AutoConfig
,
AutoConfig
,
AutoModelForCausalLM
,
AutoModelForCausalLM
,
...
@@ -11,7 +14,7 @@ from transformers import (
...
@@ -11,7 +14,7 @@ from transformers import (
set_seed
,
set_seed
,
)
)
import
transformers
def
get_4bit_config
():
def
get_4bit_config
():
...
@@ -26,15 +29,23 @@ def get_4bit_config():
...
@@ -26,15 +29,23 @@ def get_4bit_config():
)
)
def
get_model
(
model_name_or_path
=
'huggyllama/llama-7b'
,
bnb_config
=
get_4bit_config
()):
def
get_model_and_tokenizer
(
config
):
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
quant_type
=
config
model_name_or_path
,
bnb_config
=
get_4bit_config
()
quantization_config
=
bnb_config
,
if
quant_type
==
'16bit'
:
max_memory
=
{
0
:
'48GB'
},
bnb_config
.
load_in_4bit
=
False
device_map
=
'auto'
else
:
).
eval
()
bnb_config
.
bnb_4bit_quant_type
=
quant_type
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
quantization_config
=
bnb_config
,
max_memory
=
{
0
:
'48GB'
},
device_map
=
'auto'
,
torch_dtype
=
torch
.
bfloat16
).
eval
()
return
model
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_name_or_path
)
return
model
,
tokenizer
def
get_prompt_for_generation_eval
(
text
,
add_roles
=
True
):
def
get_prompt_for_generation_eval
(
text
,
add_roles
=
True
):
description
=
(
description
=
(
...
@@ -53,48 +64,66 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
...
@@ -53,48 +64,66 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
outputs
=
model
.
generate
(
inputs
=
inputs
[
'input_ids'
],
generation_config
=
generation_config
)
outputs
=
model
.
generate
(
inputs
=
inputs
[
'input_ids'
],
generation_config
=
generation_config
)
return
tokenizer
.
decode
(
outputs
[
0
],
skip_special_tokens
=
True
)
return
tokenizer
.
decode
(
outputs
[
0
],
skip_special_tokens
=
True
)
name_or_path
=
'huggyllama/llama-7b'
models
=
[
'huggyllama/llama-7b'
,
'bigscience/bloom-1b7'
]
#name_or_path = 'AI-Sweden/gpt-sw3-126m'
dtypes
=
[
'nf4'
,
'fp4'
,
'16bit'
]
load_in_4bit
=
[
True
,
False
]
@
pytest
.
fixture
(
scope
=
'session'
)
values
=
list
(
product
(
models
,
dtypes
))
def
model
():
strfunc
=
lambda
lst
:
[
str
(
x
)
for
x
in
lst
]
bnb_config
=
get_4bit_config
()
ids
=
[
'_'
.
join
(
strfunc
(
x
))
for
x
in
values
]
bnb_config
.
bnb_4bit_compute_dtype
=
torch
.
float32
@
pytest
.
fixture
(
scope
=
'session'
,
params
=
values
,
ids
=
ids
)
bnb_config
.
load_in_4bit
=
True
def
model_and_tokenizer
(
request
):
model
=
get_model
(
name_or_path
)
model
,
tokenizer
=
get_model_and_tokenizer
(
request
.
param
)
print
(
''
)
yield
model
,
tokenizer
return
model
del
model
@
pytest
.
fixture
(
scope
=
'session'
)
@
pytest
.
mark
.
parametrize
(
"inference_kernel"
,
[
True
,
False
],
ids
=
[
'inference_kernel_True'
,
'inference_kernel_False'
])
def
tokenizer
():
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
name_or_path
)
return
tokenizer
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
[
'fp16'
,
'bf16'
,
'fp32'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
[
'fp16'
,
'bf16'
,
'fp32'
])
def
test_pi
(
model
,
tokenizer
,
dtype
):
def
test_pi
(
model_and_tokenizer
,
dtype
,
inference_kernel
):
model
,
tokenizer
=
model_and_tokenizer
generation_config
=
transformers
.
GenerationConfig
(
generation_config
=
transformers
.
GenerationConfig
(
max_new_tokens
=
128
,
max_new_tokens
=
20
,
do_sample
=
True
,
do_sample
=
True
,
top_p
=
0.9
,
top_p
=
0.9
,
temperature
=
0.7
,
temperature
=
0.7
,
)
)
generation_config
.
max_new_tokens
=
5
0
generation_config
.
max_new_tokens
=
2
0
#text = 'Please write down the first 50 digits of pi.'
#text = 'Please write down the first 50 digits of pi.'
#text = get_prompt_for_generation_eval(text)
#text = get_prompt_for_generation_eval(text)
#text += ' Sure, here the first 50 digits of pi: 3.14159'
#text += ' Sure, here the first 50 digits of pi: 3.14159'
n_cases
=
3
text
=
'3.14159'
text
=
'3.14159'
model
.
config
.
quantization_config
.
bnb_4bit_compute_dtype
=
dtype
if
hasattr
(
model
.
config
,
'quantization_config'
):
model
.
config
.
quantization_config
.
bnb_4bit_compute_dtype
=
dtype
if
not
inference_kernel
:
text
=
[
text
]
*
n_cases
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
'cuda:0'
)
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
'cuda:0'
)
outputs
=
model
.
generate
(
inputs
=
inputs
[
'input_ids'
],
generation_config
=
generation_config
)
x
=
inputs
[
'input_ids'
]
textout
=
tokenizer
.
decode
(
outputs
[
0
],
skip_special_tokens
=
True
)
failure_count
=
0
print
(
''
)
outputs
=
[]
print
(
textout
)
if
inference_kernel
:
print
(
math
.
pi
)
for
i
in
range
(
n_cases
):
output
=
model
.
generate
(
x
,
generation_config
=
generation_config
)
textout
=
tokenizer
.
decode
(
output
[
0
],
skip_special_tokens
=
True
)
outputs
.
append
(
textout
)
else
:
outputs
=
model
.
generate
(
x
,
generation_config
=
generation_config
)
outputs
=
[
tokenizer
.
decode
(
output
,
skip_special_tokens
=
True
)
for
output
in
outputs
]
assert
len
(
outputs
)
==
n_cases
for
i
in
range
(
n_cases
):
if
not
outputs
[
i
][:
len
(
str
(
math
.
pi
))]
==
str
(
math
.
pi
):
failure_count
+=
1
if
failure_count
>
1
:
print
(
math
.
pi
)
for
out
in
outputs
:
print
(
out
)
raise
ValueError
(
f
'Failure count:
{
failure_count
}
/
{
n_cases
}
'
)
assert
textout
[:
len
(
str
(
math
.
pi
))]
==
str
(
math
.
pi
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment