Unverified Commit c10decf7 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Bart: example] drop columns that are exclusively pad_token_id… (#3400)

* trim seq_len below 1024 if there are columns full of pad_token_id
* Centralize trim_batch so SummarizationDataset can use it too
parent 63f4d8ca
......@@ -1997,3 +1997,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
files = self._tokenizer.save(folder, name=file)
return tuple(files)
def trim_batch(
input_ids, pad_token_id, attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment