"docs/source/en/tasks/masked_language_modeling.md" did not exist on "20273ee214d253eaa133ac3f146cf778d843ace4"
Unverified Commit 26bc3eab authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'big-refactor' into model-written-eval

parents 0d701496 cf617ab1
...@@ -58,7 +58,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None: ...@@ -58,7 +58,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
try: try:
source, target = code_to_language(src), code_to_language(tgt) source, target = code_to_language(src), code_to_language(tgt)
groups = ["greedy_until", "translation", lang] groups = ["generate_until", "translation", lang]
if lang in gpt3_translation_benchmarks.keys(): if lang in gpt3_translation_benchmarks.keys():
groups += ["gpt3_translation_benchmarks"] groups += ["gpt3_translation_benchmarks"]
......
...@@ -6,7 +6,7 @@ doc_to_text: 'English phrase: {{translation["en"]}} ...@@ -6,7 +6,7 @@ doc_to_text: 'English phrase: {{translation["en"]}}
French phrase:' French phrase:'
group: group:
- greedy_until - generate_until
- translation - translation
- wmt14 - wmt14
- gpt3_translation_benchmarks - gpt3_translation_benchmarks
......
...@@ -6,7 +6,7 @@ doc_to_text: 'French phrase: {{translation["fr"]}} ...@@ -6,7 +6,7 @@ doc_to_text: 'French phrase: {{translation["fr"]}}
English phrase:' English phrase:'
group: group:
- greedy_until - generate_until
- translation - translation
- wmt14 - wmt14
- gpt3_translation_benchmarks - gpt3_translation_benchmarks
......
...@@ -6,7 +6,7 @@ doc_to_text: 'German phrase: {{translation["de"]}} ...@@ -6,7 +6,7 @@ doc_to_text: 'German phrase: {{translation["de"]}}
English phrase:' English phrase:'
group: group:
- greedy_until - generate_until
- translation - translation
- wmt16 - wmt16
- gpt3_translation_benchmarks - gpt3_translation_benchmarks
......
...@@ -6,7 +6,7 @@ doc_to_text: 'English phrase: {{translation["en"]}} ...@@ -6,7 +6,7 @@ doc_to_text: 'English phrase: {{translation["en"]}}
German phrase:' German phrase:'
group: group:
- greedy_until - generate_until
- translation - translation
- wmt16 - wmt16
- gpt3_translation_benchmarks - gpt3_translation_benchmarks
......
...@@ -6,7 +6,7 @@ doc_to_text: 'English phrase: {{translation["en"]}} ...@@ -6,7 +6,7 @@ doc_to_text: 'English phrase: {{translation["en"]}}
Romanian phrase:' Romanian phrase:'
group: group:
- greedy_until - generate_until
- translation - translation
- wmt16 - wmt16
- gpt3_translation_benchmarks - gpt3_translation_benchmarks
......
...@@ -6,7 +6,7 @@ doc_to_text: 'Romanian phrase: {{translation["ro"]}} ...@@ -6,7 +6,7 @@ doc_to_text: 'Romanian phrase: {{translation["ro"]}}
English phrase:' English phrase:'
group: group:
- greedy_until - generate_until
- translation - translation
- wmt16 - wmt16
- gpt3_translation_benchmarks - gpt3_translation_benchmarks
......
output_type: greedy_until output_type: generate_until
training_split: train training_split: train
validation_split: validation validation_split: validation
fewshot_split: validation fewshot_split: validation
......
task: triviaqa task: triviaqa
dataset_path: trivia_qa dataset_path: trivia_qa
dataset_name: rc.nocontext dataset_name: rc.nocontext
output_type: greedy_until output_type: generate_until
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "Question: {{question}}?\nAnswer:" doc_to_text: "Question: {{question}}?\nAnswer:"
......
...@@ -3,7 +3,7 @@ group: ...@@ -3,7 +3,7 @@ group:
task: truthfulqa_gen task: truthfulqa_gen
dataset_path: truthful_qa dataset_path: truthful_qa
dataset_name: generation dataset_name: generation
output_type: greedy_until output_type: generate_until
training_split: null training_split: null
validation_split: validation validation_split: validation
test_split: null test_split: null
......
...@@ -3,7 +3,7 @@ group: ...@@ -3,7 +3,7 @@ group:
task: anagrams1 task: anagrams1
dataset_path: EleutherAI/unscramble dataset_path: EleutherAI/unscramble
dataset_name: mid_word_1_anagrams dataset_name: mid_word_1_anagrams
output_type: greedy_until output_type: generate_until
test_split: validation test_split: validation
doc_to_text: "{{context}}" doc_to_text: "{{context}}"
doc_to_target: "{{completion}}" doc_to_target: "{{completion}}"
......
...@@ -3,7 +3,7 @@ group: ...@@ -3,7 +3,7 @@ group:
task: anagrams2 task: anagrams2
dataset_path: EleutherAI/unscramble dataset_path: EleutherAI/unscramble
dataset_name: mid_word_2_anagrams dataset_name: mid_word_2_anagrams
output_type: greedy_until output_type: generate_until
test_split: validation test_split: validation
doc_to_text: "{{context}}" doc_to_text: "{{context}}"
doc_to_target: "{{completion}}" doc_to_target: "{{completion}}"
......
...@@ -3,7 +3,7 @@ group: ...@@ -3,7 +3,7 @@ group:
task: cycle_letters task: cycle_letters
dataset_path: EleutherAI/unscramble dataset_path: EleutherAI/unscramble
dataset_name: cycle_letters_in_word dataset_name: cycle_letters_in_word
output_type: greedy_until output_type: generate_until
test_split: validation test_split: validation
doc_to_text: "{{context}}" doc_to_text: "{{context}}"
doc_to_target: "{{completion}}" doc_to_target: "{{completion}}"
......
...@@ -3,7 +3,7 @@ group: ...@@ -3,7 +3,7 @@ group:
task: random_insertion task: random_insertion
dataset_path: EleutherAI/unscramble dataset_path: EleutherAI/unscramble
dataset_name: random_insertion_in_word dataset_name: random_insertion_in_word
output_type: greedy_until output_type: generate_until
test_split: validation test_split: validation
doc_to_text: "{{context}}" doc_to_text: "{{context}}"
doc_to_target: "{{completion}}" doc_to_target: "{{completion}}"
......
...@@ -3,7 +3,7 @@ group: ...@@ -3,7 +3,7 @@ group:
task: reversed_words task: reversed_words
dataset_path: EleutherAI/unscramble dataset_path: EleutherAI/unscramble
dataset_name: reversed_words dataset_name: reversed_words
output_type: greedy_until output_type: generate_until
test_split: validation test_split: validation
doc_to_text: "{{context}}" doc_to_text: "{{context}}"
doc_to_target: "{{completion}}" doc_to_target: "{{completion}}"
......
...@@ -5,7 +5,7 @@ dataset_path: wmt16 ...@@ -5,7 +5,7 @@ dataset_path: wmt16
dataset_name: ro-en dataset_name: ro-en
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: generate_until
doc_to_text: "translate English to Romanian: {{translation.en}}" doc_to_text: "translate English to Romanian: {{translation.en}}"
doc_to_target: "{{translation.ro}}" doc_to_target: "{{translation.ro}}"
metric_list: metric_list:
......
...@@ -78,7 +78,7 @@ def chunks(iter, n: int = 0, fn=None): ...@@ -78,7 +78,7 @@ def chunks(iter, n: int = 0, fn=None):
arr = [] arr = []
for i, x in enumerate(iter): for i, x in enumerate(iter):
arr.append(x) arr.append(x)
if len(arr) == (fn(i) if fn else n): if len(arr) == (fn(i, iter) if fn else n):
yield arr yield arr
arr = [] arr = []
......
...@@ -23,7 +23,7 @@ class DryrunLM(LM): ...@@ -23,7 +23,7 @@ class DryrunLM(LM):
return res return res
def greedy_until(self, requests): def generate_until(self, requests):
res = [] res = []
for ctx, _ in requests: for ctx, _ in requests:
......
...@@ -15,10 +15,10 @@ class Test_HFLM: ...@@ -15,10 +15,10 @@ class Test_HFLM:
multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")() # type: ignore multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")() # type: ignore
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1) multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: list[Instance] = multiple_choice_task.instances MULTIPLE_CH: list[Instance] = multiple_choice_task.instances
greedy_until_task = tasks.TASK_REGISTRY.get("gsm8k_yaml")() # type: ignore generate_until_task = tasks.TASK_REGISTRY.get("gsm8k_yaml")() # type: ignore
greedy_until_task.build_all_requests(limit=10, rank=0, world_size=1) generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
greedy_until_task._config.generation_kwargs["max_gen_toks"] = 10 generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
GREEDY_UNTIL: list[Instance] = greedy_until_task.instances generate_until: list[Instance] = generate_until_task.instances
rolling_task = tasks.TASK_REGISTRY.get("wikitext")() # type: ignore rolling_task = tasks.TASK_REGISTRY.get("wikitext")() # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1) rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
ROLLING: list[Instance] = rolling_task.instances ROLLING: list[Instance] = rolling_task.instances
...@@ -65,7 +65,7 @@ class Test_HFLM: ...@@ -65,7 +65,7 @@ class Test_HFLM:
-52.70050811767578, -52.70050811767578,
-56.25089645385742, -56.25089645385742,
] ]
GREEDY_UNTIL_RES = [ generate_until_RES = [
" The average of $2.50 each is $", " The average of $2.50 each is $",
" A robe takes 2 bolts of blue fiber and half", " A robe takes 2 bolts of blue fiber and half",
" $50,000 in repairs.", " $50,000 in repairs.",
...@@ -109,9 +109,9 @@ class Test_HFLM: ...@@ -109,9 +109,9 @@ class Test_HFLM:
), np.argmax(np.array(_res).reshape(-1, 4), axis=1) ), np.argmax(np.array(_res).reshape(-1, 4), axis=1)
assert (argmax_RES == argmax_res).all() assert (argmax_RES == argmax_res).all()
def test_greedy_until(self) -> None: def test_generate_until(self) -> None:
res = self.LM.greedy_until(self.GREEDY_UNTIL) res = self.LM.generate_until(self.generate_until)
assert res == self.GREEDY_UNTIL_RES assert res == self.generate_until_RES
def test_logliklihood_rolling(self) -> None: def test_logliklihood_rolling(self) -> None:
res = self.LM.loglikelihood_rolling(self.ROLLING) res = self.LM.loglikelihood_rolling(self.ROLLING)
......
...@@ -78,7 +78,7 @@ def test_gpt2(): ...@@ -78,7 +78,7 @@ def test_gpt2():
# test empty context # test empty context
gpt2.loglikelihood([("", "test")]) gpt2.loglikelihood([("", "test")])
(gen,) = gpt2.greedy_until( (gen,) = gpt2.generate_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])] [("The quick brown fox jumps over the lazy", [".", "\n"])]
) )
...@@ -204,7 +204,7 @@ def test_gpt3(): ...@@ -204,7 +204,7 @@ def test_gpt3():
# test empty context # test empty context
gpt3.loglikelihood([("", "test")]) gpt3.loglikelihood([("", "test")])
(gen,) = gpt3.greedy_until( (gen,) = gpt3.generate_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])] [("The quick brown fox jumps over the lazy", [".", "\n"])]
) )
...@@ -300,7 +300,7 @@ def test_textsynth(): ...@@ -300,7 +300,7 @@ def test_textsynth():
# test empty context # test empty context
textsynth.loglikelihood([("", "test")]) textsynth.loglikelihood([("", "test")])
(gen,) = textsynth.greedy_until( (gen,) = textsynth.generate_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])] [("The quick brown fox jumps over the lazy", [".", "\n"])]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment