Commit e9953abb authored by mgoin's avatar mgoin
Browse files

Override greedy_until

parent 2c8e66d7
...@@ -67,19 +67,65 @@ class DeepSparseLM(BaseLM): ...@@ -67,19 +67,65 @@ class DeepSparseLM(BaseLM):
logits_numpy = numpy.stack([generation.score for generation in out.generations]) logits_numpy = numpy.stack([generation.score for generation in out.generations])
return torch.from_numpy(logits_numpy) return torch.from_numpy(logits_numpy)
def _model_generate(self, context, max_length, eos_token_id): def greedy_until(
# Encode the prompt tokens to strings self, requests: List[Tuple[str, Union[List[str], str]]]
prompt = self.tokenizer.batch_decode(context.numpy()) ) -> List[str]:
def _collate(x):
# Run generation tokens = self.tok_encode(x[0])
out = self.model( return len(tokens), x[0]
prompt=prompt, max_new_tokens=max_length, force_max_tokens=True
) results = []
# Return tokens for prompt + generated text reorder = utils.Reorderer(requests, _collate)
return numpy.array(
[self.tokenizer(prompt[0] + out.generations[0].text)["input_ids"]] for chunk in utils.chunks(
) tqdm(reorder.get_reordered(), disable=False),
self.batch_size,
):
context = [c[0] for c in chunk]
request_args = chunk[0][1]
stop = request_args.get("until", None)
stop_sequences = stop if isinstance(stop, list) else [stop]
max_generation_length = request_args.get("max_length", None)
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(stop_sequences, list) or stop_sequences is None
# TODO: Find a better way to handle stop sequences for 0-shot.
if stop_sequences is None:
until = [self.eot_token]
else:
until = stop_sequences + [self.eot_token]
if max_generation_length is None:
max_tokens = self.max_gen_toks
else:
max_tokens = max_generation_length
responses = self.model(
sequences=context,
max_new_tokens=max_tokens,
stop=until,
do_sample=False,
)
responses = responses if type(responses) is list else [responses]
for response in responses:
response = response.generations[0].text
# Ensure the generated responses do not contain the stop sequences.
for term in until:
response = response.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), response)
results.append(response)
return reorder.get_original(results)
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
@property @property
def eot_token(self) -> str: def eot_token(self) -> str:
...@@ -106,8 +152,7 @@ class DeepSparseLM(BaseLM): ...@@ -106,8 +152,7 @@ class DeepSparseLM(BaseLM):
pass pass
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string) return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment