Commit f16c301e authored by Leo Gao's avatar Leo Gao
Browse files

Minor changes

parent 1e7f884d
......@@ -346,8 +346,7 @@ class PerplexityTask(Task, abc.ABC):
def count_words(self, s):
""" Downstream tasks with custom word boundaries should override this! """
return len(re.split(r"\s+", s))
def
req_ret_lens = {
'loglikelihood': 2,
......
......@@ -63,6 +63,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
# only in index. We could implement some kind of caching, but that would be more of a bandaid
# solution. we could also implement some kind of autogrouping here; they should end up next to each other.
print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
......
......@@ -71,7 +71,7 @@ class GPT2LM(LM):
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows)
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
......@@ -81,7 +81,7 @@ class GPT2LM(LM):
return loglikelihoods
def _loglikelihood_tokens(self, requests):
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
with torch.no_grad():
......@@ -93,7 +93,7 @@ class GPT2LM(LM):
return (len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate)
for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()):
for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered(), disable=disable_tqdm):
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
......
......@@ -85,8 +85,7 @@ class GPT3LM(LM):
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_perplexity(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: Add chunking
# TODO: switch implementation to use _loglikelihood_tokens rather than having it do its own thing
loglikelihoods = []
for string, in tqdm(requests):
......@@ -104,7 +103,7 @@ class GPT3LM(LM):
pred_tokens=pred_tokens,
)
string_loglikelihoods.append(block_output["logprobs"])
string_loglikelihoods = np.concatenate(string_loglikelihoods)
string_loglikelihoods = np.concatenate(string_loglikelihoods).sum()
loglikelihoods.append(string_loglikelihoods)
return loglikelihoods
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment