Commit c8a3f2e2 authored by Baber's avatar Baber
Browse files

pre-commit

parent 7f22572a
...@@ -53,10 +53,11 @@ def cached_sent_tokenize(text: str) -> List[str]: ...@@ -53,10 +53,11 @@ def cached_sent_tokenize(text: str) -> List[str]:
def download_nltk_resources(): def download_nltk_resources():
"""Download 'punkt' if not already installed""" """Download 'punkt' if not already installed"""
assert ( assert (nltk_version := parse_version(version("nltk"))) >= parse_version(
(nltk_version := parse_version(version("nltk"))) NLTK_MIN_VERSION
>= parse_version(NLTK_MIN_VERSION) ), (
), f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability." f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability."
)
try: try:
nltk.data.find("tokenizers/punkt_tab") nltk.data.find("tokenizers/punkt_tab")
...@@ -303,8 +304,8 @@ def generate_samples( ...@@ -303,8 +304,8 @@ def generate_samples(
else f"The special magic {type_needle_v} for {query} mentioned in the provided text are", else f"The special magic {type_needle_v} for {query} mentioned in the provided text are",
} }
if formatted_output["outputs"][0] not in formatted_output["input"]: if formatted_output["outputs"][0] not in formatted_output["input"]:
assert ( assert False, (
False f"Needle not in input: {formatted_output}. Something went wrong."
), f"Needle not in input: {formatted_output}. Something went wrong." )
write_jsons.append(formatted_output) write_jsons.append(formatted_output)
return write_jsons return write_jsons
...@@ -81,9 +81,9 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False): ...@@ -81,9 +81,9 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
for n in [x for noise in sentences for x in noise.split(".")] for n in [x for noise in sentences for x in noise.split(".")]
] ]
try: try:
assert len(sentences) > len( assert len(sentences) > len(chains[0]), (
chains[0] "Noises too short, unable to generate data"
), "Noises too short, unable to generate data" )
# ruff: noqa # ruff: noqa
except: except:
print("reduces chain length for not enough noises") print("reduces chain length for not enough noises")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment