Unverified Commit fa4bcd52 authored by ddobokki's avatar ddobokki Committed by GitHub
Browse files

edit: cast attention_mask to long in DataCollatorCTCWithPadding (#19369)

* edit: casting attention_mask to long in DataCollatorCTCWithPadding

* edit: casting attention_mask to long in DataCollatorCTCWithPadding
parent e9a49bab
...@@ -317,6 +317,8 @@ class DataCollatorCTCWithPadding: ...@@ -317,6 +317,8 @@ class DataCollatorCTCWithPadding:
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels batch["labels"] = labels
if "attention_mask" in batch:
batch["attention_mask"] = batch["attention_mask"].to(torch.long)
return batch return batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment