"next_docs/en/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "ea8f7e0fbd8e8b907482b95b12edc8a4093a5c40"
Unverified Commit c10decf7 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Bart: example] drop columns that are exclusively pad_token_id… (#3400)

* trim seq_len below 1024 if there are columns full of pad_token_id
* Centralize trim_batch so SummarizationDataset can use it too
parent 63f4d8ca
......@@ -1997,3 +1997,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
files = self._tokenizer.save(folder, name=file)
return tuple(files)
def trim_batch(
input_ids, pad_token_id, attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment