# Copyright (c) OpenMMLab. All rights reserved.
import logging
import re
from transformers import PreTrainedTokenizerFast
from .base import BasicAdapterFast
logger = logging.getLogger(__name__)
B_INST, E_INST = '[INST]', '[/INST]'
B_SYS, E_SYS = '<>\n', '\n<>\n\n'
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501
class Llama2Adapter(BasicAdapterFast):
"""Adapter for llama2.
Llama2 use the following template and the first user prompt
should contain a system prompt.
User can specify the system prompt using a <> tag otherwise
the default system prompt is prepended to user's input.
[INST]
<>\n
SYSTEM_PROMPT\n
<>\n\n
{user_prompt_1}
[/INST]
{answer_1}
[INST]
{user_prompt_2}
[/INST]
{answer_2}
[INST]
{user_prompt_2}(no space here)
...
"""
start_ids = []
sep_ids = []
def __init__(self, tokenizer: PreTrainedTokenizerFast):
super().__init__(tokenizer)
self.prev_round = 0
def encode_and_decorate(self, prompt):
r"""Encode prompt and decorate with template."""
if self.prev_round == 0:
res = re.search(r'<>(.*?)<>(.*)', prompt)
if res:
prompt = B_SYS + res.group(1).strip() + \
E_SYS + res.group(2).strip()
else:
prompt = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + prompt
prompt = f'{B_INST} {prompt.strip()} {E_INST}'
logger.debug(f'decorated prompt: {repr(prompt)}')
input_ids = self.tokenizer.encode(
prompt,
add_special_tokens=True,
return_tensors='pt',
)
self.prev_round += 1
return input_ids