Commit 2fce67b3 authored by Albert Jiang's avatar Albert Jiang
Browse files

small debug w.r.t. caching and output

parent a09cc41f
......@@ -344,6 +344,7 @@ class BaseLM(LM):
# Unpack greedy sample request
context, until, = request
k, temperature = 1, 0.
greedy = True
_model_generate_kwargs = {}
elif len(request) == 4:
# Unpack temperature sample request
......@@ -351,6 +352,7 @@ class BaseLM(LM):
for key in ["k", "temperature"]:
assert key in inspect.getfullargspec(self._model_generate).args, \
f"Model generation parameter '{key}' not accepted as an argument for _model_generate"
greedy = False
_model_generate_kwargs = {"k": k, "temperature": temperature}
else:
raise AssertionError
......@@ -373,6 +375,7 @@ class BaseLM(LM):
for term in until:
s = [candidate.split(term)[0] for candidate in s]
s = s[0] if greedy else s
# partial caching
self.cache_hook.add_partial("generate", (context, until, k, temperature), s)
res.append(s)
......@@ -383,7 +386,7 @@ class BaseLM(LM):
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
return self.generate(requests)[0]
return self.generate(requests)
class Task(abc.ABC):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment