lambada_multilingual.py 2.38 KB
Newer Older
sdtblck's avatar
sdtblck committed
1
2
3
4
5
from . import lambada
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from best_download import download_file
Leo Gao's avatar
Leo Gao committed
6
import json
Leo Gao's avatar
Leo Gao committed
7
from functools import partial
sdtblck's avatar
sdtblck committed
8
import os 
sdtblck's avatar
sdtblck committed
9

10
11
# This task is lambada but machine-translated to the other languages.

sdtblck's avatar
sdtblck committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
LANGS = ["en", "fr", "de", "it", "es"]
CHECKSUMS = {"en": "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226", 
             "fr": "941ec6a73dba7dc91c860bf493eb66a527cd430148827a4753a4535a046bf362", 
             "de": "51c6c1795894c46e88e4c104b5667f488efe79081fb34d746b82b8caa663865e", 
             "it": "86654237716702ab74f42855ae5a78455c1b0e50054a4593fb9c6fcf7fad0850", 
             "es": "ffd760026c647fb43c67ce1bc56fd527937304b348712dce33190ea6caba6f9c"
            }

class MultilingualLAMBADA(lambada.LAMBADA):
    VERSION = 0
    
    def __init__(self, lang=None):
      self.LANG = lang
      super().__init__()
    
    def download(self):
      sh("mkdir -p data/lambada")
sdtblck's avatar
sdtblck committed
29
30
31
      f = f"data/lambada/lambada_test_{self.LANG}.jsonl"
      url = f"http://eaidata.bmk.sh/data/lambada_test_{self.LANG}.jsonl"
      try:
32
        if not os.path.exists(f):
sdtblck's avatar
sdtblck committed
33
34
35
36
37
38
39
40
41
42
          download_file(
              url, 
              f, 
              CHECKSUMS[self.LANG]
          )
      except:
        # fallback - for some reason best_download doesnt work all the time here
        sh(f"wget {url} -O {f}")
        sh(f'echo "{CHECKSUMS[self.LANG]}  {f}" | sha256sum --check')

sdtblck's avatar
sdtblck committed
43
44
45
46
47
48

    def validation_docs(self):
      with open(f"data/lambada/lambada_test_{self.LANG}.jsonl") as fh:
        for line in fh:
          yield json.loads(line)

sdtblck's avatar
sdtblck committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class MultilingualLAMBADAEN(MultilingualLAMBADA):
  def __init__(self):
    super().__init__('en')

class MultilingualLAMBADAFR(MultilingualLAMBADA):
  def __init__(self):
    super().__init__('fr')

class MultilingualLAMBADADE(MultilingualLAMBADA):
  def __init__(self):
    super().__init__('de')

class MultilingualLAMBADAIT(MultilingualLAMBADA):
  def __init__(self):
    super().__init__('it')

class MultilingualLAMBADAES(MultilingualLAMBADA):
  def __init__(self):
    super().__init__('es')

LANG_CLASSES = [MultilingualLAMBADAEN, MultilingualLAMBADAFR, MultilingualLAMBADADE, MultilingualLAMBADAIT, MultilingualLAMBADAES]

sdtblck's avatar
sdtblck committed
71
72
def construct_tasks():
    tasks = {}
sdtblck's avatar
sdtblck committed
73
74
    for lang, lang_class in zip(LANGS, LANG_CLASSES):
        tasks[f"lambada_mt_{lang}"] = lang_class
sdtblck's avatar
sdtblck committed
75
    return tasks
sdtblck's avatar
sdtblck committed
76