Commit e76d7152 authored by jeffxtang's avatar jeffxtang
Browse files

the working example code to use BertForQuestionAnswering and get an answer...

the working example code to use BertForQuestionAnswering and get an answer from a text and a question
parent d844db40
...@@ -1095,12 +1095,16 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1095,12 +1095,16 @@ class BertForQuestionAnswering(BertPreTrainedModel):
Examples:: Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased') model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
start_positions = torch.tensor([1]) input_text = "[CLS] " + question + " [SEP] " + text + " [SEP]"
end_positions = torch.tensor([3]) input_ids = tokenizer.encode(input_text)
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))]
loss, start_scores, end_scores = outputs[:2] start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
print(' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
# a nice puppet
""" """
def __init__(self, config): def __init__(self, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment