Unverified Commit 2c778428 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Fix common tests on GPU] send model, ids to torch_device (#4014)

parent 6faca88e
......@@ -19,6 +19,7 @@ import os.path
import random
import tempfile
import unittest
from typing import List
from transformers import is_torch_available
......@@ -629,10 +630,10 @@ class ModelTesterMixin:
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config)
model = model_class(config).to(torch_device)
if config.bos_token_id is None:
# if bos token id is not defined mobel needs input_ids
# if bos token id is not defined, model needs input_ids
with self.assertRaises(AssertionError):
model.generate(do_sample=True, max_length=5)
# num_return_sequences = 1
......@@ -651,7 +652,10 @@ class ModelTesterMixin:
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
bad_words_ids = [
self._generate_random_bad_tokens(1, model.config),
self._generate_random_bad_tokens(2, model.config),
]
output_tokens = model.generate(
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
)
......@@ -661,10 +665,12 @@ class ModelTesterMixin:
def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
input_ids = (inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]).to(
torch_device
)
for model_class in self.all_generative_model_classes:
model = model_class(config)
model = model_class(config).to(torch_device)
if config.bos_token_id is None:
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
......@@ -684,7 +690,10 @@ class ModelTesterMixin:
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
bad_words_ids = [
self._generate_random_bad_tokens(1, model.config),
self._generate_random_bad_tokens(2, model.config),
]
output_tokens = model.generate(
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
)
......@@ -692,20 +701,13 @@ class ModelTesterMixin:
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
def _generate_random_bad_tokens(self, num_bad_tokens, model):
def _generate_random_bad_tokens(self, num_bad_tokens: int, config) -> List[int]:
# special tokens cannot be bad tokens
special_tokens = []
if model.config.bos_token_id is not None:
special_tokens.append(model.config.bos_token_id)
if model.config.pad_token_id is not None:
special_tokens.append(model.config.pad_token_id)
if model.config.eos_token_id is not None:
special_tokens.append(model.config.eos_token_id)
special_tokens = [x for x in [config.bos_token_id, config.eos_token_id, config.pad_token_id] if x is not None]
# create random bad tokens that are not special tokens
bad_tokens = []
while len(bad_tokens) < num_bad_tokens:
token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).numpy()[0]
token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).cpu().numpy()[0]
if token not in special_tokens:
bad_tokens.append(token)
return bad_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment