Unverified Commit d67c77be authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #214 from Muennighoff/tokenization

Add space tokenization for JA/ZH
parents 4d2c1522 c0f0f7e1
......@@ -3,6 +3,11 @@ from pprint import pprint
from sacrebleu import sacrebleu
from lm_eval import metrics
from lm_eval.base import Task, rf
from typing import List
import jieba
import nagisa
"""
This file implements translation tasks using datasets from WMT conferences, provided by sacrebleu.
......@@ -19,18 +24,38 @@ def create_tasks_from_benchmarks(benchmark_dict):
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
def version_of(dataset, language_pair):
if language_pair[-2:] in ["zh", "ja"]:
return 1 # changed to use jieba/nagisa
return 0
return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair)
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair, version_of(dataset, language_pair))
for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs
}
########################################
# Language Specifics
########################################
def zh_split(zh_text: List[str]) -> List[str]:
"""Chinese splitting"""
return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]
def ja_split(ja_text: List[str]) -> List[str]:
"""Japanese splitting"""
return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]
NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split}
########################################
# Tasks
########################################
def create_translation_task(dataset, language_pair):
def create_translation_task(dataset, language_pair, version=0):
class TranslationTask(GeneralTranslationTask):
VERSION = version
def __init__(self):
super().__init__(dataset, language_pair)
return TranslationTask
......@@ -102,6 +127,12 @@ class GeneralTranslationTask(Task):
return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results):
# Add spaces between words for BLEU score calculation of target languages like Chinese
tar_lang_code = self.sacrebleu_language_pair.split("-")[-1]
if tar_lang_code in NO_SPACE_LANG:
doc["ref"] = NO_SPACE_LANG[tar_lang_code]([doc["ref"]])[0]
results = NO_SPACE_LANG[tar_lang_code](results)
# These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method
ref_pred = (doc["ref"], results)
......
......@@ -39,6 +39,8 @@ setuptools.setup(
"zstandard==0.15.2",
"jsonlines==2.0.0",
"mock==4.0.3",
"openai==0.6.4"
"openai==0.6.4",
"jieba==0.42.1",
"nagisa==0.2.7"
]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment