Unverified Commit a2cada5d authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #317 from EleutherAI/Mistobaan/add-pre-commit

Add pre-commit
parents 7a038118 83507c4b
[flake8]
ignore = E203, E266, E501, W503, F403, F401, C901
max-line-length = 127
max-complexity = 10
select = B,C,E,F,W,T4,B9
name: Pull Request
on: [pull_request]
jobs:
pre-commit:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.8
- uses: pre-commit/action@v2.0.3
# Ignore test linting to avoid conflicting changes to version stability.
exclude: ^tests/testdata/
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
hooks:
- id: check-added-large-files
- id: check-ast
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-json
- id: check-merge-conflict
- id: check-symlinks
- id: check-yaml
- id: destroyed-symlinks
- id: detect-private-key
- id: end-of-file-fixer
- id: no-commit-to-branch
- id: requirements-txt-fixer
- id: trailing-whitespace
- id: fix-byte-order-marker
exclude: docs/CNAME
- id: fix-encoding-pragma
args: [--remove]
- id: mixed-line-ending
args: [--fix=lf]
- repo: https://gitlab.com/pycqa/flake8
rev: 3.7.9
hooks:
- id: flake8
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
language_version: python3.8
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
- id: codespell
exclude: >
(?x)^(
.*\.json|ignore.txt
)$
args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt]
......@@ -73,4 +73,3 @@ python -m scripts/clean_training_data/compress_and_package \
```
Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument.
......@@ -295,6 +295,11 @@ class TaskName(...):
## Submitting your Task
Although we currently do not work behind a specific style guide, we'd appreciate if you tidy up your file/s with the `black` formatter (which should've been install through the `requirements.txt`). Keep things clean…ish 🙂.
You can format your changes and perform flake8 standard checks by running the following commands:
```sh
pre-commit install
pre-commit run --all-files
```
Now push your work and make a pull request! Thanks for the contribution 👍. If there are any questions, leave a message in the `#lm-thunderdome` channel on the EAI discord.
ROUGE
rouge
nin
......@@ -51,7 +51,7 @@ class LM(abc.ABC):
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
......@@ -118,7 +118,6 @@ class LM(abc.ABC):
class BaseLM(LM):
@property
@abstractmethod
def eot_token_id(self):
......@@ -145,13 +144,16 @@ class BaseLM(LM):
pass
@abstractmethod
def tok_encode(self, string: str): pass
def tok_encode(self, string: str):
pass
@abstractmethod
def tok_decode(self, tokens: Iterable[int]): pass
def tok_decode(self, tokens: Iterable[int]):
pass
@abstractmethod
def _model_generate(self, context, max_length, eos_token_id): pass
def _model_generate(self, context, max_length, eos_token_id):
pass
@abstractmethod
def _model_call(self, inps):
......@@ -187,19 +189,26 @@ class BaseLM(LM):
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
for string, in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
for (string,) in tqdm(requests):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
)))
),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
......@@ -225,8 +234,10 @@ class BaseLM(LM):
return -len(toks), tuple(toks)
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size
):
inps = []
cont_toks_list = []
inplens = []
......@@ -252,44 +263,60 @@ class BaseLM(LM):
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1],
dtype=torch.long
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
).to(self.device)
inplen, = inp.shape
(inplen,) = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
padding_length = (
padding_length if padding_length is not None else inplen
)
# pad length from seq to padding_length
inp = torch.cat([
inp = torch.cat(
[
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
torch.zeros(padding_length - inplen, dtype=torch.long).to(
inp.device
), # [padding_length - seq]
],
dim=0,
)
inps.append(inp.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(cont)
inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu() # [batch, padding_length, vocab]
multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks \
in zip(chunk, multi_logits, inps, inplens, cont_toks_list):
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
chunk, multi_logits, inps, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
logits = logits[inplen - contlen : inplen].unsqueeze(
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
......@@ -300,10 +327,10 @@ class BaseLM(LM):
res.append(answer)
return reord.get_original(res)
return re_ord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are
# TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
......@@ -313,19 +340,23 @@ class BaseLM(LM):
toks = self.tok_encode(x[0])
return len(toks), x[0]
reord = utils.Reorderer(requests, _collate)
re_ord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()):
for context, until in tqdm(re_ord.get_reordered()):
if isinstance(until, str):
until = [until]
primary_until, = self.tok_encode(until[0])
(primary_until,) = self.tok_encode(until[0])
context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
for term in until:
s = s.split(term)[0]
......@@ -335,7 +366,7 @@ class BaseLM(LM):
res.append(s)
return reord.get_original(res)
return re_ord.get_original(res)
class Task(abc.ABC):
......@@ -383,7 +414,7 @@ class Task(abc.ABC):
self._fewshot_docs = None
def download(self, data_dir=None, cache_dir=None, download_mode=None):
""" Downloads and returns the task dataset.
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
:param data_dir: str
......@@ -412,7 +443,7 @@ class Task(abc.ABC):
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode
download_mode=download_mode,
)
def should_decontaminate(self):
......@@ -473,8 +504,10 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc):
print("Override doc_to_decontamination_query with document specific decontamination query.")
assert(False)
print(
"Override doc_to_decontamination_query with document specific decontamination query."
)
assert False
@abstractmethod
def doc_to_text(self, doc):
......@@ -486,7 +519,7 @@ class Task(abc.ABC):
@abstractmethod
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -531,15 +564,19 @@ class Task(abc.ABC):
def fewshot_description(self):
import warnings
warnings.warn(
"`fewshot_description` will be removed in futures versions. Pass "
"any custom descriptions to the `evaluate` function instead.",
DeprecationWarning)
DeprecationWarning,
)
return ""
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
""" Returns a fewshot context string that is made up of a prepended description
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
......@@ -556,7 +593,9 @@ class Task(abc.ABC):
:returns: str
The fewshot context.
"""
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
......@@ -564,7 +603,9 @@ class Task(abc.ABC):
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
......@@ -577,7 +618,9 @@ class Task(abc.ABC):
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs() if self.has_validation_docs() else self.test_docs()
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
......@@ -585,23 +628,27 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = "\n\n".join(
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
) + "\n\n"
labeled_examples = (
"\n\n".join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
class MultipleChoiceTask(Task):
def doc_to_target(self, doc):
return " " + doc['choices'][doc['gold']]
return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in doc['choices']
rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
]
return lls
......@@ -609,9 +656,9 @@ class MultipleChoiceTask(Task):
def process_results(self, doc, results):
gold = doc["gold"]
acc = 1. if np.argmax(results) == gold else 0.
acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
return {
"acc": acc,
......@@ -632,7 +679,6 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task, abc.ABC):
def should_decontaminate(self):
"""Whether this task supports decontamination against model training set."""
return True
......@@ -644,9 +690,15 @@ class PerplexityTask(Task, abc.ABC):
assert k == 0
return []
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "The number of fewshot examples must be 0 for perplexity tasks."
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`."
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert (
num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks."
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`."
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
......@@ -654,7 +706,9 @@ class PerplexityTask(Task, abc.ABC):
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return ""
......@@ -680,7 +734,7 @@ class PerplexityTask(Task, abc.ABC):
return req
def process_results(self, doc, results):
loglikelihood, = results
(loglikelihood,) = results
words = self.count_words(doc)
bytes_ = self.count_bytes(doc)
return {
......@@ -702,13 +756,13 @@ class PerplexityTask(Task, abc.ABC):
@classmethod
def count_words(cls, doc):
""" Downstream tasks with custom word boundaries should override this! """
"""Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
def hash_args(attr, args):
dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode('utf-8')).hexdigest()
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook:
......@@ -779,6 +833,7 @@ class CachingLM:
self.dbdict.commit()
return res
return fn
def get_cache_hook(self):
......@@ -786,16 +841,18 @@ class CachingLM:
REQUEST_RETURN_LENGTHS = {
'loglikelihood': 2,
'greedy_until': None,
'loglikelihood_rolling': None,
"loglikelihood": 2,
"greedy_until": None,
"loglikelihood_rolling": None,
}
class Request:
def __init__(self, request_type, args, index=None):
if request_type not in REQUEST_RETURN_LENGTHS.keys():
raise NotImplementedError('The request type {} is not implemented!'.format(request_type))
raise NotImplementedError(
"The request type {} is not implemented!".format(request_type)
)
self.request_type = request_type
self.args = args
......@@ -803,17 +860,21 @@ class Request:
def __iter__(self):
if REQUEST_RETURN_LENGTHS[self.request_type] is None:
raise IndexError('This request type does not return multiple arguments!')
raise IndexError("This request type does not return multiple arguments!")
for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
yield Request(self.request_type, self.args, i)
def __getitem__(self, i):
if REQUEST_RETURN_LENGTHS[self.request_type] is None:
raise IndexError('This request type does not return multiple arguments!')
raise IndexError("This request type does not return multiple arguments!")
return Request(self.request_type, self.args, i)
def __eq__(self, other):
return self.request_type == other.request_type and self.args == other.args and self.index == other.index
return (
self.request_type == other.request_type
and self.args == other.args
and self.index == other.index
)
def __repr__(self):
return f"Req_{self.request_type}{self.args}[{self.index}]\n"
......@@ -823,6 +884,7 @@ class RequestFactory:
def __getattr__(self, attr):
def fn(*args):
return Request(attr, args)
return fn
......
......@@ -68,61 +68,111 @@ class Arithmetic(datasets.GeneratorBasedBuilder):
ArithmeticConfig(
name="arithmetic_2da",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_addition.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}),
features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="2-digit addition",
),
ArithmeticConfig(
name="arithmetic_2ds",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_subtraction.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}),
features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="2-digit subtraction",
),
ArithmeticConfig(
name="arithmetic_3da",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/three_digit_addition.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}),
features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="3-digit addition",
),
ArithmeticConfig(
name="arithmetic_3ds",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/three_digit_subtraction.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}),
features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="3-digit subtraction",
),
ArithmeticConfig(
name="arithmetic_4da",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/four_digit_addition.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}),
features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="4-digit addition",
),
ArithmeticConfig(
name="arithmetic_4ds",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/four_digit_subtraction.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}),
features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="4-digit subtraction",
),
ArithmeticConfig(
name="arithmetic_5da",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/five_digit_addition.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}),
features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="5-digit addition",
),
ArithmeticConfig(
name="arithmetic_5ds",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/five_digit_subtraction.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}),
features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="5-digit subtraction",
),
ArithmeticConfig(
name="arithmetic_2dm",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_multiplication.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}),
features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="2-digit multiplication",
),
ArithmeticConfig(
name="arithmetic_1dc",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/single_digit_three_ops.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}),
features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="Single digit 3 operations",
),
]
......@@ -155,9 +205,12 @@ class Arithmetic(datasets.GeneratorBasedBuilder):
with open(filepath, encoding="utf-8") as f:
for key, row in enumerate(f):
data = json.loads(row)
context = data['context'].strip() \
.replace('\n\n', '\n') \
.replace('Q:', 'Question:') \
.replace('A:', 'Answer:')
completion = data['completion']
yield key, {'context': context, 'completion': completion}
context = (
data["context"]
.strip()
.replace("\n\n", "\n")
.replace("Q:", "Question:")
.replace("A:", "Answer:")
)
completion = data["completion"]
yield key, {"context": context, "completion": completion}
......@@ -50,13 +50,16 @@ _URLS = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccf
class ASDiv(datasets.GeneratorBasedBuilder):
""" ASDiv: A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers """
"""ASDiv: A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers"""
VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [
datasets.BuilderConfig(name="asdiv", version=VERSION,
description="A diverse corpus for evaluating and developing english math word problem solvers")
datasets.BuilderConfig(
name="asdiv",
version=VERSION,
description="A diverse corpus for evaluating and developing english math word problem solvers",
)
]
def _info(self):
......@@ -86,7 +89,9 @@ class ASDiv(datasets.GeneratorBasedBuilder):
name=datasets.Split.VALIDATION,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(data_dir, base_filepath, "dataset", "ASDiv.xml"),
"filepath": os.path.join(
data_dir, base_filepath, "dataset", "ASDiv.xml"
),
"split": datasets.Split.VALIDATION,
},
),
......
......@@ -61,7 +61,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
"span_end": -1,
"span_text": "",
"input_text": "",
"turn_id": -1
"turn_id": -1,
}
],
"1": [
......@@ -70,7 +70,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
"span_end": -1,
"span_text": "",
"input_text": "",
"turn_id": -1
"turn_id": -1,
}
],
"2": [
......@@ -79,7 +79,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
"span_end": -1,
"span_text": "",
"input_text": "",
"turn_id": -1
"turn_id": -1,
}
],
}
......@@ -91,8 +91,9 @@ class Coqa(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [
datasets.BuilderConfig(name="coqa", version=VERSION,
description="The CoQA dataset."),
datasets.BuilderConfig(
name="coqa", version=VERSION, description="The CoQA dataset."
),
]
def _info(self):
......@@ -101,41 +102,52 @@ class Coqa(datasets.GeneratorBasedBuilder):
"id": datasets.Value("string"),
"source": datasets.Value("string"),
"story": datasets.Value("string"),
"questions": datasets.features.Sequence({
"questions": datasets.features.Sequence(
{
"input_text": datasets.Value("string"),
"turn_id": datasets.Value("int32"),
}),
"answers": datasets.features.Sequence({
}
),
"answers": datasets.features.Sequence(
{
"span_start": datasets.Value("int32"),
"span_end": datasets.Value("int32"),
"span_text": datasets.Value("string"),
"input_text": datasets.Value("string"),
"turn_id": datasets.Value("int32"),
}),
}
),
"additional_answers": {
"0": datasets.features.Sequence({
"0": datasets.features.Sequence(
{
"span_start": datasets.Value("int32"),
"span_end": datasets.Value("int32"),
"span_text": datasets.Value("string"),
"input_text": datasets.Value("string"),
"turn_id": datasets.Value("int32"),
}),
"1": datasets.features.Sequence({
}
),
"1": datasets.features.Sequence(
{
"span_start": datasets.Value("int32"),
"span_end": datasets.Value("int32"),
"span_text": datasets.Value("string"),
"input_text": datasets.Value("string"),
"turn_id": datasets.Value("int32"),
}),
"2": datasets.features.Sequence({
}
),
"2": datasets.features.Sequence(
{
"span_start": datasets.Value("int32"),
"span_end": datasets.Value("int32"),
"span_text": datasets.Value("string"),
"input_text": datasets.Value("string"),
"turn_id": datasets.Value("int32"),
}),
}
})
),
},
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
......@@ -175,10 +187,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
source = row["source"]
story = row["story"]
questions = [
{
"input_text": q["input_text"],
"turn_id": q["turn_id"]
}
{"input_text": q["input_text"], "turn_id": q["turn_id"]}
for q in row["questions"]
]
answers = [
......@@ -187,7 +196,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end": a["span_end"],
"span_text": a["span_text"],
"input_text": a["input_text"],
"turn_id": a["turn_id"]
"turn_id": a["turn_id"],
}
for a in row["answers"]
]
......@@ -201,7 +210,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end": a0["span_end"],
"span_text": a0["span_text"],
"input_text": a0["input_text"],
"turn_id": a0["turn_id"]
"turn_id": a0["turn_id"],
}
for a0 in row["additional_answers"]["0"]
],
......@@ -211,7 +220,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end": a1["span_end"],
"span_text": a1["span_text"],
"input_text": a1["input_text"],
"turn_id": a1["turn_id"]
"turn_id": a1["turn_id"],
}
for a1 in row["additional_answers"]["1"]
],
......@@ -221,7 +230,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end": a2["span_end"],
"span_text": a2["span_text"],
"input_text": a2["input_text"],
"turn_id": a2["turn_id"]
"turn_id": a2["turn_id"],
}
for a2 in row["additional_answers"]["2"]
],
......@@ -232,5 +241,5 @@ class Coqa(datasets.GeneratorBasedBuilder):
"source": source,
"questions": questions,
"answers": answers,
"additional_answers": additional_answers
"additional_answers": additional_answers,
}
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Custom DROP dataet that, unlike HF, keeps all question-answer pairs
# Custom DROP dataset that, unlike HF, keeps all question-answer pairs
# even if there are multiple types of answers for the same question.
"""DROP dataset."""
......@@ -50,7 +50,8 @@ _URLS = {
"drop": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip",
}
_EMPTY_VALIDATED_ANSWER = [{
_EMPTY_VALIDATED_ANSWER = [
{
"number": "",
"date": {
"day": "",
......@@ -59,8 +60,9 @@ _EMPTY_VALIDATED_ANSWER = [{
},
"spans": [],
"worker_id": "",
"hit_id": ""
}]
"hit_id": "",
}
]
class Drop(datasets.GeneratorBasedBuilder):
......@@ -69,12 +71,14 @@ class Drop(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [
datasets.BuilderConfig(name="drop", version=VERSION,
description="The DROP dataset."),
datasets.BuilderConfig(
name="drop", version=VERSION, description="The DROP dataset."
),
]
def _info(self):
features = datasets.Features({
features = datasets.Features(
{
"section_id": datasets.Value("string"),
"passage": datasets.Value("string"),
"question": datasets.Value("string"),
......@@ -90,7 +94,8 @@ class Drop(datasets.GeneratorBasedBuilder):
"worker_id": datasets.Value("string"),
"hit_id": datasets.Value("string"),
},
"validated_answers": datasets.features.Sequence({
"validated_answers": datasets.features.Sequence(
{
"number": datasets.Value("string"),
"date": {
"day": datasets.Value("string"),
......@@ -100,8 +105,10 @@ class Drop(datasets.GeneratorBasedBuilder):
"spans": datasets.features.Sequence(datasets.Value("string")),
"worker_id": datasets.Value("string"),
"hit_id": datasets.Value("string"),
}),
})
}
),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
......@@ -118,7 +125,9 @@ class Drop(datasets.GeneratorBasedBuilder):
name=datasets.Split.TRAIN,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(data_dir, "drop_dataset", "drop_dataset_train.json"),
"filepath": os.path.join(
data_dir, "drop_dataset", "drop_dataset_train.json"
),
"split": "train",
},
),
......@@ -126,7 +135,9 @@ class Drop(datasets.GeneratorBasedBuilder):
name=datasets.Split.VALIDATION,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(data_dir, "drop_dataset", "drop_dataset_dev.json"),
"filepath": os.path.join(
data_dir, "drop_dataset", "drop_dataset_dev.json"
),
"split": "validation",
},
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment