Commit ebccca1e authored by Baber's avatar Baber
Browse files

nit

parent ca348256
......@@ -42,7 +42,13 @@ def get_example(num_words, common_repeats=30, uncommon_repeats=3, common_nums=10
return context, common
def generate_input_output(num_words, max_seq_length, freq_cw=30, freq_ucw=3, num_cw=10):
def generate_input_output(
num_words: int,
max_seq_length: int,
freq_cw: int = 30,
freq_ucw: int = 3,
num_cw: int = 10,
):
if max_seq_length < 4096:
context_example, answer_example = get_example(20, 3, 1, num_cw)
context, answer = get_example(num_words, 6, 1, num_cw)
......@@ -62,7 +68,7 @@ def generate_input_output(num_words, max_seq_length, freq_cw=30, freq_ucw=3, num
query="",
)
return input_example + "\n" + input_text, answer
return input_example, input_text, answer
def sys_word_pair_random(
......@@ -82,11 +88,15 @@ def sys_word_pair_random(
total_tokens = 0
while total_tokens + tokens_to_generate < max_seq_length:
input_text, answer = generate_input_output(num_words, max_seq_length)
input_example, input_text, answer = generate_input_output(
num_words, max_seq_length
)
# Calculate the number of tokens in the example
total_tokens = len(
TOKENIZER(
input_text
input_example
+ "\n"
+ input_text
+ " "
+ " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)])
)
......@@ -110,7 +120,9 @@ def sys_word_pair_random(
used_words = num_words
while True:
try:
input_text, answer = generate_input_output(used_words)
input_example, input_text, answer = generate_input_output(
used_words, max_seq_length
)
length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break
......@@ -122,6 +134,9 @@ def sys_word_pair_random(
input_text = " ".join(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
input_example = " ".join(
input_example.replace("\n", " ").replace("\t", " ").strip().split()
)
formatted_output = {
"index": index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment