"examples/research_projects/bertabs/README.md" did not exist on "7e73c128055a61c4232f145ff9ffb99a56928510"
Unverified Commit a0f86743 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: TF contrastive search with XLA support (#20050)

* Add contrastive search
parent 504db92e
This diff is collapsed.
......@@ -651,6 +651,7 @@ class GenerationMixin:
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
......@@ -1860,7 +1861,7 @@ class GenerationMixin:
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
>>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
>>> # set pad_token_id to eos_token_id because OPT does not have a PAD token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
......@@ -1916,7 +1917,7 @@ class GenerationMixin:
if this_peer_finished_flag.item() == 0.0:
break
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values;
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past") is None:
......@@ -2014,7 +2015,7 @@ class GenerationMixin:
full_hidden_states = outputs.hidden_states
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
......
......@@ -550,6 +550,100 @@ class TFBartModelIntegrationTest(unittest.TestCase):
def tok(self):
return BartTokenizer.from_pretrained("facebook/bart-large")
@slow
def test_contrastive_search_bart(self):
article = (
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
input_ids = bart_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="tf"
).input_ids
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
@slow
def test_contrastive_search_bart_xla(self):
article = (
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
input_ids = bart_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="tf"
).input_ids
xla_generate = tf.function(bart_model.generate, jit_compile=True)
# no_repeat_ngram_size set to 0 because it isn't compatible with XLA, but doesn't change the original output
outputs = xla_generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64, no_repeat_ngram_size=0)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
@slow
@require_tf
......
......@@ -663,3 +663,72 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
output_ids = xla_generate(**input_ids, do_sample=False, num_beams=2)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings)
@slow
def test_contrastive_search_gpt2(self):
article = (
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
gpt2_model = TFGPT2LMHeadModel.from_pretrained("gpt2-large")
input_ids = gpt2_tokenizer(article, return_tensors="tf")
outputs = gpt2_model.generate(**input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)
@slow
def test_contrastive_search_gpt2_xla(self):
article = (
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
gpt2_model = TFGPT2LMHeadModel.from_pretrained("gpt2-large")
input_ids = gpt2_tokenizer(article, return_tensors="tf")
xla_generate = tf.function(gpt2_model.generate, jit_compile=True)
outputs = xla_generate(**input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)
......@@ -1783,7 +1783,7 @@ class TFModelTesterMixin:
model.compile(optimizer="sgd", run_eagerly=True)
model.train_on_batch(test_batch, test_batch_labels)
def _test_xla_generate(self, num_beams, num_return_sequences, max_length):
def _test_xla_generate(self, num_beams, num_return_sequences, max_length, **generate_kwargs):
def _generate_and_check_results(model, config, inputs_dict):
if "input_ids" in inputs_dict:
inputs = inputs_dict["input_ids"]
......@@ -1801,9 +1801,9 @@ class TFModelTesterMixin:
else:
raise ValueError("No valid generate input found in inputs_dict")
generated = model.generate(inputs).numpy()
generated = model.generate(inputs, **generate_kwargs).numpy()
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(inputs).numpy()
generated_xla = generate_xla(inputs, **generate_kwargs).numpy()
self.assertListEqual(generated.tolist(), generated_xla.tolist())
for model_class in self.all_generative_model_classes:
......@@ -1844,6 +1844,19 @@ class TFModelTesterMixin:
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length)
def test_xla_generate_contrastive(self):
"""
Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the
model cache and other outputs, and this test ensures that they are in a valid format that is also supported
by XLA.
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams = 1
num_return_sequences = 1
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length, penalty_alpha=0.5, top_k=5)
@slow
def test_xla_generate_slow(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment