"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2ebf4e6a7b66cdd8acb1baa06f4c1fa62fce2dba"
Unverified Commit 71d18d08 authored by Kenneth Enevoldsen's avatar Kenneth Enevoldsen Committed by GitHub
Browse files

fixed bug in run_mlm_flax_stream.py (#17203)



* fixed bug run_mlm_flax_stream.py

Fixed bug caused by an update to tokenizer keys introduced in recent transformers versions (between `4.6.2` and `4.18.0`) where additional keys were introduced to the tokenizer output.

* Update run_mlm_flax_stream.py

* adding missing paranthesis

* formatted to black

* remove cols from dataset instead

* reformat to black

* moved rem. columns to map

* formatted to black
Co-authored-by: default avatarKennethEnevoldsen <kennethcenevolsen@gmail.com>
parent 71abd3ad
...@@ -288,8 +288,10 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length): ...@@ -288,8 +288,10 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
tokenized_samples = next(train_iterator) tokenized_samples = next(train_iterator)
i += len(tokenized_samples["input_ids"]) i += len(tokenized_samples["input_ids"])
# concatenate tokenized samples to list # concatenate tokenized samples to list (excluding "id" and "text")
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()} samples = {
k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
}
# Concatenated tokens are split to lists of length `max_seq_length`. # Concatenated tokens are split to lists of length `max_seq_length`.
# Note that remainedr of % max_seq_length are thrown away. # Note that remainedr of % max_seq_length are thrown away.
...@@ -407,10 +409,7 @@ if __name__ == "__main__": ...@@ -407,10 +409,7 @@ if __name__ == "__main__":
def tokenize_function(examples): def tokenize_function(examples):
return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True) return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
tokenized_datasets = dataset.map( tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=list(dataset.features.keys()))
tokenize_function,
batched=True,
)
shuffle_seed = training_args.seed shuffle_seed = training_args.seed
tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed) tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment