Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
41f7b7ae
Unverified
Commit
41f7b7ae
authored
Mar 06, 2024
by
Joao Gante
Committed by
GitHub
Mar 06, 2024
Browse files
Generate: add tests for caches with `pad_to_multiple_of` (#29462)
parent
2890116a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
2 deletions
+72
-2
tests/test_cache_utils.py
tests/test_cache_utils.py
+72
-2
No files found.
tests/test_cache_utils.py
View file @
41f7b7ae
...
@@ -291,7 +291,7 @@ class CacheIntegrationTest(unittest.TestCase):
...
@@ -291,7 +291,7 @@ class CacheIntegrationTest(unittest.TestCase):
@
require_torch_gpu
@
require_torch_gpu
@
parameterized
.
expand
([
"eager"
,
"sdpa"
,
"flash_attention_2"
])
@
parameterized
.
expand
([
"eager"
,
"sdpa"
,
"flash_attention_2"
])
def
test_static_cache_greedy_
sampl
ing_pad_left
(
self
,
attn_implementation
):
def
test_static_cache_greedy_
decod
ing_pad_left
(
self
,
attn_implementation
):
EXPECTED_GENERATION
=
[
EXPECTED_GENERATION
=
[
"The best color is the one that complements the skin tone of the"
,
"The best color is the one that complements the skin tone of the"
,
"We should not undermind the issues at hand.
\n
We should not undermind the issues"
,
"We should not undermind the issues at hand.
\n
We should not undermind the issues"
,
...
@@ -331,7 +331,7 @@ class CacheIntegrationTest(unittest.TestCase):
...
@@ -331,7 +331,7 @@ class CacheIntegrationTest(unittest.TestCase):
@
require_torch_gpu
@
require_torch_gpu
@
parameterized
.
expand
([
"eager"
,
"sdpa"
,
"flash_attention_2"
])
@
parameterized
.
expand
([
"eager"
,
"sdpa"
,
"flash_attention_2"
])
def
test_static_cache_greedy_
sampl
ing_pad_right
(
self
,
attn_implementation
):
def
test_static_cache_greedy_
decod
ing_pad_right
(
self
,
attn_implementation
):
EXPECTED_GENERATION
=
[
EXPECTED_GENERATION
=
[
"The best color isЋ the one that complements the skin tone of"
,
"The best color isЋ the one that complements the skin tone of"
,
"We should not undermind the issues at hand.
\n
We should not undermind the issues"
,
"We should not undermind the issues at hand.
\n
We should not undermind the issues"
,
...
@@ -382,6 +382,76 @@ class CacheIntegrationTest(unittest.TestCase):
...
@@ -382,6 +382,76 @@ class CacheIntegrationTest(unittest.TestCase):
with
self
.
subTest
(
f
"
{
attn_implementation
}
, static, compiled"
):
with
self
.
subTest
(
f
"
{
attn_implementation
}
, static, compiled"
):
self
.
assertListEqual
(
decoded
,
EXPECTED_GENERATION
)
self
.
assertListEqual
(
decoded
,
EXPECTED_GENERATION
)
def
test_dynamic_cache_extra_left_padding
(
self
):
"""Tests that adding extra left-padding does not affect the generation with the dynamic cache"""
EXPECTED_GENERATION
=
[
"The best color is the one that complements the skin tone of the"
,
"We should not undermind the issues at hand.
\n
We should not undermind the issues"
,
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"NousResearch/Llama-2-7b-chat-hf"
,
padding_side
=
"left"
,
pad_token
=
"<s>"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"NousResearch/Llama-2-7b-chat-hf"
,
torch_dtype
=
torch
.
bfloat16
,
).
to
(
torch_device
)
inputs
=
tokenizer
(
[
"The best color is"
,
"We should not undermind the issues at hand"
],
padding
=
True
,
return_tensors
=
"pt"
).
to
(
model
.
device
)
gen_out
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
10
)
decoded
=
tokenizer
.
batch_decode
(
gen_out
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
decoded
,
EXPECTED_GENERATION
)
# Now with extra left-padding
inputs_expanded
=
tokenizer
(
[
"The best color is"
,
"We should not undermind the issues at hand"
],
padding
=
True
,
return_tensors
=
"pt"
,
pad_to_multiple_of
=
32
,
).
to
(
model
.
device
)
self
.
assertTrue
(
inputs
.
input_ids
.
shape
[
1
]
<
inputs_expanded
.
input_ids
.
shape
[
1
])
gen_out
=
model
.
generate
(
**
inputs_expanded
,
do_sample
=
False
,
max_new_tokens
=
10
)
decoded
=
tokenizer
.
batch_decode
(
gen_out
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
decoded
,
EXPECTED_GENERATION
)
def
test_static_cache_extra_left_padding
(
self
):
"""Tests that adding extra left-padding does not affect the generation with the static cache"""
EXPECTED_GENERATION
=
[
"The best color is the one that complements the skin tone of the"
,
"We should not undermind the issues at hand.
\n
We should not undermind the issues"
,
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"NousResearch/Llama-2-7b-chat-hf"
,
padding_side
=
"left"
,
pad_token
=
"<s>"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"NousResearch/Llama-2-7b-chat-hf"
,
torch_dtype
=
torch
.
bfloat16
,
).
to
(
torch_device
)
inputs
=
tokenizer
(
[
"The best color is"
,
"We should not undermind the issues at hand"
],
padding
=
True
,
return_tensors
=
"pt"
).
to
(
model
.
device
)
model
.
generation_config
.
cache_implementation
=
"static"
gen_out
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
10
)
decoded
=
tokenizer
.
batch_decode
(
gen_out
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
decoded
,
EXPECTED_GENERATION
)
# Now with extra left-padding
inputs_expanded
=
tokenizer
(
[
"The best color is"
,
"We should not undermind the issues at hand"
],
padding
=
True
,
return_tensors
=
"pt"
,
pad_to_multiple_of
=
32
,
).
to
(
model
.
device
)
self
.
assertTrue
(
inputs
.
input_ids
.
shape
[
1
]
<
inputs_expanded
.
input_ids
.
shape
[
1
])
gen_out
=
model
.
generate
(
**
inputs_expanded
,
do_sample
=
False
,
max_new_tokens
=
10
)
decoded
=
tokenizer
.
batch_decode
(
gen_out
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
decoded
,
EXPECTED_GENERATION
)
@
unittest
.
skip
(
"TODO @gante static cache's does not support beam search yet"
)
@
unittest
.
skip
(
"TODO @gante static cache's does not support beam search yet"
)
def
test_static_cache_beam_search
(
self
):
def
test_static_cache_beam_search
(
self
):
pass
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment