import asyncio import gc import json import torch from transformers.generation.logits_process import LogitsProcessor from typing import Union, Tuple def apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]: stop_found = False for string in stop_strings: idx = reply.find(string) if idx != -1: reply = reply[:idx] stop_found = True break if not stop_found: # If something like "\nYo" is generated just before "\nYou: is completed, trim it for string in stop_strings: for j in range(len(string) - 1, 0, -1): if reply[-j:] == string[:j]: reply = reply[:-j] break else: continue break return reply, stop_found