Unverified Commit 7c7d2ec9 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[GPT-J] Use the `float16` checkpoints in integration tests (#13676)

* Use fp16 checkpoints

* Style

* Fix outputs and disable OOM tests

* Correct another output

* Use a random smaller model for generation tests

* repo quickfix

* fix gradient checkpointing
parent 0ecdf6de
...@@ -18,7 +18,7 @@ import datetime ...@@ -18,7 +18,7 @@ import datetime
import unittest import unittest
from transformers import GPTJConfig, is_torch_available from transformers import GPTJConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, tooslow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
...@@ -398,9 +398,9 @@ class GPTJModelTest(unittest.TestCase): ...@@ -398,9 +398,9 @@ class GPTJModelTest(unittest.TestCase):
@slow @slow
def test_batch_generation(self): def test_batch_generation(self):
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16)
model.to(torch_device) model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
...@@ -458,7 +458,7 @@ class GPTJModelTest(unittest.TestCase): ...@@ -458,7 +458,7 @@ class GPTJModelTest(unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = GPTJModel.from_pretrained(model_name) model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16)
self.assertIsNotNone(model) self.assertIsNotNone(model)
...@@ -467,42 +467,27 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase): ...@@ -467,42 +467,27 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase):
@slow @slow
def test_lm_generate_gptj(self): def test_lm_generate_gptj(self):
for checkpointing in [True, False]: for checkpointing in [True, False]:
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16
)
if checkpointing: if checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
else: else:
model.gradient_checkpointing_disable() model.gradient_checkpointing_disable()
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [ # fmt: off
464, # The dog is a man's best friend. It is a loyal companion, and it is a friend
3290, expected_output_ids = [464, 3290, 318, 257, 582, 338, 1266, 1545, 13, 632, 318, 257, 9112, 15185, 11, 290, 340, 318, 257, 1545]
1528, # fmt: on
286,
3931,
389,
2402,
514,
11,
290,
326,
1724,
340,
447,
247,
82,
640,
284,
923,
3612,
] # The dog days of summer are upon us, and that means it’s time to start thinking
output_ids = model.generate(input_ids, do_sample=False) output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow @tooslow
def test_gptj_sample(self): def test_gptj_sample(self):
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") # Marked as @tooslow due to GPU OOM (issue #13676)
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16)
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -519,7 +504,13 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase): ...@@ -519,7 +504,13 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase):
output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True) output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True) output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)
EXPECTED_OUTPUT_STR = "Today is a nice day and I've already been enjoying it. I walked to work with my wife" if torch_device == "cuda":
EXPECTED_OUTPUT_STR = (
"Today is a nice day and I've already been enjoying it. I walked to work with my wife"
)
else:
EXPECTED_OUTPUT_STR = "Today is a nice day and one of those days that feels a bit more alive. I am ready"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR) self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
self.assertTrue( self.assertTrue(
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))]) all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
...@@ -527,8 +518,8 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase): ...@@ -527,8 +518,8 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase):
@slow @slow
def test_gptj_sample_max_time(self): def test_gptj_sample_max_time(self):
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") tokenizer = AutoTokenizer.from_pretrained("anton-l/gpt-j-tiny-random")
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") model = GPTJForCausalLM.from_pretrained("anton-l/gpt-j-tiny-random")
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment