Unverified Commit 07eb712a authored by lazymio's avatar lazymio
Browse files

Left out

parent 91062a83
...@@ -127,7 +127,7 @@ class KTransformersInterface(TransformersInterface): ...@@ -127,7 +127,7 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool): def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
...@@ -198,7 +198,7 @@ class KTransformersInterface(TransformersInterface): ...@@ -198,7 +198,7 @@ class KTransformersInterface(TransformersInterface):
else: else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device) self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment