Unverified Commit e9e6efdc authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

BartForSequenceClassification: fix num_labels, add test (#3110)

parent f631e01d
...@@ -1324,7 +1324,7 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1324,7 +1324,7 @@ class BartForSequenceClassification(PretrainedBartModel):
# Prepend logits # Prepend logits
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
if labels is not None: # prepend loss to output, if labels is not None: # prepend loss to output,
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
return outputs return outputs
...@@ -171,7 +171,7 @@ class BartHeadTests(unittest.TestCase): ...@@ -171,7 +171,7 @@ class BartHeadTests(unittest.TestCase):
vocab_size = 99 vocab_size = 99
def test_lm_forward(self): def _get_config_and_data(self, output_past=False):
input_ids = torch.tensor( input_ids = torch.tensor(
[ [
[71, 82, 18, 33, 46, 91, 2], [71, 82, 18, 33, 46, 91, 2],
...@@ -191,9 +191,8 @@ class BartHeadTests(unittest.TestCase): ...@@ -191,9 +191,8 @@ class BartHeadTests(unittest.TestCase):
dtype=torch.long, dtype=torch.long,
device=torch_device, device=torch_device,
) )
batch_size = input_ids.shape[0]
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
batch_size = input_ids.shape[0]
config = BartConfig( config = BartConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=24, d_model=24,
...@@ -204,14 +203,25 @@ class BartHeadTests(unittest.TestCase): ...@@ -204,14 +203,25 @@ class BartHeadTests(unittest.TestCase):
encoder_ffn_dim=32, encoder_ffn_dim=32,
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
output_past=output_past,
) )
return config, input_ids, batch_size
def test_sequence_classification_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
labels = _long_tensor([2] * batch_size).to(torch_device)
model = BartForSequenceClassification(config) model = BartForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids) outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
logits = outputs[0] logits = outputs[1]
expected_shape = torch.Size((batch_size, config.num_labels)) expected_shape = torch.Size((batch_size, config.num_labels))
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
loss = outputs[0]
self.assertIsInstance(loss.item(), float)
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
lm_model = BartForMaskedLM(config) lm_model = BartForMaskedLM(config)
lm_model.to(torch_device) lm_model.to(torch_device)
loss, logits, enc_features = lm_model.forward( loss, logits, enc_features = lm_model.forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment