Unverified Commit cc88a0cf authored by Taekyoon's avatar Taekyoon Committed by GitHub
Browse files

Add nsmc task (Korean language task)

- Add nsmc dataset and task modules
parents 4887d9d3 5e7738d7
# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Naver movie review corpus for binary sentiment classification"""
import csv
import datasets
_CITATION = """\
@InProceedings{Park:2016,
title = "Naver Sentiment Movie Corpus",
author = "Lucy Park",
year = "2016",
howpublished = {\\url{https://github.com/e9t/nsmc}}
}
"""
_DESCRIPTION = """\
This is a movie review dataset in the Korean language. Reviews were scraped from Naver movies. The dataset construction is based on the method noted in Large movie review dataset from Maas et al., 2011.
"""
_HOMEPAGE = "https://github.com/e9t/nsmc/"
_LICENSE = "CC0 1.0 Universal (CC0 1.0)"
_URL = "https://raw.githubusercontent.com/e9t/nsmc/master/"
_URLs = {
"train": _URL + "ratings_train.txt",
"test": _URL + "ratings_test.txt",
}
# TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case
class NSMC(datasets.GeneratorBasedBuilder):
"""Korean Naver movie review dataset."""
VERSION = datasets.Version("1.1.0")
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"document": datasets.Value("string"),
"label": datasets.ClassLabel(names=["부정", "긍정"]),
}
),
supervised_keys=None,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
downloaded_files = dl_manager.download_and_extract(_URLs)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": downloaded_files["train"],
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": downloaded_files["validation"],
"split": "test",
},
),
]
def _generate_examples(self, filepath, split):
with open(filepath, encoding="utf-8") as f:
next(f)
reader = csv.reader(f, delimiter="\t")
for id_, row in enumerate(reader):
yield id_, {
"id": row[0],
"document": row[1],
"label": int(row[2]),
}
......@@ -75,7 +75,7 @@ class HFLM(BaseLM):
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
......@@ -89,7 +89,7 @@ class HFLM(BaseLM):
"""
with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
context,
......
......@@ -52,6 +52,7 @@ from . import asdiv
from . import gsm8k
from . import storycloze
from . import klue
from . import nsmc
########################################
# Translation tasks
......@@ -103,13 +104,13 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD,
"wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre?
"coqa": coqa.CoQA,
"drop": drop.DROP,
"lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada
**lambada_multilingual.construct_tasks(),
......@@ -229,7 +230,7 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
......@@ -298,14 +299,12 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
"klue_sts": klue.STS,
"nsmc": nsmc.NSMC,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies,
# KLUE
"klue_sts": klue.STS
}
......@@ -325,7 +324,7 @@ def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__
......
"""
NSMC:
"""
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno
from lm_eval.utils import general_detokenize
_CITATION = """
@inproceedings{zellers2019hellaswag,
title={NSMC: Can a Machine Really Finish Your Sentence?},
author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin},
booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
year={2019}
}
"""
class NSMC(Task):
VERSION = 0
DATASET_PATH = "nsmc"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
return "다음 문장은 긍정일까요 부정일까요?\n{}\n정답:".format(
general_detokenize(doc["document"]),
)
def doc_to_target(self, doc):
return " {}".format({1: "긍정", 0: "부정"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_positive, _ = rf.loglikelihood(ctx, " 긍정")
ll_negative, _ = rf.loglikelihood(ctx, " 부정")
return ll_positive, ll_negative
def process_results(self, doc, results):
ll_positive, ll_negative = results
pred = ll_positive > ll_negative
gold = doc["label"]
return {
"acc": pred == gold
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment