Unverified Commit 9cda3620 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Fix (non-slow) tests on GPU (torch) (#3024)



* Fix tests on GPU (torch)

* Fix bart slow tests
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 9df74b8b
...@@ -86,7 +86,7 @@ def _prepare_bart_decoder_inputs( ...@@ -86,7 +86,7 @@ def _prepare_bart_decoder_inputs(
causal_lm_mask = None causal_lm_mask = None
new_shape = (bsz, tgt_len, tgt_len) new_shape = (bsz, tgt_len, tgt_len)
# make it broadcastable so can just be added to the attention coefficients # make it broadcastable so can just be added to the attention coefficients
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape) decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len) assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
return decoder_input_ids, decoder_attn_mask return decoder_input_ids, decoder_attn_mask
......
...@@ -172,7 +172,7 @@ class BartHeadTests(unittest.TestCase): ...@@ -172,7 +172,7 @@ class BartHeadTests(unittest.TestCase):
vocab_size = 99 vocab_size = 99
def test_lm_forward(self): def test_lm_forward(self):
input_ids = torch.Tensor( input_ids = torch.tensor(
[ [
[71, 82, 18, 33, 46, 91, 2], [71, 82, 18, 33, 46, 91, 2],
[68, 34, 26, 58, 30, 82, 2], [68, 34, 26, 58, 30, 82, 2],
...@@ -187,8 +187,10 @@ class BartHeadTests(unittest.TestCase): ...@@ -187,8 +187,10 @@ class BartHeadTests(unittest.TestCase):
[21, 5, 62, 28, 14, 76, 2], [21, 5, 62, 28, 14, 76, 2],
[45, 98, 37, 86, 59, 48, 2], [45, 98, 37, 86, 59, 48, 2],
[70, 70, 50, 9, 28, 0, 2], [70, 70, 50, 9, 28, 0, 2],
] ],
).long() dtype=torch.long,
device=torch_device,
)
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size) decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
...@@ -204,12 +206,14 @@ class BartHeadTests(unittest.TestCase): ...@@ -204,12 +206,14 @@ class BartHeadTests(unittest.TestCase):
max_position_embeddings=48, max_position_embeddings=48,
) )
model = BartForSequenceClassification(config) model = BartForSequenceClassification(config)
model.to(torch_device)
outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids) outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids)
logits = outputs[0] logits = outputs[0]
expected_shape = torch.Size((batch_size, config.num_labels)) expected_shape = torch.Size((batch_size, config.num_labels))
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
lm_model = BartForMaskedLM(config) lm_model = BartForMaskedLM(config)
lm_model.to(torch_device)
loss, logits, enc_features = lm_model.forward( loss, logits, enc_features = lm_model.forward(
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
) )
...@@ -292,6 +296,10 @@ def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): ...@@ -292,6 +296,10 @@ def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
raise AssertionError(msg) raise AssertionError(msg)
def _long_tensor(tok_lst):
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device,)
TOLERANCE = 1e-4 TOLERANCE = 1e-4
...@@ -299,15 +307,15 @@ TOLERANCE = 1e-4 ...@@ -299,15 +307,15 @@ TOLERANCE = 1e-4
class BartModelIntegrationTest(unittest.TestCase): class BartModelIntegrationTest(unittest.TestCase):
@slow @slow
def test_inference_no_head(self): def test_inference_no_head(self):
model = BartModel.from_pretrained("bart-large") model = BartModel.from_pretrained("bart-large").to(torch_device)
input_ids = torch.Tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]).long() input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids) inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
with torch.no_grad(): with torch.no_grad():
output = model.forward(**inputs_dict)[0] output = model.forward(**inputs_dict)[0]
expected_shape = torch.Size((1, 11, 1024)) expected_shape = torch.Size((1, 11, 1024))
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
expected_slice = torch.Tensor( expected_slice = torch.Tensor(
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]] [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device
) )
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
...@@ -315,20 +323,22 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -315,20 +323,22 @@ class BartModelIntegrationTest(unittest.TestCase):
def test_mnli_inference(self): def test_mnli_inference(self):
example_b = [0, 31414, 232, 328, 740, 1140, 69, 46078, 1588, 2, 1] example_b = [0, 31414, 232, 328, 740, 1140, 69, 46078, 1588, 2, 1]
input_ids = torch.Tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], example_b]).long() input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], example_b])
model = AutoModelForSequenceClassification.from_pretrained("bart-large-mnli") # eval called in from_pre model = AutoModelForSequenceClassification.from_pretrained("bart-large-mnli").to(
torch_device
) # eval called in from_pre
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids) inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
# Test that model hasn't changed # Test that model hasn't changed
with torch.no_grad(): with torch.no_grad():
batched_logits, features = model.forward(**inputs_dict) batched_logits, features = model.forward(**inputs_dict)
expected_shape = torch.Size((2, 3)) expected_shape = torch.Size((2, 3))
self.assertEqual(batched_logits.shape, expected_shape) self.assertEqual(batched_logits.shape, expected_shape)
expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]]) expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]]).to(torch_device)
logits_arr = batched_logits[0].detach() logits_arr = batched_logits[0].detach()
# Test that padding does not change results # Test that padding does not change results
input_ids_no_pad = torch.Tensor([example_b[:-1]]).long() input_ids_no_pad = _long_tensor([example_b[:-1]])
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad) inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad)
with torch.no_grad(): with torch.no_grad():
......
...@@ -68,7 +68,7 @@ class ModelTesterMixin: ...@@ -68,7 +68,7 @@ class ModelTesterMixin:
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
out_2 = outputs[0].numpy() out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0 out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -472,6 +472,7 @@ class ModelTesterMixin: ...@@ -472,6 +472,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config = copy.deepcopy(original_config) config = copy.deepcopy(original_config)
model = model_class(config) model = model_class(config)
model.to(torch_device)
model_vocab_size = config.vocab_size model_vocab_size = config.vocab_size
# Retrieve the embeddings and clone theme # Retrieve the embeddings and clone theme
......
...@@ -20,7 +20,7 @@ from transformers import is_torch_available ...@@ -20,7 +20,7 @@ from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
...@@ -125,6 +125,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -125,6 +125,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
decoder_lm_labels, decoder_lm_labels,
): ):
model = T5Model(config=config) model = T5Model(config=config)
model.to(torch_device)
model.eval() model.eval()
decoder_output, encoder_output = model( decoder_output, encoder_output = model(
encoder_input_ids=encoder_input_ids, encoder_input_ids=encoder_input_ids,
...@@ -157,6 +158,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -157,6 +158,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
decoder_lm_labels, decoder_lm_labels,
): ):
model = T5WithLMHeadModel(config=config) model = T5WithLMHeadModel(config=config)
model.to(torch_device)
model.eval() model.eval()
outputs = model( outputs = model(
encoder_input_ids=encoder_input_ids, encoder_input_ids=encoder_input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment