Commit f283c335 authored by Baber's avatar Baber
Browse files

nit

parent 537a77fd
...@@ -19,23 +19,19 @@ NUM_SAMPLES = 500 ...@@ -19,23 +19,19 @@ NUM_SAMPLES = 500
REMOVE_NEWLINE_TAB = "" REMOVE_NEWLINE_TAB = ""
STOP_WORDS = "" STOP_WORDS = ""
RANDOM_SEED = 42 RANDOM_SEED = 42
# SEQ_LENGTHS = ( # Define Needle/Haystack Format
# # 131072,
# # 65536,
# # 32768,
# 16384,
# 8192,
# 4096,
# )
# # Define Needle/Haystack Format
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}." NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
# Words # Words
nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") r = wonderwords.RandomWord()
adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt")
verbs = wonderwords.random_word._get_words_from_text_file("verblist.txt") nouns = r._categories["nouns"]
adjs = r._categories["adjectives"]
verbs = r._categories["verbs"]
# nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt")
# adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt")
# verbs = wonderwords.random_word._get_words_from_text_file("verblist.txt")
words = [f"{adj}-{noun}" for adj in adjs for noun in nouns] words = [f"{adj}-{noun}" for adj in adjs for noun in nouns]
WORDS = sorted(list(set(words))) WORDS = sorted(list(set(words)))
...@@ -70,22 +66,22 @@ def download_nltk_resources(): ...@@ -70,22 +66,22 @@ def download_nltk_resources():
download_nltk_resources() download_nltk_resources()
def generate_random_number(num_digits=7): def generate_random_number(num_digits=7) -> str:
lower_bound = 10 ** (num_digits - 1) lower_bound = 10 ** (num_digits - 1)
upper_bound = 10**num_digits - 1 upper_bound = 10**num_digits - 1
return str(random.randint(lower_bound, upper_bound)) return str(random.randint(lower_bound, upper_bound))
def generate_random_word(): def generate_random_word() -> str:
word = random.choice(WORDS) word = random.choice(WORDS)
return word return word
def generate_random_uuid(): def generate_random_uuid() -> str:
return str(uuid.UUID(int=random.getrandbits(128), version=4)) return str(uuid.UUID(int=random.getrandbits(128), version=4))
def generate_random(type_needle: str): def generate_random(type_needle: str) -> str:
if type_needle == "numbers": if type_needle == "numbers":
return generate_random_number() return generate_random_number()
elif type_needle == "words": elif type_needle == "words":
...@@ -108,7 +104,7 @@ def generate_input_output( ...@@ -108,7 +104,7 @@ def generate_input_output(
template: str, template: str,
num_needle_q: int = 1, num_needle_q: int = 1,
random_seed: int = RANDOM_SEED, random_seed: int = RANDOM_SEED,
): ) -> tuple[str, list[str], str]:
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}." NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
keys, values, needles = [], [], [] keys, values, needles = [], [], []
for _ in range(num_needle_k): for _ in range(num_needle_k):
...@@ -199,6 +195,7 @@ def generate_input_output( ...@@ -199,6 +195,7 @@ def generate_input_output(
def generate_samples( def generate_samples(
haystack, haystack,
TOKENIZER=None,
*, *,
max_seq_length: int, max_seq_length: int,
type_haystack: str, type_haystack: str,
...@@ -213,7 +210,6 @@ def generate_samples( ...@@ -213,7 +210,6 @@ def generate_samples(
incremental: int = 500, incremental: int = 500,
remove_newline_tab: bool = False, remove_newline_tab: bool = False,
random_seed: int = 42, random_seed: int = 42,
TOKENIZER=None,
) -> list[dict]: ) -> list[dict]:
assert TOKENIZER is not None, "TOKENIZER is not defined." assert TOKENIZER is not None, "TOKENIZER is not defined."
num_needle_k = max(num_needle_k, num_needle_q) num_needle_k = max(num_needle_k, num_needle_q)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment