"vscode:/vscode.git/clone" did not exist on "3a8428ecaa6375996de0142afd73df2f98c4cc23"
Commit 8e041322 authored by cardy20's avatar cardy20
Browse files

korean legal update

parent 7604b873
...@@ -6,4 +6,5 @@ lm_cache ...@@ -6,4 +6,5 @@ lm_cache
build/ build/
logs/ logs/
output/ output/
lm_eval.egg-info/ lm_eval.egg-info/
\ No newline at end of file shell/
...@@ -248,6 +248,7 @@ def evaluate( ...@@ -248,6 +248,7 @@ def evaluate(
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
requests[req.request_type].append(req) requests[req.request_type].append(req)
# i: index in requests for a single task instance # i: index in requests for a single task instance
......
...@@ -194,8 +194,6 @@ def _sacreformat(refs, preds): ...@@ -194,8 +194,6 @@ def _sacreformat(refs, preds):
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n): def __init__(self, f, n):
self.f = f self.f = f
......
import torch import torch
import transformers import transformers
from typing import Optional from typing import Optional, Union
from lm_eval.base import BaseLM from lm_eval.base import BaseLM
def _get_dtype(
dtype: Union[str, torch.dtype]
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
class HFLM(BaseLM): class HFLM(BaseLM):
def __init__( def __init__(
self, self,
...@@ -16,6 +28,7 @@ class HFLM(BaseLM): ...@@ -16,6 +28,7 @@ class HFLM(BaseLM):
batch_size=1, batch_size=1,
load_in_8bit: Optional[bool] = False, load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
dtype: Optional[Union[str, torch.dtype]]="auto",
): ):
super().__init__() super().__init__()
...@@ -46,37 +59,13 @@ class HFLM(BaseLM): ...@@ -46,37 +59,13 @@ class HFLM(BaseLM):
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision, revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
).to(self.device) ).to(self.device)
self.gpt2.eval() self.gpt2.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
<<<<<<< HEAD
<<<<<<< HEAD
revision=revision + ("/" + subfolder if subfolder is not None else ""))
# assert isinstance(self.tokenizer, (
# transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
# transformers.T5Tokenizer, transformers.T5TokenizerFast,
# )), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size
# if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
# assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \
# self.tokenizer.encode('hello\n\nhello')
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
=======
=======
>>>>>>> e8f38aee79569d51bd6c84f23f4227771291a816
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -97,12 +86,7 @@ class HFLM(BaseLM): ...@@ -97,12 +86,7 @@ class HFLM(BaseLM):
if batch_size == "auto": if batch_size == "auto":
self.batch_size_per_gpu = batch_size self.batch_size_per_gpu = batch_size
else: else:
<<<<<<< HEAD
self.batch_size_per_gpu = int(batch_size)
>>>>>>> 0542d35d5e56768dd9041ef9b88b90256970d843
=======
self.batch_size_per_gpu = int(batch_size) self.batch_size_per_gpu = int(batch_size)
>>>>>>> e8f38aee79569d51bd6c84f23f4227771291a816
@property @property
def eot_token_id(self): def eot_token_id(self):
...@@ -157,4 +141,4 @@ class HFLM(BaseLM): ...@@ -157,4 +141,4 @@ class HFLM(BaseLM):
# for backwards compatibility # for backwards compatibility
GPT2LM = HFLM GPT2LM = HFLM
\ No newline at end of file
...@@ -57,6 +57,7 @@ from . import ko_translation ...@@ -57,6 +57,7 @@ from . import ko_translation
from . import korquad from . import korquad
from . import korunsmile from . import korunsmile
from . import kohatespeech from . import kohatespeech
from . import legal_test
from . import kold from . import kold
from . import toxigen from . import toxigen
from . import crowspairs from . import crowspairs
...@@ -345,6 +346,7 @@ TASK_REGISTRY = { ...@@ -345,6 +346,7 @@ TASK_REGISTRY = {
"kohatespeech":kohatespeech.HateSpeech, "kohatespeech":kohatespeech.HateSpeech,
"kohatespeech_gen_bias":kohatespeech.GenderBias, "kohatespeech_gen_bias":kohatespeech.GenderBias,
"kohatespeech_apeach":kohatespeech.Apeach, "kohatespeech_apeach":kohatespeech.Apeach,
"kolegal_legalcase":legal_test.LegalCasename,
**xcopa.construct_tasks(), **xcopa.construct_tasks(),
**bigbench.create_all_tasks(), **bigbench.create_all_tasks(),
**xstorycloze.create_all_tasks(), **xstorycloze.create_all_tasks(),
......
...@@ -9,7 +9,8 @@ import hashlib ...@@ -9,7 +9,8 @@ import hashlib
import functools import functools
import numpy as np import numpy as np
import re import re
import importlib.resources # import importlib.resources
from importlib_resources import files
from lm_eval.base import rf, Task from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -229,7 +230,8 @@ def create_task_from_path(json_path): ...@@ -229,7 +230,8 @@ def create_task_from_path(json_path):
def create_all_tasks(): def create_all_tasks():
resources_dir = importlib.resources.files("lm_eval.datasets") / "bigbench_resources" # resources_dir = importlib.resources.files("lm_eval.datasets") / "bigbench_resources"
resources_dir = files("lm_eval.datasets") / "bigbench_resources"
supported_tasks = [os.path.splitext(x)[0] for x in os.listdir(resources_dir)] supported_tasks = [os.path.splitext(x)[0] for x in os.listdir(resources_dir)]
res = {} res = {}
for task_name in supported_tasks: for task_name in supported_tasks:
......
"""
Korean legal AI datasets, LBox OPEN
Multi-task on Legal corpus
https://arxiv.org/pdf/2206.05224.pdf
"""
import numpy as np
from lm_eval.base import Task, MultipleChoiceTask, rf
from lm_eval.metrics import macro_f1_score, mean, matthews_corrcoef, f1_score, yesno
from lm_eval.utils import general_detokenize
_CITATION ="""
@article{hwang2022multi,
title={A multi-task benchmark for korean legal language understanding and judgement prediction},
author={Hwang, Wonseok and Lee, Dongjun and Cho, Kyoungyeon and Lee, Hanuhl and Seo, Minjoon},
journal={arXiv preprint arXiv:2206.05224},
year={2022}
}
"""
class LegalCasename(Task):
VERSION = 0
DATASET_PATH = "lbox/lbox_open"
DATASET_NAME = "casename_classification"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["valid"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def doc_to_text(self, doc):
return "문장: {} ".format(doc["facts"])
def doc_to_target(self, doc):
return " {}".format({"civil": "민사", "criminal": "형사"}[doc["casetype"]])
def construct_requests(self, doc, ctx):
ll_m, _ = rf.loglikelihood(ctx, " 민사")
ll_h, _ = rf.loglikelihood(ctx, " 형사")
return ll_m, ll_h
def process_results(self, doc, results):
ll_m, ll_h = results
pred = ll_h > ll_m
gold = {"civil": 0, "criminal": 1}[doc["casetype"]]
return {
"acc": pred == gold,
"macro_f1": (gold, pred)
}
def higher_is_better(self):
return {
"acc": True,
"macro_f1": True
}
def aggregation(self):
return {
"acc": mean,
"macro_f1": macro_f1_score
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment