Commit 2950b694 authored by wanglch's avatar wanglch
Browse files

Update train_multi_dcu.py

parent a5866d29
......@@ -66,7 +66,7 @@ def collate_fn(batch_samples):
truncation=True,
return_tensors="pt"
)["input_ids"]
batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
batch_data['decoder_input_ids'] = model.module.prepare_decoder_input_ids_from_labels(labels)
end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
labels[idx][end_idx+1:] = -100
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment