Commit 365fb34c authored by Aneesh Pappu's avatar Aneesh Pappu
Browse files

small fix to remove shifting of lm labels during pre process of roc stories,...

small fix to remove shifting of lm labels during pre process of roc stories, as this shifting happens interanlly in the model
parent 2dee8631
......@@ -83,8 +83,8 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
input_ids[i, 1, :len(with_cont2)] = with_cont2
mc_token_ids[i, 0] = len(with_cont1) - 1
mc_token_ids[i, 1] = len(with_cont2) - 1
lm_labels[i, 0, :len(with_cont1)-1] = with_cont1[1:]
lm_labels[i, 1, :len(with_cont2)-1] = with_cont2[1:]
lm_labels[i, 0, :len(with_cont1)] = with_cont1
lm_labels[i, 1, :len(with_cont2)] = with_cont2
mc_labels[i] = mc_label
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment