Unverified Commit 2a8115f0 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[WIP] GPT Neo cleanup (#10985)

* better names

* add attention mixin

* all slow tests in one class

* make helper methods static so we can test

* add local attention tests

* better names

* doc

* apply review suggestions
parent 76800fb8
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -35,6 +36,7 @@ if is_torch_available(): ...@@ -35,6 +36,7 @@ if is_torch_available():
GPTNeoForCausalLM, GPTNeoForCausalLM,
GPTNeoModel, GPTNeoModel,
) )
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin, GPTNeoLocalSelfAttention
class GPTNeoModelTester: class GPTNeoModelTester:
...@@ -430,11 +432,164 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ...@@ -430,11 +432,164 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
# check attn size # check attn size
self.assertListEqual(shapes, expected_shape) self.assertListEqual(shapes, expected_shape)
@require_torch
class GPTNeoLocalAttentionTest(unittest.TestCase):
def _get_hidden_states(self):
return torch.tensor(
[
[
[0.4983, -0.7584, -1.6944, 0.5440],
[2.6918, 0.4206, 0.4176, 0.2055],
[-0.0071, -0.0405, -1.4920, -0.3630],
[1.0492, 0.1599, -1.7648, 0.2419],
[-1.8348, 2.0514, -0.1946, 0.3203],
[0.7672, -1.1600, -1.7118, -0.9056],
[0.2986, 0.5372, 0.7729, -0.1927],
[0.0285, 0.2629, -1.1156, -1.1992],
]
],
dtype=torch.float32,
device=torch_device,
)
def test_look_back(self):
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape
# check when seq_length is divisible by window_size
window_size = 4
block_length, num_block = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
blocked_hidden_states = GPTNeoAttentionMixin._look_back(hidden_states, block_length, window_size)
expected_shape = [batch_size, num_block, window_size + block_length, hidden_size]
self.assertListEqual(list(blocked_hidden_states.shape), expected_shape)
# The last block should contain the last (window_size + block_length) hidden_states
self.assertTrue(
torch.all(blocked_hidden_states[:, -1, ...] == hidden_states[:, -(window_size + block_length) :, ...])
)
# check when seq_length is not divisible by window_size
window_size = 3
block_length, num_block = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
blocked_hidden_states = GPTNeoAttentionMixin._look_back(hidden_states, block_length, window_size)
expected_shape = [batch_size, num_block, window_size + block_length, hidden_size]
self.assertListEqual(list(blocked_hidden_states.shape), expected_shape)
# The last block should contain the last (window_size + block_length) hidden_states
self.assertTrue(
torch.all(blocked_hidden_states[:, -1, ...] == hidden_states[:, -(window_size + block_length) :, ...])
)
# check when window_size is > seq_length
window_size = 19
block_length, num_block = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
blocked_hidden_states = GPTNeoAttentionMixin._look_back(hidden_states, block_length, window_size)
expected_shape = [batch_size, num_block, window_size + block_length, hidden_size]
self.assertListEqual(list(blocked_hidden_states.shape), expected_shape)
# when window_size > seq_length, num_blocks becomes 1, in this case
# the first window_size values in blocked_hidden_staes are all zeros
# and the last block_length values are equal to the hidden_states
values = blocked_hidden_states[:, -1, :window_size, ...]
expected_values = torch.zeros_like(values)
self.assertTrue(torch.all(values == expected_values))
self.assertTrue(torch.all(blocked_hidden_states[:, -1, -block_length:, ...] == hidden_states))
def test_create_attention_mask(self):
config = GPTNeoConfig.from_pretrained("valhalla/gpt-neo-random-tiny")
layer = GPTNeoLocalSelfAttention(config)
window_size = config.window_size
batch_size, seq_length = 8, 1
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
# check shapes
expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length]
self.assertListEqual(list(causal_mask.shape), expected_shape)
# first window_size tokens in the first block are always padded
# and should not be attended
self.assertTrue(torch.all(causal_mask[:, 0, :, :, :window_size] == 0))
# each window can attend at most window_size tokens
self.assertTrue(torch.all(torch.sum(causal_mask, dim=4) <= config.window_size))
# check if user provided attention_mask is handled correctly
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=torch_device)
attention_mask[:, -3:] = 0 # don't attend last 3 tokens
causal_mask = layer._create_attention_mask(
batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
)
# last 3 tokens will be in the last block and shoul have 0s in causal_mask
self.assertTrue(torch.all(causal_mask[:, -1, :, :, -3:] == 0))
# check shapes
expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length]
self.assertListEqual(list(causal_mask.shape), expected_shape)
# first window_size tokens in the first block are always padded
# and should not be attended
self.assertTrue(torch.all(causal_mask[:, 0, :, :, :window_size] == 0))
# each window can attend at most window_size tokens
self.assertTrue(torch.all(torch.sum(causal_mask, dim=4) <= config.window_size))
def test_local_attn_probs(self):
model = GPTNeoModel.from_pretrained("valhalla/gpt-neo-random-tiny").eval()
layer = model.h[1].attn.attention.to(torch_device)
hidden_states = self._get_hidden_states()
hidden_states = torch.cat([hidden_states, hidden_states - 0.5], dim=2)
batch_size, seq_length, hidden_size = hidden_states.shape
mask_tokens = 3
attention_mask = torch.ones(batch_size, seq_length, device=torch_device, dtype=torch.long)
attention_mask[:, -mask_tokens:] = 0 # dont atten last mask_tokens
_, attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True)
# the last 3 tokens will be in the last block, and should have 0 attn_probs
self.assertTrue(torch.all(attn_probs[:, -1, :, -mask_tokens:, -mask_tokens:] == 0))
# the first config.window_size tokens in the first block are always padded
# and should have 0 attn_probs
self.assertTrue(torch.all(attn_probs[:, 0, :, : model.config.window_size :, : model.config.window_size] == 0))
@require_torch
class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
@cached_property
def model(self):
return GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B").to(torch_device)
@cached_property
def tokenizer(self):
return GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
@slow
def test_lm_generate_gpt_neo(self):
for checkpointing in [True, False]:
model = self.model
model.config.gradient_checkpointing = checkpointing
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# fmt: off
# The dog-eared copy of the book, which is a collection of essays by the late author,
expected_output_ids = [464, 3290, 12, 3380, 4866, 286, 262, 1492, 11, 543, 318, 257, 4947, 286, 27126, 416, 262, 2739, 1772, 11]
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow
def test_gpt_neo_sample(self):
model = self.model
tokenizer = self.tokenizer
torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)
output_ids = model.generate(input_ids, do_sample=True)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = "Today is a nice day and if you don’t get the memo here is what you can"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
@slow @slow
def test_batch_generation(self): def test_batch_generation(self):
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") model = self.model
model.to(torch_device) tokenizer = self.tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
...@@ -479,33 +634,3 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ...@@ -479,33 +634,3 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
for model_name in GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = GPTNeoModel.from_pretrained(model_name) model = GPTNeoModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@require_torch
class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gpt_neo(self):
for checkpointing in [True, False]:
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B", gradient_checkpointing=checkpointing)
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# fmt: off
expected_output_ids = [464, 3290, 12, 3380, 4866, 286, 262, 1492, 11, 543, 318, 257, 4947, 286, 27126, 416, 262, 2739, 1772, 11] # The dog-eared copy of the book, which is a collection of essays by the late author,
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow
def test_gpt_neo_sample(self):
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
model.to(torch_device)
torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)
output_ids = model.generate(input_ids, do_sample=True)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = "Today is a nice day and if you don’t get the memo here is what you can"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment