Unverified Commit 26bc3eab authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'big-refactor' into model-written-eval

parents 0d701496 cf617ab1
......@@ -58,7 +58,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
try:
source, target = code_to_language(src), code_to_language(tgt)
groups = ["greedy_until", "translation", lang]
groups = ["generate_until", "translation", lang]
if lang in gpt3_translation_benchmarks.keys():
groups += ["gpt3_translation_benchmarks"]
......
......@@ -6,7 +6,7 @@ doc_to_text: 'English phrase: {{translation["en"]}}
French phrase:'
group:
- greedy_until
- generate_until
- translation
- wmt14
- gpt3_translation_benchmarks
......
......@@ -6,7 +6,7 @@ doc_to_text: 'French phrase: {{translation["fr"]}}
English phrase:'
group:
- greedy_until
- generate_until
- translation
- wmt14
- gpt3_translation_benchmarks
......
......@@ -6,7 +6,7 @@ doc_to_text: 'German phrase: {{translation["de"]}}
English phrase:'
group:
- greedy_until
- generate_until
- translation
- wmt16
- gpt3_translation_benchmarks
......
......@@ -6,7 +6,7 @@ doc_to_text: 'English phrase: {{translation["en"]}}
German phrase:'
group:
- greedy_until
- generate_until
- translation
- wmt16
- gpt3_translation_benchmarks
......
......@@ -6,7 +6,7 @@ doc_to_text: 'English phrase: {{translation["en"]}}
Romanian phrase:'
group:
- greedy_until
- generate_until
- translation
- wmt16
- gpt3_translation_benchmarks
......
......@@ -6,7 +6,7 @@ doc_to_text: 'Romanian phrase: {{translation["ro"]}}
English phrase:'
group:
- greedy_until
- generate_until
- translation
- wmt16
- gpt3_translation_benchmarks
......
output_type: greedy_until
output_type: generate_until
training_split: train
validation_split: validation
fewshot_split: validation
......
task: triviaqa
dataset_path: trivia_qa
dataset_name: rc.nocontext
output_type: greedy_until
output_type: generate_until
training_split: train
validation_split: validation
doc_to_text: "Question: {{question}}?\nAnswer:"
......
......@@ -3,7 +3,7 @@ group:
task: truthfulqa_gen
dataset_path: truthful_qa
dataset_name: generation
output_type: greedy_until
output_type: generate_until
training_split: null
validation_split: validation
test_split: null
......
......@@ -3,7 +3,7 @@ group:
task: anagrams1
dataset_path: EleutherAI/unscramble
dataset_name: mid_word_1_anagrams
output_type: greedy_until
output_type: generate_until
test_split: validation
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
......
......@@ -3,7 +3,7 @@ group:
task: anagrams2
dataset_path: EleutherAI/unscramble
dataset_name: mid_word_2_anagrams
output_type: greedy_until
output_type: generate_until
test_split: validation
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
......
......@@ -3,7 +3,7 @@ group:
task: cycle_letters
dataset_path: EleutherAI/unscramble
dataset_name: cycle_letters_in_word
output_type: greedy_until
output_type: generate_until
test_split: validation
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
......
......@@ -3,7 +3,7 @@ group:
task: random_insertion
dataset_path: EleutherAI/unscramble
dataset_name: random_insertion_in_word
output_type: greedy_until
output_type: generate_until
test_split: validation
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
......
......@@ -3,7 +3,7 @@ group:
task: reversed_words
dataset_path: EleutherAI/unscramble
dataset_name: reversed_words
output_type: greedy_until
output_type: generate_until
test_split: validation
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
......
......@@ -5,7 +5,7 @@ dataset_path: wmt16
dataset_name: ro-en
training_split: train
validation_split: validation
output_type: greedy_until
output_type: generate_until
doc_to_text: "translate English to Romanian: {{translation.en}}"
doc_to_target: "{{translation.ro}}"
metric_list:
......
......@@ -78,7 +78,7 @@ def chunks(iter, n: int = 0, fn=None):
arr = []
for i, x in enumerate(iter):
arr.append(x)
if len(arr) == (fn(i) if fn else n):
if len(arr) == (fn(i, iter) if fn else n):
yield arr
arr = []
......
......@@ -23,7 +23,7 @@ class DryrunLM(LM):
return res
def greedy_until(self, requests):
def generate_until(self, requests):
res = []
for ctx, _ in requests:
......
......@@ -15,10 +15,10 @@ class Test_HFLM:
multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")() # type: ignore
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: list[Instance] = multiple_choice_task.instances
greedy_until_task = tasks.TASK_REGISTRY.get("gsm8k_yaml")() # type: ignore
greedy_until_task.build_all_requests(limit=10, rank=0, world_size=1)
greedy_until_task._config.generation_kwargs["max_gen_toks"] = 10
GREEDY_UNTIL: list[Instance] = greedy_until_task.instances
generate_until_task = tasks.TASK_REGISTRY.get("gsm8k_yaml")() # type: ignore
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until: list[Instance] = generate_until_task.instances
rolling_task = tasks.TASK_REGISTRY.get("wikitext")() # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
ROLLING: list[Instance] = rolling_task.instances
......@@ -65,7 +65,7 @@ class Test_HFLM:
-52.70050811767578,
-56.25089645385742,
]
GREEDY_UNTIL_RES = [
generate_until_RES = [
" The average of $2.50 each is $",
" A robe takes 2 bolts of blue fiber and half",
" $50,000 in repairs.",
......@@ -109,9 +109,9 @@ class Test_HFLM:
), np.argmax(np.array(_res).reshape(-1, 4), axis=1)
assert (argmax_RES == argmax_res).all()
def test_greedy_until(self) -> None:
res = self.LM.greedy_until(self.GREEDY_UNTIL)
assert res == self.GREEDY_UNTIL_RES
def test_generate_until(self) -> None:
res = self.LM.generate_until(self.generate_until)
assert res == self.generate_until_RES
def test_logliklihood_rolling(self) -> None:
res = self.LM.loglikelihood_rolling(self.ROLLING)
......
......@@ -78,7 +78,7 @@ def test_gpt2():
# test empty context
gpt2.loglikelihood([("", "test")])
(gen,) = gpt2.greedy_until(
(gen,) = gpt2.generate_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])]
)
......@@ -204,7 +204,7 @@ def test_gpt3():
# test empty context
gpt3.loglikelihood([("", "test")])
(gen,) = gpt3.greedy_until(
(gen,) = gpt3.generate_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])]
)
......@@ -300,7 +300,7 @@ def test_textsynth():
# test empty context
textsynth.loglikelihood([("", "test")])
(gen,) = textsynth.greedy_until(
(gen,) = textsynth.generate_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment