Unverified Commit d67c77be authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #214 from Muennighoff/tokenization

Add space tokenization for JA/ZH
parents 4d2c1522 c0f0f7e1
...@@ -3,6 +3,11 @@ from pprint import pprint ...@@ -3,6 +3,11 @@ from pprint import pprint
from sacrebleu import sacrebleu from sacrebleu import sacrebleu
from lm_eval import metrics from lm_eval import metrics
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from typing import List
import jieba
import nagisa
""" """
This file implements translation tasks using datasets from WMT conferences, provided by sacrebleu. This file implements translation tasks using datasets from WMT conferences, provided by sacrebleu.
...@@ -19,18 +24,38 @@ def create_tasks_from_benchmarks(benchmark_dict): ...@@ -19,18 +24,38 @@ def create_tasks_from_benchmarks(benchmark_dict):
:return: {task_name: task} :return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task} e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
""" """
def version_of(dataset, language_pair):
if language_pair[-2:] in ["zh", "ja"]:
return 1 # changed to use jieba/nagisa
return 0
return { return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair) f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair, version_of(dataset, language_pair))
for dataset, language_pairs in benchmark_dict.items() for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs for language_pair in language_pairs
} }
########################################
# Language Specifics
########################################
def zh_split(zh_text: List[str]) -> List[str]:
"""Chinese splitting"""
return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]
def ja_split(ja_text: List[str]) -> List[str]:
"""Japanese splitting"""
return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]
NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split}
######################################## ########################################
# Tasks # Tasks
######################################## ########################################
def create_translation_task(dataset, language_pair): def create_translation_task(dataset, language_pair, version=0):
class TranslationTask(GeneralTranslationTask): class TranslationTask(GeneralTranslationTask):
VERSION = version
def __init__(self): def __init__(self):
super().__init__(dataset, language_pair) super().__init__(dataset, language_pair)
return TranslationTask return TranslationTask
...@@ -102,6 +127,12 @@ class GeneralTranslationTask(Task): ...@@ -102,6 +127,12 @@ class GeneralTranslationTask(Task):
return rf.greedy_until(ctx, ["\n"]) return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results): def process_results(self, doc, results):
# Add spaces between words for BLEU score calculation of target languages like Chinese
tar_lang_code = self.sacrebleu_language_pair.split("-")[-1]
if tar_lang_code in NO_SPACE_LANG:
doc["ref"] = NO_SPACE_LANG[tar_lang_code]([doc["ref"]])[0]
results = NO_SPACE_LANG[tar_lang_code](results)
# These metrics are corpus-level not sentence level, so we'll hide the # These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method # results in this dict and compute the corpus score in the aggregate method
ref_pred = (doc["ref"], results) ref_pred = (doc["ref"], results)
......
...@@ -39,6 +39,8 @@ setuptools.setup( ...@@ -39,6 +39,8 @@ setuptools.setup(
"zstandard==0.15.2", "zstandard==0.15.2",
"jsonlines==2.0.0", "jsonlines==2.0.0",
"mock==4.0.3", "mock==4.0.3",
"openai==0.6.4" "openai==0.6.4",
"jieba==0.42.1",
"nagisa==0.2.7"
] ]
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment