Unverified Commit 4210cd96 authored by Joe Davison's avatar Joe Davison Committed by GitHub
Browse files

fix add_token_positions fn (#10217)

parent 7169d1ea
...@@ -558,15 +558,13 @@ we can use the built in :func:`~transformers.BatchEncoding.char_to_token` method ...@@ -558,15 +558,13 @@ we can use the built in :func:`~transformers.BatchEncoding.char_to_token` method
end_positions = [] end_positions = []
for i in range(len(answers)): for i in range(len(answers)):
start_positions.append(encodings.char_to_token(i, answers[i]['answer_start'])) start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'])) end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))
# if start position is None, the answer passage has been truncated # if start position is None, the answer passage has been truncated
if start_positions[-1] is None: if start_positions[-1] is None:
start_positions[-1] = tokenizer.model_max_length start_positions[-1] = tokenizer.model_max_length
end_positions[-1] = tokenizer.model_max_length
# if end position is None, the 'char_to_token' function points to the space before the correct token - > add + 1
if end_positions[-1] is None:
end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end'] + 1)
encodings.update({'start_positions': start_positions, 'end_positions': end_positions}) encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
add_token_positions(train_encodings, train_answers) add_token_positions(train_encodings, train_answers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment