Commit 7f22572a authored by Baber's avatar Baber
Browse files

Merge branch 'main' into longcxt

parents 5e2979d2 f724be69
...@@ -259,7 +259,7 @@ def doc_to_text(src: str, tgt: str) -> str: ...@@ -259,7 +259,7 @@ def doc_to_text(src: str, tgt: str) -> str:
src_name, tgt_name = map(code_to_language_name, [src, tgt]) src_name, tgt_name = map(code_to_language_name, [src, tgt])
return f"""\ return f"""\
{src_name} sentence: {jinja_var('sentence_' + src)} {src_name} sentence: {jinja_var("sentence_" + src)}
{tgt_name} sentence:""" {tgt_name} sentence:"""
......
...@@ -7,7 +7,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: ...@@ -7,7 +7,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
### Context: {doc["context"]} ### Context: {doc["context"]}
### Question: {doc["question"]} ### Question: {doc["question"]}
### Options: ### Options:
(1) {doc['option#1']}\n(2) {doc["option#2"]}\n(3) {doc["option#3"]}\n(4) {doc['option#4']}\n(5) {doc['option#5']} (1) {doc["option#1"]}\n(2) {doc["option#2"]}\n(3) {doc["option#3"]}\n(4) {doc["option#4"]}\n(5) {doc["option#5"]}
### Answer: 주어진 문제의 정답은""" ### Answer: 주어진 문제의 정답은"""
out_doc = { out_doc = {
......
...@@ -258,7 +258,7 @@ def doc_to_text(src: str, tgt: str) -> str: ...@@ -258,7 +258,7 @@ def doc_to_text(src: str, tgt: str) -> str:
src_name, tgt_name = map(code_to_language_name, [src, tgt]) src_name, tgt_name = map(code_to_language_name, [src, tgt])
return f"""\ return f"""\
{src_name} sentence: {jinja_var('sentence_' + src)} {src_name} sentence: {jinja_var("sentence_" + src)}
{tgt_name} sentence:""" {tgt_name} sentence:"""
......
...@@ -722,7 +722,7 @@ class RephraseChecker(Instruction): ...@@ -722,7 +722,7 @@ class RephraseChecker(Instruction):
if not self.is_change(value): if not self.is_change(value):
raise ValueError( raise ValueError(
f"value {value} does not contain " "changes in the form of *change me*." f"value {value} does not contain changes in the form of *change me*."
) )
response_without_changes = self.strip_changes(value) response_without_changes = self.strip_changes(value)
......
...@@ -35,10 +35,11 @@ RANK = os.environ.get("LOCAL_RANK", "0") ...@@ -35,10 +35,11 @@ RANK = os.environ.get("LOCAL_RANK", "0")
def download_nltk_resources(): def download_nltk_resources():
"""Download 'punkt' if not already installed""" """Download 'punkt' if not already installed"""
assert ( assert (nltk_version := parse_version(version("nltk"))) >= parse_version(
(nltk_version := parse_version(version("nltk"))) NLTK_MIN_VERSION
>= parse_version(NLTK_MIN_VERSION) ), (
), f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability." f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability."
)
try: try:
nltk.data.find("tokenizers/punkt_tab") nltk.data.find("tokenizers/punkt_tab")
......
...@@ -23,9 +23,9 @@ def _extract_answer(completion): ...@@ -23,9 +23,9 @@ def _extract_answer(completion):
def process_results(doc, results): def process_results(doc, results):
assert ( assert len(results) == 1, (
len(results) == 1 f"results should be a list with 1 str element, but is {results}"
), f"results should be a list with 1 str element, but is {results}" )
completion = results[0] completion = results[0]
extracted_answer = _extract_answer(completion) extracted_answer = _extract_answer(completion)
......
...@@ -722,7 +722,7 @@ class RephraseChecker(Instruction): ...@@ -722,7 +722,7 @@ class RephraseChecker(Instruction):
if not self.is_change(value): if not self.is_change(value):
raise ValueError( raise ValueError(
f"value {value} does not contain " "changes in the form of *change me*." f"value {value} does not contain changes in the form of *change me*."
) )
response_without_changes = self.strip_changes(value) response_without_changes = self.strip_changes(value)
......
...@@ -34,9 +34,9 @@ NLTK_MIN_VERSION = "3.9.1" ...@@ -34,9 +34,9 @@ NLTK_MIN_VERSION = "3.9.1"
def download_nltk_resources(): def download_nltk_resources():
"""Download 'punkt' if not already installed""" """Download 'punkt' if not already installed"""
nltk_version = pkg_resources.get_distribution("nltk").version nltk_version = pkg_resources.get_distribution("nltk").version
assert ( assert version.parse(nltk_version) >= version.parse(NLTK_MIN_VERSION), (
version.parse(nltk_version) >= version.parse(NLTK_MIN_VERSION) f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability."
), f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability." )
try: try:
nltk.data.find("tokenizers/punkt_tab") nltk.data.find("tokenizers/punkt_tab")
......
...@@ -8,7 +8,7 @@ def doc_to_choice(doc): ...@@ -8,7 +8,7 @@ def doc_to_choice(doc):
return ast.literal_eval(doc["choices"]) return ast.literal_eval(doc["choices"])
DOC_TO_TEXT = "{narrative}\n\n" "{question}\n\n" "{choices}\n" "Answer:" DOC_TO_TEXT = "{narrative}\n\n{question}\n\n{choices}\nAnswer:"
def doc_to_text(doc): def doc_to_text(doc):
...@@ -17,7 +17,7 @@ def doc_to_text(doc): ...@@ -17,7 +17,7 @@ def doc_to_text(doc):
""" """
choices = "" choices = ""
for i, choice in enumerate(ast.literal_eval(doc["choices"])): for i, choice in enumerate(ast.literal_eval(doc["choices"])):
choices += f"{i+1} - {choice}\n" choices += f"{i + 1} - {choice}\n"
text = DOC_TO_TEXT.format( text = DOC_TO_TEXT.format(
narrative=doc["narrative"], question=doc["question"], choices=choices narrative=doc["narrative"], question=doc["question"], choices=choices
......
...@@ -14,13 +14,13 @@ def load_questionsheet(qsheet: dict, no_context: bool = False): ...@@ -14,13 +14,13 @@ def load_questionsheet(qsheet: dict, no_context: bool = False):
all_subquestions += "\n" all_subquestions += "\n"
if no_context: if no_context:
prompt = f"""{qsheet['preamble']} prompt = f"""{qsheet["preamble"]}
{all_subquestions} {all_subquestions}
""" """
else: else:
prompt = f"""{qsheet['preamble']} prompt = f"""{qsheet["preamble"]}
{qsheet['context']} {qsheet["context"]}
{all_subquestions} {all_subquestions}
""" """
......
...@@ -258,7 +258,7 @@ def doc_to_text(src: str, tgt: str) -> str: ...@@ -258,7 +258,7 @@ def doc_to_text(src: str, tgt: str) -> str:
src_name, tgt_name = map(code_to_language_name, [src, tgt]) src_name, tgt_name = map(code_to_language_name, [src, tgt])
return f"""\ return f"""\
{src_name} sentence: {jinja_var('sentence_' + src)} {src_name} sentence: {jinja_var("sentence_" + src)}
{tgt_name} sentence:""" {tgt_name} sentence:"""
......
...@@ -127,9 +127,9 @@ def main(): ...@@ -127,9 +127,9 @@ def main():
for seed in range(1, N_SEEDS + 1): for seed in range(1, N_SEEDS + 1):
# Checking if directories exist # Checking if directories exist
seed_log_dir = os.path.join(args.log_dir, f"seed_{seed}") seed_log_dir = os.path.join(args.log_dir, f"seed_{seed}")
assert os.path.exists( assert os.path.exists(seed_log_dir), (
seed_log_dir f"No logs found for seed={seed}. No directory found at {seed_log_dir}"
), f"No logs found for seed={seed}. No directory found at {seed_log_dir}" )
subtasks = None subtasks = None
if args.dataset == "agieval": if args.dataset == "agieval":
agieval_subtasks = [ agieval_subtasks = [
......
...@@ -258,7 +258,7 @@ def doc_to_text(src: str, tgt: str) -> str: ...@@ -258,7 +258,7 @@ def doc_to_text(src: str, tgt: str) -> str:
src_name, tgt_name = map(code_to_language_name, [src, tgt]) src_name, tgt_name = map(code_to_language_name, [src, tgt])
return f"""\ return f"""\
{src_name} sentence: {jinja_var('sentence_' + src)} {src_name} sentence: {jinja_var("sentence_" + src)}
{tgt_name} sentence:""" {tgt_name} sentence:"""
......
...@@ -58,9 +58,9 @@ class SQuAD2(ConfigurableTask): ...@@ -58,9 +58,9 @@ class SQuAD2(ConfigurableTask):
super().__init__(config={"metadata": {"version": self.VERSION}}) super().__init__(config={"metadata": {"version": self.VERSION}})
# HF changed squad on us so we have to make sure we aren't running the old one # HF changed squad on us so we have to make sure we aren't running the old one
assert version.parse(datasets.__version__) >= version.parse( assert version.parse(datasets.__version__) >= version.parse("1.11.0"), (
"1.11.0" "datasets v1.11.0 or later required for SQuAD"
), "datasets v1.11.0 or later required for SQuAD" )
def has_training_docs(self): def has_training_docs(self):
return True return True
......
...@@ -14,7 +14,8 @@ categories = { ...@@ -14,7 +14,8 @@ categories = {
"STEM": [ "STEM": [
"biology", "biology",
"chemistry", "chemistry",
"mathematics" "physics", "mathematics",
"physics",
"earth science", "earth science",
], ],
"humanities": ["Chinese", "history", "Tour", "law"], "humanities": ["Chinese", "history", "Tour", "law"],
......
...@@ -48,9 +48,9 @@ def escaped_split(text, sep_char, maxsplit=-1): ...@@ -48,9 +48,9 @@ def escaped_split(text, sep_char, maxsplit=-1):
is not specified or less than 0, then there is no limit on the is not specified or less than 0, then there is no limit on the
number of splits (all possible splits are made). number of splits (all possible splits are made).
""" """
assert ( assert len(sep_char) == 1, (
len(sep_char) == 1 "separation string must be a single character for escaped splitting"
), "separation string must be a single character for escaped splitting" )
if maxsplit == 0: if maxsplit == 0:
return text return text
......
...@@ -17,7 +17,7 @@ eval_logger = utils.eval_logger ...@@ -17,7 +17,7 @@ eval_logger = utils.eval_logger
def memory_stats(): def memory_stats():
eval_logger.info( eval_logger.info(
f"Memory allocated: {torch.cuda.memory_allocated() / 1024 ** 2}, reserved: {torch.cuda.memory_reserved() // 1024 ** 2}" f"Memory allocated: {torch.cuda.memory_allocated() / 1024**2}, reserved: {torch.cuda.memory_reserved() // 1024**2}"
) )
......
...@@ -66,9 +66,9 @@ def main(): ...@@ -66,9 +66,9 @@ def main():
f"All models must have the same tasks. {model} has tasks: {model_tasks} but have already recorded tasks: {old_tasks}. Taking intersection {tasks}" f"All models must have the same tasks. {model} has tasks: {model_tasks} but have already recorded tasks: {old_tasks}. Taking intersection {tasks}"
) )
assert ( assert len(tasks) > 0, (
len(tasks) > 0 "Must provide at least one task in common amongst models to compare."
), "Must provide at least one task in common amongst models to compare." )
for task in tasks: for task in tasks:
# Upload data for all models # Upload data for all models
......
...@@ -87,7 +87,9 @@ class TestNewTasks: ...@@ -87,7 +87,9 @@ class TestNewTasks:
(x[-1].isspace() is False if len(x) > 0 else True) (x[-1].isspace() is False if len(x) > 0 else True)
if target_delimiter.isspace() if target_delimiter.isspace()
else True else True
), "doc_to_text ends in a whitespace and target delimiter also a whitespace" ), (
"doc_to_text ends in a whitespace and target delimiter also a whitespace"
)
else: else:
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment