Unverified Commit 42cc971f authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #615 from cardy20/polyglot

Add legal judgement prediction tasks : polyglot-ko evaluation
parents 1f66adc8 be3969c6
...@@ -7,3 +7,4 @@ build/ ...@@ -7,3 +7,4 @@ build/
logs/ logs/
output/ output/
lm_eval.egg-info/ lm_eval.egg-info/
shell/
# Decontamination
## Usage
Simply add a "--decontamination_ngrams_path" when running main.py. The provided directory should contain
the ngram files and info.json produced in "Pile Ngram Generation" further down.
```bash
python main.py \
--model gpt2 \
--device 0 \
--tasks sciq \
--decontamination_ngrams_path path/containing/training/set/ngrams
```
## Background
Downstream evaluations test model generalization, and are less useful when test set data also exists in the training set, referred to as leakage or contamination.
Filtering your training set against the test set is a good first step, however this isn't always possible, as in the case of a new benchmark or one that wasn't considered prior to model training. When training set filtering isn't possible, it is useful to measure the impact of test set leakage by detecting the contaminated test examples and producing a clean version of the benchmark.
The basis for our decontamination procedure can be found in Appendix C of "Language Models are Few-Shot Learners". OpenAI defined a test document as contaminated if any N-gram overlap existed with any training document. They used a range of N values between 8 and 13 depending on dataset, while we just used 13 for simplicity.
## Implementation
Contamination detection can be found in `lm_eval/decontaminate.py` with supporting code in `lm_eval/decontamination/`.
decontaminate.py does the following:
1. Build dictionaries of all ngrams and their corresponding evaluation/document ids.
2. Scan through sorted files containing training set n-grams.
3. If a match is found, the corresponding evaluation/document combinations are marked as contaminated.
`lm_eval/evaluator.py` can then produce a clean version of the benchmark by excluding the results of contaminated documents. For each metric, a clean version will be shown in the results with a "decontaminate" suffix.
This is disabled by default for new tasks, to support decontamination on a task override the "should_decontaminate" and "doc_to_decontamination_query" methods. For more details see the [task guide](task_guide.md).
## Pile Ngram Generation
The relevant scripts can be found in `scripts/clean_training_data`, which also import from
`lm_eval/decontamination/`
1. git clone https://github.com/EleutherAI/lm-evaluation-harness.git
2. pip install -r requirements.txt
3. Download The Pile from [The Eye](https://the-eye.eu/public/AI/pile/train/)
4. Place pile files in "pile" directory under "lm-evaluation-harness" (or create a symlink)
5. Run generate_13_grams.
```bash
export PYTHONHASHSEED=0
python -m scripts/clean_training_data/generate_13_grams \
-dir path/to/working/directory \
-n 13 \
-buckets 500
```
Took approximately 4 days for us. We had the time to wait, but this could be scaled out by doing partial pile scans on multiple instances of this script and merging the relevant buckets. We fixed PYTHONHASHSEED to ensure reproducibility of bucket hashing in case you need to stop and start.
6. Sort the generated 13-grams.
```bash
python -m scripts/clean_training_data/sort_13_gram_buckets \
-dir path/to/working/directory/output
```
Took approximately 5 days for us. You could speed this up by spreading the files around to different machines and running the sort script before gathering them together.
7. Compress the sorted 13 grams files and place them together with info.json.
This step only takes a few hours.
```bash
python -m scripts/clean_training_data/compress_and_package \
-dir path/to/working/directory \
-output path/to/final/directory \
-procs 8
```
Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument.
...@@ -391,6 +391,7 @@ class BaseLM(LM): ...@@ -391,6 +391,7 @@ class BaseLM(LM):
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
print(re_ord.arr)
for context, request_args in tqdm(re_ord.get_reordered()): for context, request_args in tqdm(re_ord.get_reordered()):
until = request_args["until"] until = request_args["until"]
if isinstance(until, str): if isinstance(until, str):
......
...@@ -248,6 +248,7 @@ def evaluate( ...@@ -248,6 +248,7 @@ def evaluate(
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
requests[req.request_type].append(req) requests[req.request_type].append(req)
# i: index in requests for a single task instance # i: index in requests for a single task instance
......
...@@ -8,6 +8,7 @@ import random ...@@ -8,6 +8,7 @@ import random
def mean(arr): def mean(arr):
print(len(arr))
return sum(arr) / len(arr) return sum(arr) / len(arr)
...@@ -41,7 +42,6 @@ def f1_score(items): ...@@ -41,7 +42,6 @@ def f1_score(items):
golds = unzipped_list[0] golds = unzipped_list[0]
preds = unzipped_list[1] preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds) fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore) return np.max(fscore)
def macro_f1_score(items): def macro_f1_score(items):
...@@ -49,7 +49,6 @@ def macro_f1_score(items): ...@@ -49,7 +49,6 @@ def macro_f1_score(items):
golds = unzipped_list[0] golds = unzipped_list[0]
preds = unzipped_list[1] preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds, average='macro') fscore = sklearn.metrics.f1_score(golds, preds, average='macro')
return fscore return fscore
def acc_all(items): def acc_all(items):
...@@ -194,8 +193,6 @@ def _sacreformat(refs, preds): ...@@ -194,8 +193,6 @@ def _sacreformat(refs, preds):
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n): def __init__(self, f, n):
self.f = f self.f = f
......
import torch import torch
import transformers import transformers
from typing import Optional from typing import Optional, Union
from lm_eval.base import BaseLM from lm_eval.base import BaseLM
def _get_dtype(
dtype: Union[str, torch.dtype]
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
class HFLM(BaseLM): class HFLM(BaseLM):
def __init__( def __init__(
self, self,
...@@ -16,6 +28,7 @@ class HFLM(BaseLM): ...@@ -16,6 +28,7 @@ class HFLM(BaseLM):
batch_size=1, batch_size=1,
load_in_8bit: Optional[bool] = False, load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
dtype: Optional[Union[str, torch.dtype]]="auto",
): ):
super().__init__() super().__init__()
...@@ -46,6 +59,7 @@ class HFLM(BaseLM): ...@@ -46,6 +59,7 @@ class HFLM(BaseLM):
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision, revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
).to(self.device) ).to(self.device)
self.gpt2.eval() self.gpt2.eval()
......
...@@ -57,6 +57,7 @@ from . import ko_translation ...@@ -57,6 +57,7 @@ from . import ko_translation
from . import korquad from . import korquad
from . import korunsmile from . import korunsmile
from . import kohatespeech from . import kohatespeech
from . import legal_test
from . import kold from . import kold
from . import kosbi from . import kosbi
from . import toxigen from . import toxigen
...@@ -346,6 +347,10 @@ TASK_REGISTRY = { ...@@ -346,6 +347,10 @@ TASK_REGISTRY = {
"kohatespeech":kohatespeech.HateSpeech, "kohatespeech":kohatespeech.HateSpeech,
"kohatespeech_gen_bias":kohatespeech.GenderBias, "kohatespeech_gen_bias":kohatespeech.GenderBias,
"kohatespeech_apeach":kohatespeech.Apeach, "kohatespeech_apeach":kohatespeech.Apeach,
"kolegal_legalcase":legal_test.LegalBinary,
"kolegal_civilcase":legal_test.LJPCivil,
"kolegal_criminalcase":legal_test.LJPCriminal,
=======
"kosbi":kosbi.KoSBi, "kosbi":kosbi.KoSBi,
**xcopa.construct_tasks(), **xcopa.construct_tasks(),
**bigbench.create_all_tasks(), **bigbench.create_all_tasks(),
......
...@@ -9,7 +9,8 @@ import hashlib ...@@ -9,7 +9,8 @@ import hashlib
import functools import functools
import numpy as np import numpy as np
import re import re
import importlib.resources # import importlib.resources
from importlib_resources import files
from lm_eval.base import rf, Task from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -229,7 +230,8 @@ def create_task_from_path(json_path): ...@@ -229,7 +230,8 @@ def create_task_from_path(json_path):
def create_all_tasks(): def create_all_tasks():
resources_dir = importlib.resources.files("lm_eval.datasets") / "bigbench_resources" # resources_dir = importlib.resources.files("lm_eval.datasets") / "bigbench_resources"
resources_dir = files("lm_eval.datasets") / "bigbench_resources"
supported_tasks = [os.path.splitext(x)[0] for x in os.listdir(resources_dir)] supported_tasks = [os.path.splitext(x)[0] for x in os.listdir(resources_dir)]
res = {} res = {}
for task_name in supported_tasks: for task_name in supported_tasks:
......
"""
Korean legal AI datasets, LBox OPEN
Multi-task on Legal corpus
https://arxiv.org/pdf/2206.05224.pdf
"""
import numpy as np
from lm_eval.base import Task, MultipleChoiceTask, rf
from lm_eval.metrics import bleu, chrf, ter
from lm_eval.metrics import macro_f1_score, mean, matthews_corrcoef, f1_score, yesno
from lm_eval.utils import general_detokenize
_CITATION ="""
@article{hwang2022multi,
title={A multi-task benchmark for korean legal language understanding and judgement prediction},
author={Hwang, Wonseok and Lee, Dongjun and Cho, Kyoungyeon and Lee, Hanuhl and Seo, Minjoon},
journal={arXiv preprint arXiv:2206.05224},
year={2022}
}
"""
class LegalBinary(Task):
""" Predict civil(민사) or criminal(형사) case"""
VERSION = 0
DATASET_PATH = "lbox/lbox_open"
DATASET_NAME = "casename_classification"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["valid"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def doc_to_text(self, doc):
return "문장: {} ".format(doc["facts"])
def doc_to_target(self, doc):
return " {}".format({"civil": "민사", "criminal": "형사"}[doc["casetype"]])
def construct_requests(self, doc, ctx):
ll_m, _ = rf.loglikelihood(ctx, " 민사")
ll_h, _ = rf.loglikelihood(ctx, " 형사")
return ll_m, ll_h
def process_results(self, doc, results):
ll_m, ll_h = results
pred = ll_h > ll_m
gold = {"civil": 0, "criminal": 1}[doc["casetype"]]
return {
"acc": pred == gold,
"macro_f1": (gold, pred)
}
def higher_is_better(self):
return {
"acc": True,
"macro_f1": True
}
def aggregation(self):
return {
"acc": mean,
"macro_f1": macro_f1_score
}
class LJPCivil(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "lbox/lbox_open"
DATASET_NAME = "ljp_civil"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def doc_to_text(self, doc):
return doc["query"]
def doc_to_target(self, doc):
return " {}".format(doc['gold'])
def proces_label(self, doc):
return {'구상금':0, '대여금':1, '부당이득금':2, '손해배상(기)':3}[doc['gold']]
def _process_doc(self, doc):
out_doc = {
"query": "{}".format(doc['facts']),
"choices": ['구상금', '대여금', '부당이득금', '손해배상(기)'],
"gold": doc['casename']
}
return out_doc
def process_results(self, doc, results):
pred = np.argmax(results)
gold = self.proces_label(doc)
return {
"acc": pred == gold,
"macro_f1": (gold, pred)
}
def higher_is_better(self):
return {
"acc": True,
"macro_f1": True
}
def aggregation(self):
return {
"acc": mean,
"macro_f1": macro_f1_score
}
class LJPCivil(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "lbox/lbox_open"
DATASET_NAME = "ljp_civil"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def doc_to_text(self, doc):
return doc["query"]
def doc_to_target(self, doc):
return " {}".format(doc['gold'])
def proces_label(self, doc):
return {'구상금':0, '대여금':1, '부당이득금':2, '손해배상(기)':3}[doc['gold']]
def _process_doc(self, doc):
out_doc = {
"query": "{}".format(doc['facts']),
"choices": ['구상금', '대여금', '부당이득금', '손해배상(기)'],
"gold": doc['casename']
}
return out_doc
def process_results(self, doc, results):
pred = np.argmax(results)
gold = self.proces_label(doc)
return {
"acc": pred == gold,
"macro_f1": (gold, pred)
}
def higher_is_better(self):
return {
"acc": True,
"macro_f1": True
}
def aggregation(self):
return {
"acc": mean,
"macro_f1": macro_f1_score
}
class LJPCriminal(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "lbox/lbox_open"
DATASET_NAME = "ljp_criminal"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def doc_to_text(self, doc):
return doc["query"]
def doc_to_target(self, doc):
return " {}".format(doc['gold'])
def proces_label(self, doc):
return {'강제추행':0, '공무집행방해':1, '교통사고처리특례법위반(치상)':2, '도로교통법위반(음주운전)':3,\
'사기':4, '상해':5, '폭행':6}[doc['gold']]
def _process_doc(self, doc):
out_doc = {
"query": "{}".format(doc['facts']),
"choices": ['강제추행', '공무집행방해', '교통사고처리특례법위반(치상)', \
'도로교통법위반(음주운전)', '사기', '상해', '폭행'],
"gold": doc['casename']
}
return out_doc
def process_results(self, doc, results):
pred = np.argmax(results)
gold = self.proces_label(doc)
return {
"acc": pred == gold,
"macro_f1": (gold, pred)
}
def higher_is_better(self):
return {
"acc": True,
"macro_f1": True
}
def aggregation(self):
return {
"acc": mean,
"macro_f1": macro_f1_score
}
class LegalSummarization(Task):
VERSION = 0
def __init__(self):
pass
def has_training_docs(self):
"""Whether the task has a training set"""
return True
def has_validation_docs(self):
"""Whether the task has a validation set"""
return True
def has_test_docs(self):
"""Whether the task has a test set"""
return True
def training_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
if self._training_docs is None:
self._training_docs = [
{"src": src, "tgt": tgt} for src, tgt in zip(self.train_src, self.train_tgt)
]
return self._training_docs
def validation_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [
{"src": src, "tgt": tgt} for src, tgt in zip(self.valid_src, self.valid_tgt)
]
def test_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [
{"src": src, "tgt": tgt} for src, tgt in zip(self.tst_src, self.tst_tgt)
]
def doc_to_text(self, doc):
src_lang = self.src_lang
tar_lang = self.tar_lang
if src_lang == 'ko':
return f"{src_lang}{tar_lang}으로 번역해주는 모델입니다.\n\n###\n{src_lang}:" + doc["src"] + f"\n{tar_lang}:"
elif src_lang == 'en':
return f"Translate {src_lang} to {tar_lang}.\n\n###\n{src_lang}:" + doc["src"] + f"\n{tar_lang}:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["src"]
def doc_to_target(self, doc):
# This shows a single target, though there may be multiple targets in a lang test
return " " + doc["tgt"] if isinstance(doc["tgt"], str) else doc["tgt"][0]
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results):
ref_pred = (doc["tgt"], results)
return {
"bleu": ref_pred,
"chrf": ref_pred,
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"bleu": bleu,
"chrf": chrf,
"ter": ter,
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"bleu": True,
"chrf": True,
"ter": False,
}
def __str__(self):
return f"{self.src_lang} to {self.tar_lang} Task"
...@@ -108,7 +108,6 @@ def main(): ...@@ -108,7 +108,6 @@ def main():
) )
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path: if args.output_path:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True) os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
......
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_criminalcase --num_fewshot 0
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_criminalcase --num_fewshot 5
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_criminalcase --num_fewshot 10
python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
--task ko_en_translation --num_fewshot 5
# test : numbers
#python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_legalcase --num_fewshot 0
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_legalcase --num_fewshot 5
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_legalcase --num_fewshot 10
import os
import zstandard
import json
import jsonlines
import io
import datetime
def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)):
return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file.
class Archive:
def __init__(self, file_path, compression_level=3):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, 'wb')
self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n')
def commit(self):
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
def __init__(self):
pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'):
with open(file, 'rb') as fh:
self.fh = fh
#cctx = zstandard.ZstdDecompressor()
# reader = io.BufferedReader(cctx.stream_reader(fh))
reader = io.BufferedReader(fh)
rdr = jsonlines.Reader(reader)
for ob in rdr:
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
if isinstance(ob, str):
assert not get_meta
yield ob
continue
text = ob['text']
if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text)
if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {})
else:
yield text
# Simple text reader and writer with same interface as above
class TextArchive:
def __init__(self, file_path, mode="ab"):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
def add_data(self, data, meta={}):
self.fh.write(data.encode('UTF-8') + b'\n')
def commit(self):
self.fh.flush()
self.fh.close()
class TextReader:
def __init__(self, file_path):
self.file_path = file_path
def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
self.fh = fh
while True:
line = self.fh.readline()
if line == -1 or line == "":
break
else :
yield line[:-1]
\ No newline at end of file
...@@ -41,89 +41,22 @@ from tqdm_multiprocess.logger import setup_logger_tqdm ...@@ -41,89 +41,22 @@ from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
terminate = False terminate = False
def handler(signal_received, frame): def handler(signal_received, frame):
global terminate global terminate
terminate = True terminate = True
def get_pile(directory):
def yield_pile(start_offsets=None, checkpoint_offset=None):
directory = "pile"
if not os.path.exists(directory):
print(
"We expect the pile archives to be in the 'pile' directory, but this was not found."
)
raise Exception("Pile directory not found.")
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
pile_global_offset = 0
start_file = 0
if checkpoint_offset:
for file_i, start_offset in enumerate(start_offsets):
if start_offset > checkpoint_offset:
break
start_file = file_i
pile_global_offset = start_offset
for file_i, file in enumerate(files):
if file_i < start_file:
logger.info(f"Skipping file {file}")
continue
logger.info(f"Reading from pile file: {file}")
reader = Reader() reader = Reader()
for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")):
for document in reader.read(file): for document in reader.read(file):
yield (pile_global_offset, document) yield document
pile_global_offset += 1
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
def __init__(self, directory, num_buckets):
self.bucket_files = [
os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)
]
self.buckets = list(map(TextArchive, self.bucket_files))
self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")
if os.path.exists(self.checkpoint_file):
self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
else:
self.bucket_offsets = [0 for i in range(len(self.buckets))]
for i, offset in enumerate(self.bucket_offsets):
bucket = self.buckets[i]
bucket.fh.seek(offset)
bucket.fh.truncate()
def add_data(self, key, value):
i = hash(key) % len(self.buckets)
bucket = self.buckets[i]
bucket.add_data(value)
def save_checkpoint(self):
for bucket in self.buckets:
bucket.fh.flush()
bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))
def close_buckets(self): def close_buckets(self):
for bucket in self.buckets: for bucket in self.buckets:
bucket.commit() bucket.commit()
def do_ngrams_in_buckets(n_value, working_directory, bucket_count): def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
pile_statistics = json.load(open("pile_statistics.json", "r"))
pile_document_count = pile_statistics["Document Count"]
start_offsets = pile_statistics["File Start Offsets"]
output_directory = os.path.join(working_directory, "output") output_directory = os.path.join(working_directory, "output")
os.makedirs(output_directory, exist_ok=True) os.makedirs(output_directory, exist_ok=True)
...@@ -165,10 +98,6 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count): ...@@ -165,10 +98,6 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
return return
continue continue
if offset == checkpoint_offset:
progress.reset(total=pile_document_count)
progress.update(checkpoint_offset)
# Save checkpoint every "batch_size", only allow terminate after checkpoint # Save checkpoint every "batch_size", only allow terminate after checkpoint
if batch_counter == batch_size: if batch_counter == batch_size:
progress.update(batch_size) progress.update(batch_size)
...@@ -191,6 +120,7 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count): ...@@ -191,6 +120,7 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.") parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.")
parser.add_argument("-dir", "--working_directory", default="") parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-sdir", "--save_directory", default="")
parser.add_argument("-n", "--n_value", type=int, default=13) parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500) parser.add_argument("-buckets", "--bucket_count", type=int, default=500)
...@@ -210,7 +140,3 @@ if __name__ == "__main__": ...@@ -210,7 +140,3 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count) do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
\ No newline at end of file
info_dict = {"title": "dataset ngrams", "ngram_size": 13}
info_dict_path = os.path.join(args.working_directory, "info.json")
json.dump(info_dict, open(info_dict_path, "w"))
from lm_eval.decontamination.archiver import Reader
import os
import json
from functools import reduce
import glob
import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
def get_file_stats(file_path, tqdm_func, global_tqdm):
reader = Reader()
total_documents = 0
total_size = 0
update_frequency = 10000
current_file_position = 0
with tqdm_func(
total=os.path.getsize(file_path), dynamic_ncols=True, unit="byte", unit_scale=1
) as progress:
for document in reader.read(file_path, get_meta=True):
total_size += len(document)
total_documents += 1
if total_documents % update_frequency == 0:
new_file_pos = reader.fh.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
global_tqdm.update(bytes_read)
return (total_documents, total_size)
def get_files():
directory = "pile"
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
print(files)
return files
def get_stats():
files = get_files()
total_size_bytes = sum(map(lambda x: os.path.getsize(x), files))
pool = TqdmMultiProcessPool(4)
global_tqdm = tqdm.tqdm(
total=total_size_bytes, dynamic_ncols=True, unit="byte", unit_scale=1
)
# Generate minhashes with pool
tasks = [(get_file_stats, (file,)) for file in files]
def on_done(_):
return None
def on_error(_):
return None
results = pool.map(global_tqdm, tasks, on_error, on_done)
total_documents, total_size = reduce(
lambda x, y: (x[0] + y[0], x[1] + y[1]), results
)
start_offsets = []
current_offset = 0
for file_document_count, _ in results:
start_offsets.append(current_offset)
current_offset += file_document_count
return (total_documents, total_size, start_offsets)
if __name__ == "__main__":
version = 1.01
print(f"Running version {version}")
stats_file_path = "pile_statistics.json"
if os.path.exists(stats_file_path):
stats = json.load(open(stats_file_path, "r"))
else:
document_count, total_document_size_chars, start_offsets = get_stats()
stats = {
"Data": "Pile statistics",
"Document Count": document_count,
"Total Pile Characters": total_document_size_chars,
"File Start Offsets": start_offsets,
}
json.dump(stats, open(stats_file_path, "w"), indent=4)
print(f"document_count: {stats['Document Count']}") print(f"document_count: {stats['Document Count']}")
print(f"total_chars: {stats['Total Pile Characters']}") print(f"total_chars: {stats['Total Pile Characters']}")
......
...@@ -34,6 +34,13 @@ def sort_13_gram_buckets(working_directory): ...@@ -34,6 +34,13 @@ def sort_13_gram_buckets(working_directory):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt")) bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt"))
for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True): for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path))
done_file = os.path.join(working_directory, f"ngram_bucket_sorting_{bucket_id}.done")
if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping")
return
sorted_file_path = bucket_file_path + ".sorted" sorted_file_path = bucket_file_path + ".sorted"
command = f"sort {bucket_file_path} > {sorted_file_path}" command = f"sort {bucket_file_path} > {sorted_file_path}"
logger.info(command) logger.info(command)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment