Unverified Commit 3ae8c8be authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #562 from apappu97/roc_stories_lmlabels_fix

Small fix to remove shifting of lm labels during pre process of RocStories.
parents e8952017 365fb34c
...@@ -83,8 +83,8 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d ...@@ -83,8 +83,8 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
input_ids[i, 1, :len(with_cont2)] = with_cont2 input_ids[i, 1, :len(with_cont2)] = with_cont2
mc_token_ids[i, 0] = len(with_cont1) - 1 mc_token_ids[i, 0] = len(with_cont1) - 1
mc_token_ids[i, 1] = len(with_cont2) - 1 mc_token_ids[i, 1] = len(with_cont2) - 1
lm_labels[i, 0, :len(with_cont1)-1] = with_cont1[1:] lm_labels[i, 0, :len(with_cont1)] = with_cont1
lm_labels[i, 1, :len(with_cont2)-1] = with_cont2[1:] lm_labels[i, 1, :len(with_cont2)] = with_cont2
mc_labels[i] = mc_label mc_labels[i] = mc_label
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels) all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs)) tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment