Commit 9161ebbc authored by fromSun2Moon's avatar fromSun2Moon
Browse files

legal update

parent dc1f2539
......@@ -391,6 +391,7 @@ class BaseLM(LM):
re_ord = utils.Reorderer(requests, _collate)
print(re_ord.arr)
for context, request_args in tqdm(re_ord.get_reordered()):
until = request_args["until"]
if isinstance(until, str):
......
......@@ -8,6 +8,7 @@ import random
def mean(arr):
print(len(arr))
return sum(arr) / len(arr)
......@@ -41,7 +42,6 @@ def f1_score(items):
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds)
return np.max(fscore)
def macro_f1_score(items):
......@@ -49,7 +49,6 @@ def macro_f1_score(items):
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds, average='macro')
return fscore
def acc_all(items):
......
......@@ -266,3 +266,117 @@ class LJPCriminal(MultipleChoiceTask):
"macro_f1": macro_f1_score
}
class LegalSummarization(Task):
VERSION = 0
def __init__(self):
pass
def has_training_docs(self):
"""Whether the task has a training set"""
return True
def has_validation_docs(self):
"""Whether the task has a validation set"""
return True
def has_test_docs(self):
"""Whether the task has a test set"""
return True
def training_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
if self._training_docs is None:
self._training_docs = [
{"src": src, "tgt": tgt} for src, tgt in zip(self.train_src, self.train_tgt)
]
return self._training_docs
def validation_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [
{"src": src, "tgt": tgt} for src, tgt in zip(self.valid_src, self.valid_tgt)
]
def test_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [
{"src": src, "tgt": tgt} for src, tgt in zip(self.tst_src, self.tst_tgt)
]
def doc_to_text(self, doc):
src_lang = self.src_lang
tar_lang = self.tar_lang
if src_lang == 'ko':
return f"{src_lang}{tar_lang}으로 번역해주는 모델입니다.\n\n###\n{src_lang}:" + doc["src"] + f"\n{tar_lang}:"
elif src_lang == 'en':
return f"Translate {src_lang} to {tar_lang}.\n\n###\n{src_lang}:" + doc["src"] + f"\n{tar_lang}:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["src"]
def doc_to_target(self, doc):
# This shows a single target, though there may be multiple targets in a lang test
return " " + doc["tgt"] if isinstance(doc["tgt"], str) else doc["tgt"][0]
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results):
ref_pred = (doc["tgt"], results)
return {
"bleu": ref_pred,
"chrf": ref_pred,
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"bleu": bleu,
"chrf": chrf,
"ter": ter,
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"bleu": True,
"chrf": True,
"ter": False,
}
def __str__(self):
return f"{self.src_lang} to {self.tar_lang} Task"
......@@ -108,7 +108,6 @@ def main():
)
dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
......
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_criminalcase --num_fewshot 0
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_criminalcase --num_fewshot 5
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_criminalcase --num_fewshot 10
python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
--task kolegal_criminalcase --num_fewshot 0
python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
--task kolegal_criminalcase --num_fewshot 5
python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
--task kolegal_criminalcase --num_fewshot 10
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_legalcase --num_fewshot 0
--task ko_en_translation --num_fewshot 5
# test : numbers
#python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_legalcase --num_fewshot 0
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
# --task kolegal_legalcase --num_fewshot 5
# python3 -W ignore main.py --model gpt2 --model_args pretrained=EleutherAI/polyglot-ko-1.3B \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment