Commit 121b7096 authored by Fabrizio Milo's avatar Fabrizio Milo
Browse files

add pre-commit

parent 7a038118
[flake8]
ignore = E203, E266, E501, W503, F403, F401, C901
max-line-length = 127
max-complexity = 10
select = B,C,E,F,W,T4,B9
name: Pull Request
on: [pull_request]
jobs:
pre-commit:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.8
- uses: pre-commit/action@v2.0.3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
hooks:
- id: check-added-large-files
- id: check-ast
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-json
- id: check-merge-conflict
- id: check-symlinks
- id: check-yaml
- id: destroyed-symlinks
- id: detect-private-key
- id: end-of-file-fixer
- id: no-commit-to-branch
- id: requirements-txt-fixer
- id: trailing-whitespace
- id: fix-byte-order-marker
exclude: docs/CNAME
- id: fix-encoding-pragma
args: [--remove]
- id: mixed-line-ending
args: [--fix=lf]
- repo: https://gitlab.com/pycqa/flake8
rev: 3.7.9
hooks:
- id: flake8
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
language_version: python3.8
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
- id: codespell
args: [
"--ignore-words-list=reord", # Word used in error messages that need rewording
--check-filenames,
--check-hidden,
]
...@@ -73,4 +73,3 @@ python -m scripts/clean_training_data/compress_and_package \ ...@@ -73,4 +73,3 @@ python -m scripts/clean_training_data/compress_and_package \
``` ```
Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument. Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument.
...@@ -118,7 +118,6 @@ class LM(abc.ABC): ...@@ -118,7 +118,6 @@ class LM(abc.ABC):
class BaseLM(LM): class BaseLM(LM):
@property @property
@abstractmethod @abstractmethod
def eot_token_id(self): def eot_token_id(self):
...@@ -145,13 +144,16 @@ class BaseLM(LM): ...@@ -145,13 +144,16 @@ class BaseLM(LM):
pass pass
@abstractmethod @abstractmethod
def tok_encode(self, string: str): pass def tok_encode(self, string: str):
pass
@abstractmethod @abstractmethod
def tok_decode(self, tokens: Iterable[int]): pass def tok_decode(self, tokens: Iterable[int]):
pass
@abstractmethod @abstractmethod
def _model_generate(self, context, max_length, eos_token_id): pass def _model_generate(self, context, max_length, eos_token_id):
pass
@abstractmethod @abstractmethod
def _model_call(self, inps): def _model_call(self, inps):
...@@ -187,19 +189,26 @@ class BaseLM(LM): ...@@ -187,19 +189,26 @@ class BaseLM(LM):
# TODO: automatic batch size detection for vectorization # TODO: automatic batch size detection for vectorization
loglikelihoods = [] loglikelihoods = []
for string, in tqdm(requests): for (string,) in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows( rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.eot_token_id,
max_seq_len=self.max_length, max_seq_len=self.max_length,
context_len=1, context_len=1,
))) ),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that # that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True) string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
)
# discard is_greedy # discard is_greedy
string_nll = [x[0] for x in string_nll] string_nll = [x[0] for x in string_nll]
...@@ -226,7 +235,9 @@ class BaseLM(LM): ...@@ -226,7 +235,9 @@ class BaseLM(LM):
# TODO: automatic (variable) batch size detection for vectorization # TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size): for chunk in utils.chunks(
tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size
):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
inplens = [] inplens = []
...@@ -252,44 +263,60 @@ class BaseLM(LM): ...@@ -252,44 +263,60 @@ class BaseLM(LM):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
inp = torch.tensor( inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1], (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long dtype=torch.long,
).to(self.device) ).to(self.device)
inplen, = inp.shape (inplen,) = inp.shape
cont = continuation_enc cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one. # since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen padding_length = (
padding_length if padding_length is not None else inplen
)
# pad length from seq to padding_length # pad length from seq to padding_length
inp = torch.cat([ inp = torch.cat(
[
inp, # [seq] inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq] torch.zeros(padding_length - inplen, dtype=torch.long).to(
], dim=0) inp.device
), # [padding_length - seq]
],
dim=0,
)
inps.append(inp.unsqueeze(0)) # [1, padding_length] inps.append(inp.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(cont) cont_toks_list.append(cont)
inplens.append(inplen) inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu() # [batch, padding_length, vocab] multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks \ for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
in zip(chunk, multi_logits, inps, inplens, cont_toks_list): chunk, multi_logits, inps, inplens, cont_toks_list
):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab] logits = logits[inplen - contlen : inplen].unsqueeze(
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq] cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all() max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices # Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist() # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match) # Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal)) answer = (float(logits.sum()), bool(max_equal))
...@@ -319,13 +346,17 @@ class BaseLM(LM): ...@@ -319,13 +346,17 @@ class BaseLM(LM):
if isinstance(until, str): if isinstance(until, str):
until = [until] until = [until]
primary_until, = self.tok_encode(until[0]) (primary_until,) = self.tok_encode(until[0])
context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device) context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until) cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:]) s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
...@@ -383,7 +414,7 @@ class Task(abc.ABC): ...@@ -383,7 +414,7 @@ class Task(abc.ABC):
self._fewshot_docs = None self._fewshot_docs = None
def download(self, data_dir=None, cache_dir=None, download_mode=None): def download(self, data_dir=None, cache_dir=None, download_mode=None):
""" Downloads and returns the task dataset. """Downloads and returns the task dataset.
Override this method to download the dataset from a custom API. Override this method to download the dataset from a custom API.
:param data_dir: str :param data_dir: str
...@@ -412,7 +443,7 @@ class Task(abc.ABC): ...@@ -412,7 +443,7 @@ class Task(abc.ABC):
name=self.DATASET_NAME, name=self.DATASET_NAME,
data_dir=data_dir, data_dir=data_dir,
cache_dir=cache_dir, cache_dir=cache_dir,
download_mode=download_mode download_mode=download_mode,
) )
def should_decontaminate(self): def should_decontaminate(self):
...@@ -473,8 +504,10 @@ class Task(abc.ABC): ...@@ -473,8 +504,10 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
print("Override doc_to_decontamination_query with document specific decontamination query.") print(
assert(False) "Override doc_to_decontamination_query with document specific decontamination query."
)
assert False
@abstractmethod @abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -486,7 +519,7 @@ class Task(abc.ABC): ...@@ -486,7 +519,7 @@ class Task(abc.ABC):
@abstractmethod @abstractmethod
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -531,15 +564,19 @@ class Task(abc.ABC): ...@@ -531,15 +564,19 @@ class Task(abc.ABC):
def fewshot_description(self): def fewshot_description(self):
import warnings import warnings
warnings.warn( warnings.warn(
"`fewshot_description` will be removed in futures versions. Pass " "`fewshot_description` will be removed in futures versions. Pass "
"any custom descriptions to the `evaluate` function instead.", "any custom descriptions to the `evaluate` function instead.",
DeprecationWarning) DeprecationWarning,
)
return "" return ""
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
""" Returns a fewshot context string that is made up of a prepended description self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str :param doc: str
...@@ -556,7 +593,9 @@ class Task(abc.ABC): ...@@ -556,7 +593,9 @@ class Task(abc.ABC):
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, ( assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend " "The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the " "a custom description to the context, supply the corresponding string via the "
...@@ -564,7 +603,9 @@ class Task(abc.ABC): ...@@ -564,7 +603,9 @@ class Task(abc.ABC):
) )
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else "" description = description + "\n\n" if description else ""
...@@ -577,7 +618,9 @@ class Task(abc.ABC): ...@@ -577,7 +618,9 @@ class Task(abc.ABC):
else: else:
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list( self._fewshot_docs = list(
self.validation_docs() if self.has_validation_docs() else self.test_docs() self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
) )
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
...@@ -585,23 +628,27 @@ class Task(abc.ABC): ...@@ -585,23 +628,27 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = "\n\n".join( labeled_examples = (
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex] "\n\n".join(
) + "\n\n" [
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
return description + labeled_examples + example return description + labeled_examples + example
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['choices'][doc['gold']] return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
lls = [ lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0] rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
for choice in doc['choices']
] ]
return lls return lls
...@@ -609,9 +656,9 @@ class MultipleChoiceTask(Task): ...@@ -609,9 +656,9 @@ class MultipleChoiceTask(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["gold"] gold = doc["gold"]
acc = 1. if np.argmax(results) == gold else 0. acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]]) completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1. if np.argmax(results / completion_len) == gold else 0. acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
return { return {
"acc": acc, "acc": acc,
...@@ -632,7 +679,6 @@ class MultipleChoiceTask(Task): ...@@ -632,7 +679,6 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task, abc.ABC): class PerplexityTask(Task, abc.ABC):
def should_decontaminate(self): def should_decontaminate(self):
"""Whether this task supports decontamination against model training set.""" """Whether this task supports decontamination against model training set."""
return True return True
...@@ -644,9 +690,15 @@ class PerplexityTask(Task, abc.ABC): ...@@ -644,9 +690,15 @@ class PerplexityTask(Task, abc.ABC):
assert k == 0 assert k == 0
return [] return []
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
assert num_fewshot == 0, "The number of fewshot examples must be 0 for perplexity tasks." self, doc, num_fewshot, provide_description=None, rnd=None, description=None
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`." ):
assert (
num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks."
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`."
assert not provide_description, ( assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend " "The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the " "a custom description to the context, supply the corresponding string via the "
...@@ -654,7 +706,9 @@ class PerplexityTask(Task, abc.ABC): ...@@ -654,7 +706,9 @@ class PerplexityTask(Task, abc.ABC):
) )
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return "" return ""
...@@ -680,7 +734,7 @@ class PerplexityTask(Task, abc.ABC): ...@@ -680,7 +734,7 @@ class PerplexityTask(Task, abc.ABC):
return req return req
def process_results(self, doc, results): def process_results(self, doc, results):
loglikelihood, = results (loglikelihood,) = results
words = self.count_words(doc) words = self.count_words(doc)
bytes_ = self.count_bytes(doc) bytes_ = self.count_bytes(doc)
return { return {
...@@ -702,13 +756,13 @@ class PerplexityTask(Task, abc.ABC): ...@@ -702,13 +756,13 @@ class PerplexityTask(Task, abc.ABC):
@classmethod @classmethod
def count_words(cls, doc): def count_words(cls, doc):
""" Downstream tasks with custom word boundaries should override this! """ """Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
def hash_args(attr, args): def hash_args(attr, args):
dat = json.dumps([attr] + list(args)) dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode('utf-8')).hexdigest() return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook: class CacheHook:
...@@ -779,6 +833,7 @@ class CachingLM: ...@@ -779,6 +833,7 @@ class CachingLM:
self.dbdict.commit() self.dbdict.commit()
return res return res
return fn return fn
def get_cache_hook(self): def get_cache_hook(self):
...@@ -786,16 +841,18 @@ class CachingLM: ...@@ -786,16 +841,18 @@ class CachingLM:
REQUEST_RETURN_LENGTHS = { REQUEST_RETURN_LENGTHS = {
'loglikelihood': 2, "loglikelihood": 2,
'greedy_until': None, "greedy_until": None,
'loglikelihood_rolling': None, "loglikelihood_rolling": None,
} }
class Request: class Request:
def __init__(self, request_type, args, index=None): def __init__(self, request_type, args, index=None):
if request_type not in REQUEST_RETURN_LENGTHS.keys(): if request_type not in REQUEST_RETURN_LENGTHS.keys():
raise NotImplementedError('The request type {} is not implemented!'.format(request_type)) raise NotImplementedError(
"The request type {} is not implemented!".format(request_type)
)
self.request_type = request_type self.request_type = request_type
self.args = args self.args = args
...@@ -803,17 +860,21 @@ class Request: ...@@ -803,17 +860,21 @@ class Request:
def __iter__(self): def __iter__(self):
if REQUEST_RETURN_LENGTHS[self.request_type] is None: if REQUEST_RETURN_LENGTHS[self.request_type] is None:
raise IndexError('This request type does not return multiple arguments!') raise IndexError("This request type does not return multiple arguments!")
for i in range(REQUEST_RETURN_LENGTHS[self.request_type]): for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
yield Request(self.request_type, self.args, i) yield Request(self.request_type, self.args, i)
def __getitem__(self, i): def __getitem__(self, i):
if REQUEST_RETURN_LENGTHS[self.request_type] is None: if REQUEST_RETURN_LENGTHS[self.request_type] is None:
raise IndexError('This request type does not return multiple arguments!') raise IndexError("This request type does not return multiple arguments!")
return Request(self.request_type, self.args, i) return Request(self.request_type, self.args, i)
def __eq__(self, other): def __eq__(self, other):
return self.request_type == other.request_type and self.args == other.args and self.index == other.index return (
self.request_type == other.request_type
and self.args == other.args
and self.index == other.index
)
def __repr__(self): def __repr__(self):
return f"Req_{self.request_type}{self.args}[{self.index}]\n" return f"Req_{self.request_type}{self.args}[{self.index}]\n"
...@@ -823,6 +884,7 @@ class RequestFactory: ...@@ -823,6 +884,7 @@ class RequestFactory:
def __getattr__(self, attr): def __getattr__(self, attr):
def fn(*args): def fn(*args):
return Request(attr, args) return Request(attr, args)
return fn return fn
......
...@@ -68,61 +68,111 @@ class Arithmetic(datasets.GeneratorBasedBuilder): ...@@ -68,61 +68,111 @@ class Arithmetic(datasets.GeneratorBasedBuilder):
ArithmeticConfig( ArithmeticConfig(
name="arithmetic_2da", name="arithmetic_2da",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_addition.jsonl", url="https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_addition.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}), features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="2-digit addition", description="2-digit addition",
), ),
ArithmeticConfig( ArithmeticConfig(
name="arithmetic_2ds", name="arithmetic_2ds",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_subtraction.jsonl", url="https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_subtraction.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}), features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="2-digit subtraction", description="2-digit subtraction",
), ),
ArithmeticConfig( ArithmeticConfig(
name="arithmetic_3da", name="arithmetic_3da",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/three_digit_addition.jsonl", url="https://raw.githubusercontent.com/openai/gpt-3/master/data/three_digit_addition.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}), features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="3-digit addition", description="3-digit addition",
), ),
ArithmeticConfig( ArithmeticConfig(
name="arithmetic_3ds", name="arithmetic_3ds",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/three_digit_subtraction.jsonl", url="https://raw.githubusercontent.com/openai/gpt-3/master/data/three_digit_subtraction.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}), features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="3-digit subtraction", description="3-digit subtraction",
), ),
ArithmeticConfig( ArithmeticConfig(
name="arithmetic_4da", name="arithmetic_4da",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/four_digit_addition.jsonl", url="https://raw.githubusercontent.com/openai/gpt-3/master/data/four_digit_addition.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}), features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="4-digit addition", description="4-digit addition",
), ),
ArithmeticConfig( ArithmeticConfig(
name="arithmetic_4ds", name="arithmetic_4ds",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/four_digit_subtraction.jsonl", url="https://raw.githubusercontent.com/openai/gpt-3/master/data/four_digit_subtraction.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}), features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="4-digit subtraction", description="4-digit subtraction",
), ),
ArithmeticConfig( ArithmeticConfig(
name="arithmetic_5da", name="arithmetic_5da",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/five_digit_addition.jsonl", url="https://raw.githubusercontent.com/openai/gpt-3/master/data/five_digit_addition.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}), features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="5-digit addition", description="5-digit addition",
), ),
ArithmeticConfig( ArithmeticConfig(
name="arithmetic_5ds", name="arithmetic_5ds",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/five_digit_subtraction.jsonl", url="https://raw.githubusercontent.com/openai/gpt-3/master/data/five_digit_subtraction.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}), features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="5-digit subtraction", description="5-digit subtraction",
), ),
ArithmeticConfig( ArithmeticConfig(
name="arithmetic_2dm", name="arithmetic_2dm",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_multiplication.jsonl", url="https://raw.githubusercontent.com/openai/gpt-3/master/data/two_digit_multiplication.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}), features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="2-digit multiplication", description="2-digit multiplication",
), ),
ArithmeticConfig( ArithmeticConfig(
name="arithmetic_1dc", name="arithmetic_1dc",
url="https://raw.githubusercontent.com/openai/gpt-3/master/data/single_digit_three_ops.jsonl", url="https://raw.githubusercontent.com/openai/gpt-3/master/data/single_digit_three_ops.jsonl",
features=datasets.Features({"context": datasets.Value("string"), "completion": datasets.Value("string")}), features=datasets.Features(
{
"context": datasets.Value("string"),
"completion": datasets.Value("string"),
}
),
description="Single digit 3 operations", description="Single digit 3 operations",
), ),
] ]
...@@ -155,9 +205,12 @@ class Arithmetic(datasets.GeneratorBasedBuilder): ...@@ -155,9 +205,12 @@ class Arithmetic(datasets.GeneratorBasedBuilder):
with open(filepath, encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for key, row in enumerate(f): for key, row in enumerate(f):
data = json.loads(row) data = json.loads(row)
context = data['context'].strip() \ context = (
.replace('\n\n', '\n') \ data["context"]
.replace('Q:', 'Question:') \ .strip()
.replace('A:', 'Answer:') .replace("\n\n", "\n")
completion = data['completion'] .replace("Q:", "Question:")
yield key, {'context': context, 'completion': completion} .replace("A:", "Answer:")
)
completion = data["completion"]
yield key, {"context": context, "completion": completion}
...@@ -50,13 +50,16 @@ _URLS = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccf ...@@ -50,13 +50,16 @@ _URLS = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccf
class ASDiv(datasets.GeneratorBasedBuilder): class ASDiv(datasets.GeneratorBasedBuilder):
""" ASDiv: A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers """ """ASDiv: A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers"""
VERSION = datasets.Version("0.0.1") VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [ BUILDER_CONFIGS = [
datasets.BuilderConfig(name="asdiv", version=VERSION, datasets.BuilderConfig(
description="A diverse corpus for evaluating and developing english math word problem solvers") name="asdiv",
version=VERSION,
description="A diverse corpus for evaluating and developing english math word problem solvers",
)
] ]
def _info(self): def _info(self):
...@@ -86,7 +89,9 @@ class ASDiv(datasets.GeneratorBasedBuilder): ...@@ -86,7 +89,9 @@ class ASDiv(datasets.GeneratorBasedBuilder):
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
# These kwargs will be passed to _generate_examples # These kwargs will be passed to _generate_examples
gen_kwargs={ gen_kwargs={
"filepath": os.path.join(data_dir, base_filepath, "dataset", "ASDiv.xml"), "filepath": os.path.join(
data_dir, base_filepath, "dataset", "ASDiv.xml"
),
"split": datasets.Split.VALIDATION, "split": datasets.Split.VALIDATION,
}, },
), ),
......
...@@ -61,7 +61,7 @@ _EMPTY_ADDITIONAL_ANSWER = { ...@@ -61,7 +61,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
"span_end": -1, "span_end": -1,
"span_text": "", "span_text": "",
"input_text": "", "input_text": "",
"turn_id": -1 "turn_id": -1,
} }
], ],
"1": [ "1": [
...@@ -70,7 +70,7 @@ _EMPTY_ADDITIONAL_ANSWER = { ...@@ -70,7 +70,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
"span_end": -1, "span_end": -1,
"span_text": "", "span_text": "",
"input_text": "", "input_text": "",
"turn_id": -1 "turn_id": -1,
} }
], ],
"2": [ "2": [
...@@ -79,7 +79,7 @@ _EMPTY_ADDITIONAL_ANSWER = { ...@@ -79,7 +79,7 @@ _EMPTY_ADDITIONAL_ANSWER = {
"span_end": -1, "span_end": -1,
"span_text": "", "span_text": "",
"input_text": "", "input_text": "",
"turn_id": -1 "turn_id": -1,
} }
], ],
} }
...@@ -91,8 +91,9 @@ class Coqa(datasets.GeneratorBasedBuilder): ...@@ -91,8 +91,9 @@ class Coqa(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.1") VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [ BUILDER_CONFIGS = [
datasets.BuilderConfig(name="coqa", version=VERSION, datasets.BuilderConfig(
description="The CoQA dataset."), name="coqa", version=VERSION, description="The CoQA dataset."
),
] ]
def _info(self): def _info(self):
...@@ -101,41 +102,52 @@ class Coqa(datasets.GeneratorBasedBuilder): ...@@ -101,41 +102,52 @@ class Coqa(datasets.GeneratorBasedBuilder):
"id": datasets.Value("string"), "id": datasets.Value("string"),
"source": datasets.Value("string"), "source": datasets.Value("string"),
"story": datasets.Value("string"), "story": datasets.Value("string"),
"questions": datasets.features.Sequence({ "questions": datasets.features.Sequence(
{
"input_text": datasets.Value("string"), "input_text": datasets.Value("string"),
"turn_id": datasets.Value("int32"), "turn_id": datasets.Value("int32"),
}), }
"answers": datasets.features.Sequence({ ),
"answers": datasets.features.Sequence(
{
"span_start": datasets.Value("int32"), "span_start": datasets.Value("int32"),
"span_end": datasets.Value("int32"), "span_end": datasets.Value("int32"),
"span_text": datasets.Value("string"), "span_text": datasets.Value("string"),
"input_text": datasets.Value("string"), "input_text": datasets.Value("string"),
"turn_id": datasets.Value("int32"), "turn_id": datasets.Value("int32"),
}), }
),
"additional_answers": { "additional_answers": {
"0": datasets.features.Sequence({ "0": datasets.features.Sequence(
{
"span_start": datasets.Value("int32"), "span_start": datasets.Value("int32"),
"span_end": datasets.Value("int32"), "span_end": datasets.Value("int32"),
"span_text": datasets.Value("string"), "span_text": datasets.Value("string"),
"input_text": datasets.Value("string"), "input_text": datasets.Value("string"),
"turn_id": datasets.Value("int32"), "turn_id": datasets.Value("int32"),
}), }
"1": datasets.features.Sequence({ ),
"1": datasets.features.Sequence(
{
"span_start": datasets.Value("int32"), "span_start": datasets.Value("int32"),
"span_end": datasets.Value("int32"), "span_end": datasets.Value("int32"),
"span_text": datasets.Value("string"), "span_text": datasets.Value("string"),
"input_text": datasets.Value("string"), "input_text": datasets.Value("string"),
"turn_id": datasets.Value("int32"), "turn_id": datasets.Value("int32"),
}), }
"2": datasets.features.Sequence({ ),
"2": datasets.features.Sequence(
{
"span_start": datasets.Value("int32"), "span_start": datasets.Value("int32"),
"span_end": datasets.Value("int32"), "span_end": datasets.Value("int32"),
"span_text": datasets.Value("string"), "span_text": datasets.Value("string"),
"input_text": datasets.Value("string"), "input_text": datasets.Value("string"),
"turn_id": datasets.Value("int32"), "turn_id": datasets.Value("int32"),
}),
} }
}) ),
},
}
)
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION,
features=features, features=features,
...@@ -175,10 +187,7 @@ class Coqa(datasets.GeneratorBasedBuilder): ...@@ -175,10 +187,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
source = row["source"] source = row["source"]
story = row["story"] story = row["story"]
questions = [ questions = [
{ {"input_text": q["input_text"], "turn_id": q["turn_id"]}
"input_text": q["input_text"],
"turn_id": q["turn_id"]
}
for q in row["questions"] for q in row["questions"]
] ]
answers = [ answers = [
...@@ -187,7 +196,7 @@ class Coqa(datasets.GeneratorBasedBuilder): ...@@ -187,7 +196,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end": a["span_end"], "span_end": a["span_end"],
"span_text": a["span_text"], "span_text": a["span_text"],
"input_text": a["input_text"], "input_text": a["input_text"],
"turn_id": a["turn_id"] "turn_id": a["turn_id"],
} }
for a in row["answers"] for a in row["answers"]
] ]
...@@ -201,7 +210,7 @@ class Coqa(datasets.GeneratorBasedBuilder): ...@@ -201,7 +210,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end": a0["span_end"], "span_end": a0["span_end"],
"span_text": a0["span_text"], "span_text": a0["span_text"],
"input_text": a0["input_text"], "input_text": a0["input_text"],
"turn_id": a0["turn_id"] "turn_id": a0["turn_id"],
} }
for a0 in row["additional_answers"]["0"] for a0 in row["additional_answers"]["0"]
], ],
...@@ -211,7 +220,7 @@ class Coqa(datasets.GeneratorBasedBuilder): ...@@ -211,7 +220,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end": a1["span_end"], "span_end": a1["span_end"],
"span_text": a1["span_text"], "span_text": a1["span_text"],
"input_text": a1["input_text"], "input_text": a1["input_text"],
"turn_id": a1["turn_id"] "turn_id": a1["turn_id"],
} }
for a1 in row["additional_answers"]["1"] for a1 in row["additional_answers"]["1"]
], ],
...@@ -221,7 +230,7 @@ class Coqa(datasets.GeneratorBasedBuilder): ...@@ -221,7 +230,7 @@ class Coqa(datasets.GeneratorBasedBuilder):
"span_end": a2["span_end"], "span_end": a2["span_end"],
"span_text": a2["span_text"], "span_text": a2["span_text"],
"input_text": a2["input_text"], "input_text": a2["input_text"],
"turn_id": a2["turn_id"] "turn_id": a2["turn_id"],
} }
for a2 in row["additional_answers"]["2"] for a2 in row["additional_answers"]["2"]
], ],
...@@ -232,5 +241,5 @@ class Coqa(datasets.GeneratorBasedBuilder): ...@@ -232,5 +241,5 @@ class Coqa(datasets.GeneratorBasedBuilder):
"source": source, "source": source,
"questions": questions, "questions": questions,
"answers": answers, "answers": answers,
"additional_answers": additional_answers "additional_answers": additional_answers,
} }
...@@ -50,7 +50,8 @@ _URLS = { ...@@ -50,7 +50,8 @@ _URLS = {
"drop": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip", "drop": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip",
} }
_EMPTY_VALIDATED_ANSWER = [{ _EMPTY_VALIDATED_ANSWER = [
{
"number": "", "number": "",
"date": { "date": {
"day": "", "day": "",
...@@ -59,8 +60,9 @@ _EMPTY_VALIDATED_ANSWER = [{ ...@@ -59,8 +60,9 @@ _EMPTY_VALIDATED_ANSWER = [{
}, },
"spans": [], "spans": [],
"worker_id": "", "worker_id": "",
"hit_id": "" "hit_id": "",
}] }
]
class Drop(datasets.GeneratorBasedBuilder): class Drop(datasets.GeneratorBasedBuilder):
...@@ -69,12 +71,14 @@ class Drop(datasets.GeneratorBasedBuilder): ...@@ -69,12 +71,14 @@ class Drop(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.1") VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [ BUILDER_CONFIGS = [
datasets.BuilderConfig(name="drop", version=VERSION, datasets.BuilderConfig(
description="The DROP dataset."), name="drop", version=VERSION, description="The DROP dataset."
),
] ]
def _info(self): def _info(self):
features = datasets.Features({ features = datasets.Features(
{
"section_id": datasets.Value("string"), "section_id": datasets.Value("string"),
"passage": datasets.Value("string"), "passage": datasets.Value("string"),
"question": datasets.Value("string"), "question": datasets.Value("string"),
...@@ -90,7 +94,8 @@ class Drop(datasets.GeneratorBasedBuilder): ...@@ -90,7 +94,8 @@ class Drop(datasets.GeneratorBasedBuilder):
"worker_id": datasets.Value("string"), "worker_id": datasets.Value("string"),
"hit_id": datasets.Value("string"), "hit_id": datasets.Value("string"),
}, },
"validated_answers": datasets.features.Sequence({ "validated_answers": datasets.features.Sequence(
{
"number": datasets.Value("string"), "number": datasets.Value("string"),
"date": { "date": {
"day": datasets.Value("string"), "day": datasets.Value("string"),
...@@ -100,8 +105,10 @@ class Drop(datasets.GeneratorBasedBuilder): ...@@ -100,8 +105,10 @@ class Drop(datasets.GeneratorBasedBuilder):
"spans": datasets.features.Sequence(datasets.Value("string")), "spans": datasets.features.Sequence(datasets.Value("string")),
"worker_id": datasets.Value("string"), "worker_id": datasets.Value("string"),
"hit_id": datasets.Value("string"), "hit_id": datasets.Value("string"),
}), }
}) ),
}
)
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION,
features=features, features=features,
...@@ -118,7 +125,9 @@ class Drop(datasets.GeneratorBasedBuilder): ...@@ -118,7 +125,9 @@ class Drop(datasets.GeneratorBasedBuilder):
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
# These kwargs will be passed to _generate_examples # These kwargs will be passed to _generate_examples
gen_kwargs={ gen_kwargs={
"filepath": os.path.join(data_dir, "drop_dataset", "drop_dataset_train.json"), "filepath": os.path.join(
data_dir, "drop_dataset", "drop_dataset_train.json"
),
"split": "train", "split": "train",
}, },
), ),
...@@ -126,7 +135,9 @@ class Drop(datasets.GeneratorBasedBuilder): ...@@ -126,7 +135,9 @@ class Drop(datasets.GeneratorBasedBuilder):
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
# These kwargs will be passed to _generate_examples # These kwargs will be passed to _generate_examples
gen_kwargs={ gen_kwargs={
"filepath": os.path.join(data_dir, "drop_dataset", "drop_dataset_dev.json"), "filepath": os.path.join(
data_dir, "drop_dataset", "drop_dataset_dev.json"
),
"split": "validation", "split": "validation",
}, },
), ),
......
...@@ -56,8 +56,11 @@ class GSM8K(datasets.GeneratorBasedBuilder): ...@@ -56,8 +56,11 @@ class GSM8K(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.1") VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [ BUILDER_CONFIGS = [
datasets.BuilderConfig(name="gsm8k", version=VERSION, datasets.BuilderConfig(
description="The Grade School Math 8k dataset."), name="gsm8k",
version=VERSION,
description="The Grade School Math 8k dataset.",
),
] ]
def _info(self): def _info(self):
...@@ -90,10 +93,7 @@ class GSM8K(datasets.GeneratorBasedBuilder): ...@@ -90,10 +93,7 @@ class GSM8K(datasets.GeneratorBasedBuilder):
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
# These kwargs will be passed to _generate_examples # These kwargs will be passed to _generate_examples
gen_kwargs={ gen_kwargs={"filepath": data_dir["test"], "split": "test"},
"filepath": data_dir["test"],
"split": "test"
},
), ),
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment