Commit c77b60c1 authored by Leo Gao's avatar Leo Gao
Browse files

Merge branch 'perplexity'

# Conflicts:
#	lm_eval/models/gpt2.py
#	lm_eval/models/gpt3.py
#	tests/test_misc.py
#	tests/test_models.py
parents eb8456d0 5452fddb
import abc
import random
import numpy as np
import re
from lm_eval.metrics import mean
from lm_eval.metrics import mean, perplexity, weighted_mean
class LM(abc.ABC):
......@@ -27,9 +28,51 @@ class LM(abc.ABC):
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `contination`
The log probability of `continuation`
isgreedy:
Whether `contination` would be generated by greedy sampling from `context`
Whether `continuation` would be generated by greedy sampling from `context`
"""
pass
@abc.abstractmethod
def loglikelihood_rolling(self, requests):
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT
Max context length: 4
Resulting input/prediction pairs:
INPUT: EOT 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list
A list of strings
string: str
String for which we are computing per-toke loglikelihood
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `continuation`
isgreedy:
Whether `continuation` would be generated by greedy sampling from `context`
"""
pass
......@@ -247,9 +290,68 @@ class MultipleChoiceTask(Task):
}
class PerplexityTask(Task, abc.ABC):
def has_training_docs(self):
return False
def fewshot_description(self):
return ""
def fewshot_examples(self, k, rnd):
assert k == 0
return []
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
assert num_fewshot == 0
assert not provide_description
return ""
def higher_is_better(self):
return {
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def doc_to_text(self, doc):
return doc
def doc_to_target(self, doc):
raise NotImplementedError()
def construct_requests(self, doc, ctx):
assert not ctx
req = rf.loglikelihood_rolling(doc)
return req
def process_results(self, doc, results):
loglikelihood, = results
return {
"word_perplexity": loglikelihood / self.count_words(self.doc_to_text(doc)),
"byte_perplexity": loglikelihood / self.count_bytes(self.doc_to_text(doc)),
"bits_per_byte": (-loglikelihood, self.count_bytes(self.doc_to_text(doc)))
}
def aggregation(self):
return {
"word_perplexity": perplexity,
"byte_perplexity": perplexity,
"bits_per_byte": weighted_mean
}
def count_bytes(self, s):
return len(s.encode("utf-8"))
def count_words(self, s):
""" Downstream tasks with custom word boundaries should override this! """
return len(re.split(r"\s+", s))
req_ret_lens = {
'loglikelihood': 2,
'greedy_until': None,
'loglikelihood_rolling': None,
}
import os
......
......@@ -64,6 +64,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
# only in index. We could implement some kind of caching, but that would be more of a bandaid
# solution. we could also implement some kind of autogrouping here; they should end up next to each other.
print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
......
......@@ -94,6 +94,11 @@ def perplexity(items):
return math.exp(-mean(items))
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
......
......@@ -26,3 +26,11 @@ class DummyLM(LM):
assert ctx.strip() != ''
return res
def loglikelihood_rolling(self, requests):
res = []
for _ in requests:
res.append(-random.random())
return res
\ No newline at end of file
......@@ -5,10 +5,13 @@ import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
class GPT2LM(LM):
MAX_GEN_TOKS = 256
VOCAB_SIZE = 50257
EOT_TOKEN_ID = 50256
def __init__(self, device='cuda', pretrained='gpt2', batch_size=1):
super().__init__()
......@@ -51,7 +54,7 @@ class GPT2LM(LM):
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [50256]
context_enc = [self.EOT_TOKEN_ID]
else:
context_enc = self.tokenizer.encode(context)
......@@ -61,7 +64,36 @@ class GPT2LM(LM):
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens(self, requests):
def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
with torch.no_grad():
for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"]
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=encoded,
prefix_token=self.EOT_TOKEN_ID,
max_seq_len=self.max_length,
context_len=1,
)))
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
with torch.no_grad():
......@@ -78,18 +110,38 @@ class GPT2LM(LM):
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered()), self.batch_size):
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
inps = []
contlens = []
inplens = []
ctxlens = []
padding_length = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp = torch.tensor((context_enc + continuation_enc)[-self.max_length:], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1]
, dtype=torch.long).to(self.device)
inplen, = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
......@@ -100,19 +152,24 @@ class GPT2LM(LM):
], dim=0)
inps.append(inp.unsqueeze(0))
contlens.append(cont)
inplens.append(inplen)
ctxlens.append(ctxlen)
multi_logits = F.log_softmax(self.gpt2(torch.cat(inps, dim=0))[0][:, :, :50257], dim=-1) # [batch, seq, vocab]
multi_logits = F.log_softmax(self.gpt2(torch.cat(inps, dim=0))[0][:, :, :50257], dim=-1).cpu() # [batch, seq, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks)
for (cache_key, _, _), logits, ctxlen, inp, inplen in zip(chunk, multi_logits, ctxlens, inps, inplens):
logits = logits[ctxlen - 1:inplen - 1].unsqueeze(0) # [1, seq, vocab]
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
cont_toks = inp[:, ctxlen:inplen] # [1, seq]
# cont_toks :: [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
max_equal = (greedy_tokens == cont_toks).all()
last_token_slice = logits[:, -1, :].squeeze(0).tolist()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
......
import os
import numpy as np
import transformers
from lm_eval.base import LM
from lm_eval import utils
......@@ -58,6 +59,7 @@ class GPT3LM(LM):
self.tokenizer.pad_token = "<|endoftext|>"
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0]
# Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
......@@ -83,6 +85,30 @@ class GPT3LM(LM):
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
# TODO: switch implementation to use _loglikelihood_tokens rather than having it do its own thing
loglikelihoods = []
for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"]
rolling_token_windows = utils.get_rolling_token_windows(
token_list=encoded,
prefix_token=self.end_of_text_token_id,
max_seq_len=self.MAX_LENGTH,
context_len=1,
)
string_loglikelihoods = []
for input_tokens, pred_tokens in rolling_token_windows:
block_output = self.get_token_logprobs(
input_tokens=input_tokens,
pred_tokens=pred_tokens,
)
string_loglikelihoods.append(block_output["logprobs"])
string_loglikelihoods = np.concatenate(string_loglikelihoods).sum()
loglikelihoods.append(string_loglikelihoods)
return loglikelihoods
def _loglikelihood_tokens(self, requests):
import openai
res = []
......@@ -95,7 +121,7 @@ class GPT3LM(LM):
return (-len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = []
ctxlens = []
......@@ -122,9 +148,30 @@ class GPT3LM(LM):
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return reord.get_original(res)
def get_token_logprobs(self, input_tokens, pred_tokens):
pred_start = len(input_tokens) - len(pred_tokens) + 1
# We're going to stitch together the input_tokens and pred_tokens
# In the longest case, this gets us to length = max_seq_len+1 (which the API works with)
assert input_tokens[pred_start:] == pred_tokens[:-1]
token_ids = input_tokens + [pred_tokens[-1]]
response = oa_completion(
engine=self.engine,
prompt=token_ids,
max_tokens=0,
temperature=0.0,
logprobs=0,
echo=True,
)
logprobs = np.array(response["choices"][0]["logprobs"]["token_logprobs"][pred_start:])
positions = np.arange(pred_start-1, pred_start-1 + len(token_ids[pred_start:]))
return {
"logprobs": logprobs,
"positions": positions,
}
def greedy_until(self, requests):
if not requests: return []
import openai
......
......@@ -37,6 +37,7 @@ from . import hendrycks_test
from . import hendrycks_math
from . import cbt
from . import lambada_cloze
from . import pile
########################################
# Translation tasks
......@@ -171,6 +172,31 @@ TASK_REGISTRY = {
"cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords,
# Pile
"pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3,
"pile_bookcorpus2": pile.PileBookCorpus2,
"pile_commoncrawl": pile.PileCommonCrawl,
"pile_dm-mathematics": pile.PileDmMathematics,
"pile_enron": pile.PileEnron,
"pile_europarl": pile.PileEuroparl,
"pile_freelaw": pile.PileFreeLaw,
"pile_github": pile.PileGithub,
"pile_gutenberg": pile.PileGutenberg,
"pile_hackernews": pile.PileHackernews,
"pile_nih-exporter": pile.PileNIHExporter,
"pile_opensubtitles": pile.PileOpenSubtitles,
"pile_openwebtext2": pile.PileOpenWebText2,
"pile_philpapers": pile.PilePhilPapers,
"pile_pile-cc": pile.PilePileCc,
"pile_pubmed-abstracts": pile.PilePubmedAbstracts,
"pile_pubmed-central": pile.PilePubmedCentral,
"pile_stackexchange": pile.PileStackExchange,
"pile_uspto": pile.PileUspto,
"pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles,
}
......
import os
import lm_dataformat
import abc
import numpy as np
from lm_eval.base import rf, PerplexityTask
from ..metrics import mean, matthews_corrcoef, f1_score
from ..utils import general_detokenize
from best_download import download_file
class PilePerplexityTask(PerplexityTask, abc.ABC):
PILE_SET_NAME = None
VAL_PATH = 'data/pile/val.jsonl.zst'
TEST_PATH = 'data/pile/test.jsonl.zst'
def download(self):
# TODO: separate pile val/test out by component so we don't have to scan the entire file once per set
os.makedirs("data/pile/", exist_ok=True)
if not os.path.exists(self.VAL_PATH):
download_file("https://the-eye.eu/public/AI/pile/val.jsonl.zst", self.VAL_PATH)
if not os.path.exists(self.TEST_PATH):
download_file("https://the-eye.eu/public/AI/pile/test.jsonl.zst", self.TEST_PATH)
def validation_docs(self):
rdr = lm_dataformat.Reader(self.VAL_PATH)
for doc, metadata in rdr.stream_data(get_meta=True):
if metadata["pile_set_name"] == self.PILE_SET_NAME:
yield doc
def test_docs(self):
rdr = lm_dataformat.Reader(self.TEST_PATH)
for doc, metadata in rdr.stream_data(get_meta=True):
if metadata["pile_set_name"] == self.PILE_SET_NAME:
yield doc
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
class PileArxiv(PilePerplexityTask):
PILE_SET_NAME = "ArXiv"
class PileBooks3(PilePerplexityTask):
PILE_SET_NAME = "Books3"
class PileBookCorpus2(PilePerplexityTask):
PILE_SET_NAME = "BookCorpus2"
class PileCommonCrawl(PilePerplexityTask):
PILE_SET_NAME = "CommonCrawl"
class PileDmMathematics(PilePerplexityTask):
PILE_SET_NAME = "DM Mathematics"
class PileEnron(PilePerplexityTask):
PILE_SET_NAME = "Enron Emails"
class PileEuroparl(PilePerplexityTask):
PILE_SET_NAME = "EuroParl"
class PileFreeLaw(PilePerplexityTask):
PILE_SET_NAME = "FreeLaw"
class PileGithub(PilePerplexityTask):
PILE_SET_NAME = "Github"
class PileGutenberg(PilePerplexityTask):
PILE_SET_NAME = "Gutenberg (PG-19)"
class PileHackernews(PilePerplexityTask):
PILE_SET_NAME = "HackerNews"
class PileNIHExporter(PilePerplexityTask):
PILE_SET_NAME = "NIH ExPorter"
class PileOpenSubtitles(PilePerplexityTask):
PILE_SET_NAME = "OpenSubtitles"
class PileOpenWebText2(PilePerplexityTask):
PILE_SET_NAME = "OpenWebText2"
class PilePhilPapers(PilePerplexityTask):
PILE_SET_NAME = "PhilPapers"
class PilePileCc(PilePerplexityTask):
PILE_SET_NAME = "Pile-CC"
class PilePubmedAbstracts(PilePerplexityTask):
PILE_SET_NAME = "PubMed Abstracts"
class PilePubmedCentral(PilePerplexityTask):
PILE_SET_NAME = "PubMed Central"
class PileStackExchange(PilePerplexityTask):
PILE_SET_NAME = "StackExchange"
class PileUspto(PilePerplexityTask):
PILE_SET_NAME = "USPTO Backgrounds"
class PileUbuntuIrc(PilePerplexityTask):
PILE_SET_NAME = "Ubuntu IRC"
class PileWikipedia(PilePerplexityTask):
PILE_SET_NAME = "Wikipedia (en)"
class PileYoutubeSubtitles(PilePerplexityTask):
PILE_SET_NAME = "YoutubeSubtitles"
......@@ -61,6 +61,56 @@ def general_detokenize(string):
return string
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
:param token_list: list
List of tokens to be PREDICTED
:param max_seq_len: int
max_seq_len of model (or max_seq_len we want to use)
:param context_len: int
Amount of desired token context for prediction. Needs to be at least 1.
:param prefix_token: token
Dummy token like <eos> so the first token has something to condition on
:return: generator
Generator of tuples
(input_tokens, pred_tokens)
Note: Score only the last len(pred_tokens) logits of the LM
"""
assert 1 <= context_len <= max_seq_len
if not token_list:
return
# +1 offset, going from input->preds
pred_len = max_seq_len - context_len + 1
predicted = 0
# Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list))
yield (
[prefix_token] + token_list[:first_seq_len - 1],
token_list[:first_seq_len]
)
predicted += first_seq_len
while predicted < len(token_list):
window_pred_len = min(len(token_list) - predicted, pred_len)
window_end = predicted + window_pred_len
yield (
token_list[window_end - max_seq_len - 1:window_end - 1],
token_list[window_end - window_pred_len:window_end],
)
predicted += window_pred_len
def make_disjoint_window(pair):
""" Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """
a, b = pair
return a[:-(len(b) - 1)], b
class Reorderer:
def __init__(self, arr, fn):
self.size = len(arr)
......
......@@ -15,6 +15,7 @@ def test_evaluator(taskname, Task):
def ll_fn(reqs):
for ctx, cont in reqs:
if len(ctx) == 0: continue
# space convention
assert ctx[-1] != ' '
assert cont[0] == ' ' or ctx[-1] == '\n'
......@@ -26,7 +27,18 @@ def test_evaluator(taskname, Task):
res.append((-random.random(), False))
return res
def ll_perp_fn(reqs):
for string, in reqs:
assert isinstance(string, str)
res = []
random.seed(42)
for _ in reqs:
res.append(-random.random())
return res
lm.loglikelihood = ll_fn
lm.loglikelihood_rolling = ll_perp_fn
evaluator.evaluate(lm, task_dict, False, 0, 10)
......@@ -7,6 +7,6 @@ def test_bootstrapping():
random.seed(42)
arr = [random.random() for _ in range(1000)]
expected = metrics.mean_stderr(arr)
bootstrapped = metrics.bootstrap_stderr(metrics.mean, arr)
bootstrapped = metrics.bootstrap_stderr(metrics.mean, arr, iters=100000)
assert bootstrapped == pytest.approx(expected, abs=1e-4)
......@@ -41,4 +41,19 @@ def test_gpt2():
targets = [-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281]
for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt, abs=1e-3)
\ No newline at end of file
assert pred == pytest.approx(tgt, abs=1e-3)
def test_gpt2_perplexity():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt2.loglikelihood_perplexity([(test_string,)])[0]
tgt = sum([-4.9599953, -8.069298, -8.308624, -10.178513, -8.906924, -1.9318912, -7.745445, -7.146077, -5.2072, -3.5882986, -1.9957212, -8.044922, -0.20841774, -5.1096807, -0.099879116, -8.888423, -4.6180487])
assert perplexity == pytest.approx(tgt, abs=1e-3)
# Hack: modify gpt2 to have shorter context length to induce rolling windows
gpt2.max_length = 5
perplexity = gpt2.loglikelihood_perplexity([(test_string,)])[0]
tgt = sum([-4.96001, -8.069275, -8.308612, -10.178482, -8.90691, -4.037338, -8.09261, -11.662385, -10.206891, -4.425003, -2.2563353, -7.909143, -1.9304147, -7.3610134, -2.3120654, -7.3229, -2.1643813])
assert perplexity == pytest.approx(tgt, abs=1e-3)
from lm_eval.utils import get_rolling_token_windows, make_disjoint_window
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v1():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]),
]
x = list(range(34))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=10,
context_len=1,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v2():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [10, 11, 12]),
([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [13, 14, 15]),
([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [16, 17, 18]),
([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [19, 20, 21]),
([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [22, 23, 24]),
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [25, 26, 27]),
([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [28, 29, 30]),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [31, 32, 33]),
]
x = list(range(34))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=10,
context_len=8,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v3():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10]),
([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11]),
([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12]),
([3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [13]),
([4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [14]),
([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [15]),
([6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [16]),
([7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [17]),
([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [19]),
([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20]),
([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [21]),
([12, 13, 14, 15, 16, 17, 18, 19, 20, 21], [22]),
([13, 14, 15, 16, 17, 18, 19, 20, 21, 22], [23]),
([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [24]),
([15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [25]),
([16, 17, 18, 19, 20, 21, 22, 23, 24, 25], [26]),
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [30]),
([21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [31]),
([22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [32]),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [33]),
]
x = list(range(34))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=10,
context_len=10,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v4():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10]),
([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11]),
([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12]),
([3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [13]),
([4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [14]),
([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [15]),
([6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [16]),
([7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [17]),
([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [19]),
([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20]),
([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [21]),
([12, 13, 14, 15, 16, 17, 18, 19, 20, 21], [22]),
([13, 14, 15, 16, 17, 18, 19, 20, 21, 22], [23]),
([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [24]),
([15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [25]),
([16, 17, 18, 19, 20, 21, 22, 23, 24, 25], [26]),
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
]
x = list(range(30))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=10,
context_len=10,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v5():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]),
]
x = list(range(30))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=10,
context_len=1,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v6():
gold = [
([-100, 0], [0, 1]),
([1, 2], [2, 3]),
([3, 4], [4, 5]),
([5, 6], [6, 7]),
([6, 7], [8]),
]
x = list(range(9))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=2,
context_len=1,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
def test_get_rolling_token_windows_empty():
generator = get_rolling_token_windows(
token_list=[],
prefix_token=-100,
max_seq_len=2,
context_len=1,
)
n = 0
for _ in generator:
n += 1
assert n == 0
def test_make_disjoint_window():
assert make_disjoint_window(([1,2,3,4,5], [2,3,4,5,6])) == ([1], [2,3,4,5,6])
assert make_disjoint_window(([1,2,3,4,5], [4,5,6])) == ([1,2,3], [4,5,6])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment