Commit e9953abb authored by mgoin's avatar mgoin
Browse files

Override greedy_until

parent 2c8e66d7
......@@ -67,19 +67,65 @@ class DeepSparseLM(BaseLM):
logits_numpy = numpy.stack([generation.score for generation in out.generations])
return torch.from_numpy(logits_numpy)
def _model_generate(self, context, max_length, eos_token_id):
# Encode the prompt tokens to strings
prompt = self.tokenizer.batch_decode(context.numpy())
# Run generation
out = self.model(
prompt=prompt, max_new_tokens=max_length, force_max_tokens=True
)
# Return tokens for prompt + generated text
return numpy.array(
[self.tokenizer(prompt[0] + out.generations[0].text)["input_ids"]]
)
def greedy_until(
self, requests: List[Tuple[str, Union[List[str], str]]]
) -> List[str]:
def _collate(x):
tokens = self.tok_encode(x[0])
return len(tokens), x[0]
results = []
reorder = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(
tqdm(reorder.get_reordered(), disable=False),
self.batch_size,
):
context = [c[0] for c in chunk]
request_args = chunk[0][1]
stop = request_args.get("until", None)
stop_sequences = stop if isinstance(stop, list) else [stop]
max_generation_length = request_args.get("max_length", None)
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(stop_sequences, list) or stop_sequences is None
# TODO: Find a better way to handle stop sequences for 0-shot.
if stop_sequences is None:
until = [self.eot_token]
else:
until = stop_sequences + [self.eot_token]
if max_generation_length is None:
max_tokens = self.max_gen_toks
else:
max_tokens = max_generation_length
responses = self.model(
sequences=context,
max_new_tokens=max_tokens,
stop=until,
do_sample=False,
)
responses = responses if type(responses) is list else [responses]
for response in responses:
response = response.generations[0].text
# Ensure the generated responses do not contain the stop sequences.
for term in until:
response = response.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), response)
results.append(response)
return reorder.get_original(results)
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
@property
def eot_token(self) -> str:
......@@ -106,8 +152,7 @@ class DeepSparseLM(BaseLM):
pass
def tok_encode(self, string: str):
return self.tokenizer.encode(string)
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment