Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ddfaf119
Unverified
Commit
ddfaf119
authored
Jul 03, 2024
by
Joao Gante
Committed by
GitHub
Jul 03, 2024
Browse files
Gemma 2: Update slow tests (#31759)
gemma 2 slow tests
parent
c1fe1259
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
25 deletions
+46
-25
src/transformers/pipelines/text_generation.py
src/transformers/pipelines/text_generation.py
+10
-10
tests/models/gemma2/test_modeling_gemma2.py
tests/models/gemma2/test_modeling_gemma2.py
+36
-15
No files found.
src/transformers/pipelines/text_generation.py
View file @
ddfaf119
...
@@ -272,12 +272,17 @@ class TextGenerationPipeline(Pipeline):
...
@@ -272,12 +272,17 @@ class TextGenerationPipeline(Pipeline):
max_length
=
None
,
max_length
=
None
,
**
generate_kwargs
,
**
generate_kwargs
,
):
):
if
isinstance
(
prompt_text
,
Chat
):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs
=
{}
tokenizer_kwargs
=
{
for
tokenizer_kwarg_name
in
[
"truncation"
,
"padding"
,
"max_length"
]:
"add_special_tokens"
:
add_special_tokens
,
if
locals
()[
tokenizer_kwarg_name
]
is
not
None
:
"truncation"
:
truncation
,
tokenizer_kwargs
[
tokenizer_kwarg_name
]
=
locals
()[
tokenizer_kwarg_name
]
"padding"
:
padding
,
"max_length"
:
max_length
,
}
tokenizer_kwargs
=
{
key
:
value
for
key
,
value
in
tokenizer_kwargs
.
items
()
if
value
is
not
None
}
if
isinstance
(
prompt_text
,
Chat
):
tokenizer_kwargs
.
pop
(
"add_special_tokens"
,
None
)
# ignore add_special_tokens on chats
inputs
=
self
.
tokenizer
.
apply_chat_template
(
inputs
=
self
.
tokenizer
.
apply_chat_template
(
prompt_text
.
messages
,
prompt_text
.
messages
,
add_generation_prompt
=
True
,
add_generation_prompt
=
True
,
...
@@ -286,11 +291,6 @@ class TextGenerationPipeline(Pipeline):
...
@@ -286,11 +291,6 @@ class TextGenerationPipeline(Pipeline):
**
tokenizer_kwargs
,
**
tokenizer_kwargs
,
)
)
else
:
else
:
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs
=
{}
for
tokenizer_kwarg_name
in
[
"add_special_tokens"
,
"truncation"
,
"padding"
,
"max_length"
]:
if
locals
()[
tokenizer_kwarg_name
]
is
not
None
:
tokenizer_kwargs
[
tokenizer_kwarg_name
]
=
locals
()[
tokenizer_kwarg_name
]
inputs
=
self
.
tokenizer
(
prefix
+
prompt_text
,
return_tensors
=
self
.
framework
,
**
tokenizer_kwargs
)
inputs
=
self
.
tokenizer
(
prefix
+
prompt_text
,
return_tensors
=
self
.
framework
,
**
tokenizer_kwargs
)
inputs
[
"prompt_text"
]
=
prompt_text
inputs
[
"prompt_text"
]
=
prompt_text
...
...
tests/models/gemma2/test_modeling_gemma2.py
View file @
ddfaf119
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
unittest
import
unittest
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
Gemma2Config
,
is_torch_available
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
Gemma2Config
,
is_torch_available
,
pipeline
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
require_read_token
,
require_read_token
,
require_torch
,
require_torch
,
...
@@ -102,41 +102,62 @@ class Gemma2IntegrationTest(unittest.TestCase):
...
@@ -102,41 +102,62 @@ class Gemma2IntegrationTest(unittest.TestCase):
cls
.
cuda_compute_capability_major_version
=
torch
.
cuda
.
get_device_capability
()[
0
]
cls
.
cuda_compute_capability_major_version
=
torch
.
cuda
.
get_device_capability
()[
0
]
@
require_read_token
@
require_read_token
def
test_model_
2
b_bf16
(
self
):
def
test_model_
9
b_bf16
(
self
):
model_id
=
"google/gemma-2-9b"
model_id
=
"google/gemma-2-9b"
EXPECTED_TEXTS
=
[
EXPECTED_TEXTS
=
[
"<bos>Hello I am doing a project
for a class and I am trying to use the <code><a-image></code>
"
,
"<bos>Hello I am doing a project
on the 1918 flu pandemic and I am trying to find out how many
"
,
"<pad><pad><bos>Hi today
. So,
I'm going to
show you how to do a problem from the textbook. So
"
,
"<pad><pad><bos>Hi today I'm going to
be talking about the history of the United States. The United States of America
"
,
]
]
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_id
,
low_cpu_mem_usage
=
True
,
torch_dtype
=
torch
.
bfloat16
).
to
(
model
=
AutoModelForCausalLM
.
from_pretrained
(
torch_device
model_id
,
low_cpu_mem_usage
=
True
,
torch_dtype
=
torch
.
bfloat16
,
attn_implementation
=
"eager"
)
)
.
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
)
inputs
=
tokenizer
(
self
.
input_text
,
return_tensors
=
"pt"
,
padding
=
True
).
to
(
torch_device
)
inputs
=
tokenizer
(
self
.
input_text
,
return_tensors
=
"pt"
,
padding
=
True
).
to
(
torch_device
)
output
=
model
.
generate
(
**
inputs
,
max_new_tokens
=
20
,
do_sample
=
False
)
output
=
model
.
generate
(
**
inputs
,
max_new_tokens
=
20
,
do_sample
=
False
)
output_text
=
tokenizer
.
batch_decode
(
output
,
skip_special_tokens
=
Tru
e
)
output_text
=
tokenizer
.
batch_decode
(
output
,
skip_special_tokens
=
Fals
e
)
self
.
assertEqual
(
output_text
,
EXPECTED_TEXTS
)
self
.
assertEqual
(
output_text
,
EXPECTED_TEXTS
)
@
require_read_token
@
require_read_token
def
test_model_
2
b_fp16
(
self
):
def
test_model_
9
b_fp16
(
self
):
model_id
=
"google/gemma-2-9b"
model_id
=
"google/gemma-2-9b"
EXPECTED_TEXTS
=
[
EXPECTED_TEXTS
=
[
"<bos>Hello I am doing a project on the
effect of the temperature on the rate of a reaction. I am using a
"
,
"<bos>Hello I am doing a project on the
1918 flu pandemic and I am trying to find out how many
"
,
"<pad><pad><bos>Hi today I'm going to be talking about the
1000-4000-
"
,
"<pad><pad><bos>Hi today I'm going to be talking about the
history of the United States. The United States of America
"
,
]
]
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_id
,
low_cpu_mem_usage
=
True
,
torch_dtype
=
torch
.
float16
).
to
(
model
=
AutoModelForCausalLM
.
from_pretrained
(
torch_device
model_id
,
low_cpu_mem_usage
=
True
,
torch_dtype
=
torch
.
float16
,
attn_implementation
=
"eager"
)
)
.
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
)
inputs
=
tokenizer
(
self
.
input_text
,
return_tensors
=
"pt"
,
padding
=
True
).
to
(
torch_device
)
inputs
=
tokenizer
(
self
.
input_text
,
return_tensors
=
"pt"
,
padding
=
True
).
to
(
torch_device
)
output
=
model
.
generate
(
**
inputs
,
max_new_tokens
=
20
,
do_sample
=
False
)
output
=
model
.
generate
(
**
inputs
,
max_new_tokens
=
20
,
do_sample
=
False
)
output_text
=
tokenizer
.
batch_decode
(
output
,
skip_special_tokens
=
Tru
e
)
output_text
=
tokenizer
.
batch_decode
(
output
,
skip_special_tokens
=
Fals
e
)
self
.
assertEqual
(
output_text
,
EXPECTED_TEXTS
)
self
.
assertEqual
(
output_text
,
EXPECTED_TEXTS
)
@
require_read_token
def
test_model_9b_pipeline_bf16
(
self
):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id
=
"google/gemma-2-9b"
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
EXPECTED_TEXTS
=
[
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many"
,
"Hi today I'm going to be talking about the history of the United States. The United States of America"
,
]
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_id
,
low_cpu_mem_usage
=
True
,
torch_dtype
=
torch
.
bfloat16
).
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
)
pipe
=
pipeline
(
"text-generation"
,
model
=
model
,
tokenizer
=
tokenizer
)
output
=
pipe
(
self
.
input_text
,
max_new_tokens
=
20
,
do_sample
=
False
,
padding
=
True
)
self
.
assertEqual
(
output
[
0
][
0
][
"generated_text"
],
EXPECTED_TEXTS
[
0
])
self
.
assertEqual
(
output
[
1
][
0
][
"generated_text"
],
EXPECTED_TEXTS
[
1
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment