Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a0f86743
Unverified
Commit
a0f86743
authored
Nov 07, 2022
by
Joao Gante
Committed by
GitHub
Nov 07, 2022
Browse files
Generate: TF contrastive search with XLA support (#20050)
* Add contrastive search
parent
504db92e
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
770 additions
and
46 deletions
+770
-46
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+587
-40
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+4
-3
tests/models/bart/test_modeling_tf_bart.py
tests/models/bart/test_modeling_tf_bart.py
+94
-0
tests/models/gpt2/test_modeling_tf_gpt2.py
tests/models/gpt2/test_modeling_tf_gpt2.py
+69
-0
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+16
-3
No files found.
src/transformers/generation_tf_utils.py
View file @
a0f86743
This diff is collapsed.
Click to expand it.
src/transformers/generation_utils.py
View file @
a0f86743
...
...
@@ -651,6 +651,7 @@ class GenerationMixin:
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
model_kwargs
,
)
->
Tuple
[
torch
.
LongTensor
,
Dict
[
str
,
Any
]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
if
input_ids
is
not
None
:
input_ids
=
input_ids
.
repeat_interleave
(
expand_size
,
dim
=
0
)
...
...
@@ -1860,7 +1861,7 @@ class GenerationMixin:
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
>>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> # set pad_token_id to eos_token_id because
G
PT
2
does not have a PAD token
>>> # set pad_token_id to eos_token_id because
O
PT does not have a PAD token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
...
...
@@ -1916,7 +1917,7 @@ class GenerationMixin:
if
this_peer_finished_flag
.
item
()
==
0.0
:
break
# if the first step in the loop, encode all the prefix and obtain
three parameters
: (1) past_key_values;
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if
model_kwargs
.
get
(
"past"
)
is
None
:
...
...
@@ -2014,7 +2015,7 @@ class GenerationMixin:
full_hidden_states
=
outputs
.
hidden_states
context_hidden
=
last_hidden_states
.
repeat_interleave
(
top_k
,
dim
=
0
)
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
# compute the degenerati
o
n penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
selected_idx
=
_ranking_fast
(
context_hidden
,
next_hidden
,
top_k_probs
,
penalty_alpha
,
top_k
)
...
...
tests/models/bart/test_modeling_tf_bart.py
View file @
a0f86743
...
...
@@ -550,6 +550,100 @@ class TFBartModelIntegrationTest(unittest.TestCase):
def
tok
(
self
):
return
BartTokenizer
.
from_pretrained
(
"facebook/bart-large"
)
@
slow
def
test_contrastive_search_bart
(
self
):
article
=
(
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"facebook/bart-large-cnn"
)
bart_model
=
TFBartForConditionalGeneration
.
from_pretrained
(
"facebook/bart-large-cnn"
)
input_ids
=
bart_tokenizer
(
article
,
add_special_tokens
=
False
,
truncation
=
True
,
max_length
=
512
,
return_tensors
=
"tf"
).
input_ids
outputs
=
bart_model
.
generate
(
input_ids
,
penalty_alpha
=
0.5
,
top_k
=
5
,
max_length
=
64
)
generated_text
=
bart_tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
generated_text
,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
@
slow
def
test_contrastive_search_bart_xla
(
self
):
article
=
(
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"facebook/bart-large-cnn"
)
bart_model
=
TFBartForConditionalGeneration
.
from_pretrained
(
"facebook/bart-large-cnn"
)
input_ids
=
bart_tokenizer
(
article
,
add_special_tokens
=
False
,
truncation
=
True
,
max_length
=
512
,
return_tensors
=
"tf"
).
input_ids
xla_generate
=
tf
.
function
(
bart_model
.
generate
,
jit_compile
=
True
)
# no_repeat_ngram_size set to 0 because it isn't compatible with XLA, but doesn't change the original output
outputs
=
xla_generate
(
input_ids
,
penalty_alpha
=
0.5
,
top_k
=
5
,
max_length
=
64
,
no_repeat_ngram_size
=
0
)
generated_text
=
bart_tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
generated_text
,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
@
slow
@
require_tf
...
...
tests/models/gpt2/test_modeling_tf_gpt2.py
View file @
a0f86743
...
...
@@ -663,3 +663,72 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
output_ids
=
xla_generate
(
**
input_ids
,
do_sample
=
False
,
num_beams
=
2
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
output_strings
,
expected_output_strings
)
@
slow
def
test_contrastive_search_gpt2
(
self
):
article
=
(
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2-large"
)
gpt2_model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2-large"
)
input_ids
=
gpt2_tokenizer
(
article
,
return_tensors
=
"tf"
)
outputs
=
gpt2_model
.
generate
(
**
input_ids
,
penalty_alpha
=
0.6
,
top_k
=
4
,
max_length
=
256
)
generated_text
=
gpt2_tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
generated_text
,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom
\n\n
Google has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.
\n\n
Facebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."
\n\n
"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."
\n\n
Google declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)
@
slow
def
test_contrastive_search_gpt2_xla
(
self
):
article
=
(
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2-large"
)
gpt2_model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2-large"
)
input_ids
=
gpt2_tokenizer
(
article
,
return_tensors
=
"tf"
)
xla_generate
=
tf
.
function
(
gpt2_model
.
generate
,
jit_compile
=
True
)
outputs
=
xla_generate
(
**
input_ids
,
penalty_alpha
=
0.6
,
top_k
=
4
,
max_length
=
256
)
generated_text
=
gpt2_tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
generated_text
,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom
\n\n
Google has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.
\n\n
Facebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."
\n\n
"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."
\n\n
Google declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)
tests/test_modeling_tf_common.py
View file @
a0f86743
...
...
@@ -1783,7 +1783,7 @@ class TFModelTesterMixin:
model
.
compile
(
optimizer
=
"sgd"
,
run_eagerly
=
True
)
model
.
train_on_batch
(
test_batch
,
test_batch_labels
)
def
_test_xla_generate
(
self
,
num_beams
,
num_return_sequences
,
max_length
):
def
_test_xla_generate
(
self
,
num_beams
,
num_return_sequences
,
max_length
,
**
generate_kwargs
):
def
_generate_and_check_results
(
model
,
config
,
inputs_dict
):
if
"input_ids"
in
inputs_dict
:
inputs
=
inputs_dict
[
"input_ids"
]
...
...
@@ -1801,9 +1801,9 @@ class TFModelTesterMixin:
else
:
raise
ValueError
(
"No valid generate input found in inputs_dict"
)
generated
=
model
.
generate
(
inputs
).
numpy
()
generated
=
model
.
generate
(
inputs
,
**
generate_kwargs
).
numpy
()
generate_xla
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
generated_xla
=
generate_xla
(
inputs
).
numpy
()
generated_xla
=
generate_xla
(
inputs
,
**
generate_kwargs
).
numpy
()
self
.
assertListEqual
(
generated
.
tolist
(),
generated_xla
.
tolist
())
for
model_class
in
self
.
all_generative_model_classes
:
...
...
@@ -1844,6 +1844,19 @@ class TFModelTesterMixin:
max_length
=
10
self
.
_test_xla_generate
(
num_beams
,
num_return_sequences
,
max_length
)
def
test_xla_generate_contrastive
(
self
):
"""
Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the
model cache and other outputs, and this test ensures that they are in a valid format that is also supported
by XLA.
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams
=
1
num_return_sequences
=
1
max_length
=
10
self
.
_test_xla_generate
(
num_beams
,
num_return_sequences
,
max_length
,
penalty_alpha
=
0.5
,
top_k
=
5
)
@
slow
def
test_xla_generate_slow
(
self
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment