Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a0f86743
Unverified
Commit
a0f86743
authored
Nov 07, 2022
by
Joao Gante
Committed by
GitHub
Nov 07, 2022
Browse files
Generate: TF contrastive search with XLA support (#20050)
* Add contrastive search
parent
504db92e
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
770 additions
and
46 deletions
+770
-46
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+587
-40
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+4
-3
tests/models/bart/test_modeling_tf_bart.py
tests/models/bart/test_modeling_tf_bart.py
+94
-0
tests/models/gpt2/test_modeling_tf_gpt2.py
tests/models/gpt2/test_modeling_tf_gpt2.py
+69
-0
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+16
-3
No files found.
src/transformers/generation_tf_utils.py
View file @
a0f86743
This diff is collapsed.
Click to expand it.
src/transformers/generation_utils.py
View file @
a0f86743
...
@@ -651,6 +651,7 @@ class GenerationMixin:
...
@@ -651,6 +651,7 @@ class GenerationMixin:
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
model_kwargs
,
**
model_kwargs
,
)
->
Tuple
[
torch
.
LongTensor
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
torch
.
LongTensor
,
Dict
[
str
,
Any
]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
if
input_ids
is
not
None
:
if
input_ids
is
not
None
:
input_ids
=
input_ids
.
repeat_interleave
(
expand_size
,
dim
=
0
)
input_ids
=
input_ids
.
repeat_interleave
(
expand_size
,
dim
=
0
)
...
@@ -1860,7 +1861,7 @@ class GenerationMixin:
...
@@ -1860,7 +1861,7 @@ class GenerationMixin:
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
>>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> # set pad_token_id to eos_token_id because
G
PT
2
does not have a PAD token
>>> # set pad_token_id to eos_token_id because
O
PT does not have a PAD token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "DeepMind Company is"
>>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
...
@@ -1916,7 +1917,7 @@ class GenerationMixin:
...
@@ -1916,7 +1917,7 @@ class GenerationMixin:
if
this_peer_finished_flag
.
item
()
==
0.0
:
if
this_peer_finished_flag
.
item
()
==
0.0
:
break
break
# if the first step in the loop, encode all the prefix and obtain
three parameters
: (1) past_key_values;
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if
model_kwargs
.
get
(
"past"
)
is
None
:
if
model_kwargs
.
get
(
"past"
)
is
None
:
...
@@ -2014,7 +2015,7 @@ class GenerationMixin:
...
@@ -2014,7 +2015,7 @@ class GenerationMixin:
full_hidden_states
=
outputs
.
hidden_states
full_hidden_states
=
outputs
.
hidden_states
context_hidden
=
last_hidden_states
.
repeat_interleave
(
top_k
,
dim
=
0
)
context_hidden
=
last_hidden_states
.
repeat_interleave
(
top_k
,
dim
=
0
)
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
# compute the degenerati
o
n penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
# model confidence
selected_idx
=
_ranking_fast
(
context_hidden
,
next_hidden
,
top_k_probs
,
penalty_alpha
,
top_k
)
selected_idx
=
_ranking_fast
(
context_hidden
,
next_hidden
,
top_k_probs
,
penalty_alpha
,
top_k
)
...
...
tests/models/bart/test_modeling_tf_bart.py
View file @
a0f86743
...
@@ -550,6 +550,100 @@ class TFBartModelIntegrationTest(unittest.TestCase):
...
@@ -550,6 +550,100 @@ class TFBartModelIntegrationTest(unittest.TestCase):
def
tok
(
self
):
def
tok
(
self
):
return
BartTokenizer
.
from_pretrained
(
"facebook/bart-large"
)
return
BartTokenizer
.
from_pretrained
(
"facebook/bart-large"
)
@
slow
def
test_contrastive_search_bart
(
self
):
article
=
(
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"facebook/bart-large-cnn"
)
bart_model
=
TFBartForConditionalGeneration
.
from_pretrained
(
"facebook/bart-large-cnn"
)
input_ids
=
bart_tokenizer
(
article
,
add_special_tokens
=
False
,
truncation
=
True
,
max_length
=
512
,
return_tensors
=
"tf"
).
input_ids
outputs
=
bart_model
.
generate
(
input_ids
,
penalty_alpha
=
0.5
,
top_k
=
5
,
max_length
=
64
)
generated_text
=
bart_tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
generated_text
,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
@
slow
def
test_contrastive_search_bart_xla
(
self
):
article
=
(
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"facebook/bart-large-cnn"
)
bart_model
=
TFBartForConditionalGeneration
.
from_pretrained
(
"facebook/bart-large-cnn"
)
input_ids
=
bart_tokenizer
(
article
,
add_special_tokens
=
False
,
truncation
=
True
,
max_length
=
512
,
return_tensors
=
"tf"
).
input_ids
xla_generate
=
tf
.
function
(
bart_model
.
generate
,
jit_compile
=
True
)
# no_repeat_ngram_size set to 0 because it isn't compatible with XLA, but doesn't change the original output
outputs
=
xla_generate
(
input_ids
,
penalty_alpha
=
0.5
,
top_k
=
5
,
max_length
=
64
,
no_repeat_ngram_size
=
0
)
generated_text
=
bart_tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
generated_text
,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
@
slow
@
slow
@
require_tf
@
require_tf
...
...
tests/models/gpt2/test_modeling_tf_gpt2.py
View file @
a0f86743
...
@@ -663,3 +663,72 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
...
@@ -663,3 +663,72 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
output_ids
=
xla_generate
(
**
input_ids
,
do_sample
=
False
,
num_beams
=
2
)
output_ids
=
xla_generate
(
**
input_ids
,
do_sample
=
False
,
num_beams
=
2
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
output_strings
,
expected_output_strings
)
self
.
assertListEqual
(
output_strings
,
expected_output_strings
)
@
slow
def
test_contrastive_search_gpt2
(
self
):
article
=
(
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2-large"
)
gpt2_model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2-large"
)
input_ids
=
gpt2_tokenizer
(
article
,
return_tensors
=
"tf"
)
outputs
=
gpt2_model
.
generate
(
**
input_ids
,
penalty_alpha
=
0.6
,
top_k
=
4
,
max_length
=
256
)
generated_text
=
gpt2_tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
generated_text
,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom
\n\n
Google has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.
\n\n
Facebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."
\n\n
"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."
\n\n
Google declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)
@
slow
def
test_contrastive_search_gpt2_xla
(
self
):
article
=
(
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2-large"
)
gpt2_model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2-large"
)
input_ids
=
gpt2_tokenizer
(
article
,
return_tensors
=
"tf"
)
xla_generate
=
tf
.
function
(
gpt2_model
.
generate
,
jit_compile
=
True
)
outputs
=
xla_generate
(
**
input_ids
,
penalty_alpha
=
0.6
,
top_k
=
4
,
max_length
=
256
)
generated_text
=
gpt2_tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
generated_text
,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom
\n\n
Google has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.
\n\n
Facebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."
\n\n
"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."
\n\n
Google declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)
tests/test_modeling_tf_common.py
View file @
a0f86743
...
@@ -1783,7 +1783,7 @@ class TFModelTesterMixin:
...
@@ -1783,7 +1783,7 @@ class TFModelTesterMixin:
model
.
compile
(
optimizer
=
"sgd"
,
run_eagerly
=
True
)
model
.
compile
(
optimizer
=
"sgd"
,
run_eagerly
=
True
)
model
.
train_on_batch
(
test_batch
,
test_batch_labels
)
model
.
train_on_batch
(
test_batch
,
test_batch_labels
)
def
_test_xla_generate
(
self
,
num_beams
,
num_return_sequences
,
max_length
):
def
_test_xla_generate
(
self
,
num_beams
,
num_return_sequences
,
max_length
,
**
generate_kwargs
):
def
_generate_and_check_results
(
model
,
config
,
inputs_dict
):
def
_generate_and_check_results
(
model
,
config
,
inputs_dict
):
if
"input_ids"
in
inputs_dict
:
if
"input_ids"
in
inputs_dict
:
inputs
=
inputs_dict
[
"input_ids"
]
inputs
=
inputs_dict
[
"input_ids"
]
...
@@ -1801,9 +1801,9 @@ class TFModelTesterMixin:
...
@@ -1801,9 +1801,9 @@ class TFModelTesterMixin:
else
:
else
:
raise
ValueError
(
"No valid generate input found in inputs_dict"
)
raise
ValueError
(
"No valid generate input found in inputs_dict"
)
generated
=
model
.
generate
(
inputs
).
numpy
()
generated
=
model
.
generate
(
inputs
,
**
generate_kwargs
).
numpy
()
generate_xla
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
generate_xla
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
generated_xla
=
generate_xla
(
inputs
).
numpy
()
generated_xla
=
generate_xla
(
inputs
,
**
generate_kwargs
).
numpy
()
self
.
assertListEqual
(
generated
.
tolist
(),
generated_xla
.
tolist
())
self
.
assertListEqual
(
generated
.
tolist
(),
generated_xla
.
tolist
())
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
...
@@ -1844,6 +1844,19 @@ class TFModelTesterMixin:
...
@@ -1844,6 +1844,19 @@ class TFModelTesterMixin:
max_length
=
10
max_length
=
10
self
.
_test_xla_generate
(
num_beams
,
num_return_sequences
,
max_length
)
self
.
_test_xla_generate
(
num_beams
,
num_return_sequences
,
max_length
)
def
test_xla_generate_contrastive
(
self
):
"""
Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the
model cache and other outputs, and this test ensures that they are in a valid format that is also supported
by XLA.
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams
=
1
num_return_sequences
=
1
max_length
=
10
self
.
_test_xla_generate
(
num_beams
,
num_return_sequences
,
max_length
,
penalty_alpha
=
0.5
,
top_k
=
5
)
@
slow
@
slow
def
test_xla_generate_slow
(
self
):
def
test_xla_generate_slow
(
self
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment