Unverified Commit 8138fd52 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Fix `loglikelihood_rolling` caching ( #1821 ) (#2187)



* fix revision type

* allow for None-input loglikelihood reqs to be cached

* handle no remaining cache items

* pre-commit

* change cache_hook.add_partial(loglikelihood_rolling...) convention

---------
Co-authored-by: default avatarBaber Abbasi <baber@eleuther.ai>
parent 2de3688f
...@@ -283,8 +283,11 @@ class CachingLM: ...@@ -283,8 +283,11 @@ class CachingLM:
eval_logger.info( eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
) )
# actually run the LM on the requests that do not have cached results if remaining_reqs:
rem_res = getattr(self.lm, attr)(remaining_reqs) # actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
else:
rem_res = []
# stick the new ones back into the list and also cache any of the new ones # stick the new ones back into the list and also cache any of the new ones
resptr = 0 resptr = 0
......
...@@ -510,7 +510,7 @@ class TemplateAPI(TemplateLM): ...@@ -510,7 +510,7 @@ class TemplateAPI(TemplateLM):
): ):
if answer_ is not None: if answer_ is not None:
res.append(answer_) res.append(answer_)
# partial caching # cache requests that aren't from a loglikelihood_rolling request
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial( self.cache_hook.add_partial(
"loglikelihood", cache_key, answer_ "loglikelihood", cache_key, answer_
...@@ -638,4 +638,7 @@ class TemplateAPI(TemplateLM): ...@@ -638,4 +638,7 @@ class TemplateAPI(TemplateLM):
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
return loglikelihoods return loglikelihoods
...@@ -1018,6 +1018,9 @@ class HFLM(TemplateLM): ...@@ -1018,6 +1018,9 @@ class HFLM(TemplateLM):
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
return loglikelihoods return loglikelihoods
def _batch_scheduler(self, pos, n_reordered_requests): def _batch_scheduler(self, pos, n_reordered_requests):
...@@ -1246,7 +1249,13 @@ class HFLM(TemplateLM): ...@@ -1246,7 +1249,13 @@ class HFLM(TemplateLM):
res.append(answer) res.append(answer)
self.cache_hook.add_partial("loglikelihood", request_str, answer) if request_str is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial(
"loglikelihood", request_str, answer
)
pbar.update(1) pbar.update(1)
pbar.close() pbar.close()
......
...@@ -386,6 +386,9 @@ class NeMoLM(LM): ...@@ -386,6 +386,9 @@ class NeMoLM(LM):
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
...@@ -468,6 +471,9 @@ class NeMoLM(LM): ...@@ -468,6 +471,9 @@ class NeMoLM(LM):
answer = (logprob, is_greedy) answer = (logprob, is_greedy)
if cache_key is not None: if cache_key is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer) res.append(answer)
......
...@@ -321,6 +321,9 @@ class DeepSparseLM(LM): ...@@ -321,6 +321,9 @@ class DeepSparseLM(LM):
res.append(answer) res.append(answer)
if cache_key is not None: if cache_key is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res) return re_ord.get_original(res)
......
...@@ -502,7 +502,8 @@ class NEURON_HF(TemplateLM): ...@@ -502,7 +502,8 @@ class NEURON_HF(TemplateLM):
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens( def _loglikelihood_tokens(
...@@ -620,7 +621,11 @@ class NEURON_HF(TemplateLM): ...@@ -620,7 +621,11 @@ class NEURON_HF(TemplateLM):
res.append(answer) res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer) if cache_key is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res) return re_ord.get_original(res)
......
...@@ -307,6 +307,10 @@ class VLLM(TemplateLM): ...@@ -307,6 +307,10 @@ class VLLM(TemplateLM):
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
return loglikelihoods return loglikelihoods
def generate_until( def generate_until(
...@@ -453,8 +457,10 @@ class VLLM(TemplateLM): ...@@ -453,8 +457,10 @@ class VLLM(TemplateLM):
res.append(answer) res.append(answer)
# partial caching
if cache_key is not None: if cache_key is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
pbar.update(1) pbar.update(1)
pbar.close() pbar.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment