"web/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "de2894ac2cdcf0ece2c0d2a576579768882e8882"
Unverified Commit e2230ba7 authored by Simon Böhm's avatar Simon Böhm Committed by GitHub
Browse files

Fix BERT example code for NSP and Multiple Choice (#3953)

Change the example code to use encode_plus since the token_type_id
wasn't being correctly set.
parent 3a5d1ea2
...@@ -1036,11 +1036,12 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -1036,11 +1036,12 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased') model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
outputs = model(input_ids) next_sentence = "The sky is blue due to the shorter wavelength of blue light."
encoding = tokenizer.encode_plus(prompt, next_sentence, return_tensors='pt')
seq_relationship_scores = outputs[0]
loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
assert logits[0, 0] < logits[0, 1] # next sentence was random
""" """
outputs = self.bert( outputs = self.bert(
...@@ -1191,7 +1192,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1191,7 +1192,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss. Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above) of the input tensors. (see `input_ids` above)
Returns: Returns:
...@@ -1221,14 +1222,17 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1221,14 +1222,17 @@ class BertForMultipleChoice(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMultipleChoice.from_pretrained('bert-base-uncased') model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
labels = torch.tensor(1).unsqueeze(0) # Batch size 1 choice0 = "It is eaten with a fork and a knife."
outputs = model(input_ids, labels=labels) choice1 = "It is eaten while held in the hand."
labels = torch.tensor(0) # choice0 is correct (according to Wikipedia ;))
loss, classification_scores = outputs[:2] encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True)
outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
# the linear classifier still needs to be trained
loss, logits = outputs[:2]
""" """
num_choices = input_ids.shape[1] num_choices = input_ids.shape[1]
......
...@@ -857,10 +857,13 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -857,10 +857,13 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased') model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids)
seq_relationship_scores = outputs[0]
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
encoding = tokenizer.encode_plus(prompt, next_sentence, return_tensors='tf')
logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
assert logits[0][0] < logits[0][1] # the next sentence was random
""" """
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
...@@ -990,11 +993,15 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel): ...@@ -990,11 +993,15 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased') model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = tf.constant([tokenizer.encode(s) for s in choices])[None, :] # Batch size 1, 2 choices
outputs = model(input_ids)
classification_scores = outputs[0]
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='tf', pad_to_max_length=True)
# linear classifier on the output is not yet trained
outputs = model(encoding['input_ids'][None, :])
logits = outputs[0]
""" """
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment