"torchvision/vscode:/vscode.git/clone" did not exist on "71073cb50389e811a4bdd4c2e207fa835f02856b"
Commit 2950b694 authored by wanglch's avatar wanglch
Browse files

Update train_multi_dcu.py

parent a5866d29
...@@ -66,7 +66,7 @@ def collate_fn(batch_samples): ...@@ -66,7 +66,7 @@ def collate_fn(batch_samples):
truncation=True, truncation=True,
return_tensors="pt" return_tensors="pt"
)["input_ids"] )["input_ids"]
batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels) batch_data['decoder_input_ids'] = model.module.prepare_decoder_input_ids_from_labels(labels)
end_token_index = torch.where(labels == tokenizer.eos_token_id)[1] end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index): for idx, end_idx in enumerate(end_token_index):
labels[idx][end_idx+1:] = -100 labels[idx][end_idx+1:] = -100
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment