Unverified Commit e9e6efdc authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

BartForSequenceClassification: fix num_labels, add test (#3110)

parent f631e01d
......@@ -1324,7 +1324,7 @@ class BartForSequenceClassification(PretrainedBartModel):
# Prepend logits
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
if labels is not None: # prepend loss to output,
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs
......@@ -171,7 +171,7 @@ class BartHeadTests(unittest.TestCase):
vocab_size = 99
def test_lm_forward(self):
def _get_config_and_data(self, output_past=False):
input_ids = torch.tensor(
[
[71, 82, 18, 33, 46, 91, 2],
......@@ -191,9 +191,8 @@ class BartHeadTests(unittest.TestCase):
dtype=torch.long,
device=torch_device,
)
batch_size = input_ids.shape[0]
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
batch_size = input_ids.shape[0]
config = BartConfig(
vocab_size=self.vocab_size,
d_model=24,
......@@ -204,14 +203,25 @@ class BartHeadTests(unittest.TestCase):
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
output_past=output_past,
)
return config, input_ids, batch_size
def test_sequence_classification_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
labels = _long_tensor([2] * batch_size).to(torch_device)
model = BartForSequenceClassification(config)
model.to(torch_device)
outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids)
logits = outputs[0]
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
logits = outputs[1]
expected_shape = torch.Size((batch_size, config.num_labels))
self.assertEqual(logits.shape, expected_shape)
loss = outputs[0]
self.assertIsInstance(loss.item(), float)
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
lm_model = BartForMaskedLM(config)
lm_model.to(torch_device)
loss, logits, enc_features = lm_model.forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment