Unverified Commit a7920065 authored by Boda Sadallah's avatar Boda Sadallah Committed by GitHub
Browse files

fix bug in group_texts function, that was inserting short batches (#23429)

* fix bug in group_texts function, that was inserting short batches

* fully exclude short batches and return empty dict instead

* fix style
parent b7b81d93
...@@ -491,9 +491,8 @@ def main(): ...@@ -491,9 +491,8 @@ def main():
# Concatenate all texts. # Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]]) total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
# customize this part to your needs. # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size total_length = (total_length // block_size) * block_size
# Split by chunks of max_len. # Split by chunks of max_len.
result = { result = {
......
...@@ -434,9 +434,8 @@ def main(): ...@@ -434,9 +434,8 @@ def main():
# Concatenate all texts. # Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]]) total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
# customize this part to your needs. # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size total_length = (total_length // block_size) * block_size
# Split by chunks of max_len. # Split by chunks of max_len.
result = { result = {
......
...@@ -506,9 +506,8 @@ def main(): ...@@ -506,9 +506,8 @@ def main():
# Concatenate all texts. # Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]]) total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # We drop the small remainder, and if the total_length < max_seq_length we exclude this batch and return an empty dict.
# customize this part to your needs. # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
if total_length >= max_seq_length:
total_length = (total_length // max_seq_length) * max_seq_length total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len. # Split by chunks of max_len.
result = { result = {
......
...@@ -472,9 +472,8 @@ def main(): ...@@ -472,9 +472,8 @@ def main():
# Concatenate all texts. # Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]]) total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # We drop the small remainder, and if the total_length < max_seq_length we exclude this batch and return an empty dict.
# customize this part to your needs. # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
if total_length >= max_seq_length:
total_length = (total_length // max_seq_length) * max_seq_length total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len. # Split by chunks of max_len.
result = { result = {
......
...@@ -450,9 +450,8 @@ def main(): ...@@ -450,9 +450,8 @@ def main():
# Concatenate all texts. # Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]]) total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # We drop the small remainder, and if the total_length < max_seq_length we exclude this batch and return an empty dict.
# customize this part to your needs. # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
if total_length >= max_seq_length:
total_length = (total_length // max_seq_length) * max_seq_length total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len. # Split by chunks of max_len.
result = { result = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment