"...lm-evaluation-harness.git" did not exist on "21970120f9ca2ac02dbb4db63654da4303bb66f2"
Commit 7d06d0f9 authored by yangzhong's avatar yangzhong
Browse files

Update files

parent 2f320edb
Pipeline #2827 failed with stages
in 0 seconds
# OPENAI API DEMO
> 更加详细的OPENAI API信息:<https://platform.openai.com/docs/api-reference>
这是一个使用fastapi实现的简易的仿OPENAI API风格的服务器DEMO,您可以使用这个API DEMO来快速搭建基于中文大模型的个人网站以及其他有趣的WEB DEMO。
本实现基于vLLM部署LLM后端服务,暂不支持加载LoRA模型、仅CPU部署和使用8bit推理。
## 部署方式
安装依赖
``` shell
pip install fastapi uvicorn shortuuid vllm fschat
```
启动脚本
``` shell
python scripts/openai_server_demo/openai_api_server_vllm.py --model /path/to/base_model --tokenizer-mode slow --served-model-name chinese-llama-alpaca-2
```
### 参数说明
`--model {base_model}`: 存放HF格式的LLaMA-2模型权重和配置文件的目录,可以是合并后的中文Alpaca-2模型
`--tokenizer {tokenizer_path}`: 存放对应tokenizer的目录。若不提供此参数,则其默认值与`--base_model`相同
`--tokenizer-mode {tokenizer-mode}`: tokenizer的模式。使用基于LLaMA/LLaMa-2的模型时,固定为`slow`
`--tensor-parallel-size {tensor_parallel_size}`: 使用的GPU数量。默认为1
`--served-model-name {served-model-name}`: API中使用的模型名。若使用中文Alpaca-2系列模型,模型名中务必包含`chinese-llama-alpaca-2`
`--host {host_name}`: 部署服务的host name。默认值是`localhost`
`--port {port}`: 部署服务的端口号。默认值是`8000`
## API文档
### 文字接龙(completion)
> 有关completion的中文翻译,李宏毅教授将其翻译为文字接龙 <https://www.youtube.com/watch?v=yiY4nPOzJEg>
最基础的API接口,输入prompt,输出语言大模型的文字接龙(completion)结果。
API DEMO内置有prompt模板,prompt将被套入instruction模板中,这里输入的prompt应更像指令而非对话。
#### 快速体验completion接口
请求command:
``` shell
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "chinese-llama-alpaca-2",
"prompt": "告诉我中国的首都在哪里"
}'
```
json返回体:
``` json
{
"id": "cmpl-41234d71fa034ec3ae90bbf6b5be7",
"object": "text_completion",
"created": 1690870733,
"model": "chinese-llama-alpaca-2",
"choices": [
{
"index": 0,
"text": "中国的首都是北京。"
}
]
}
```
#### completion接口高级参数
请求command:
``` shell
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "chinese-llama-alpaca-2",
"prompt": "告诉我中国和美国分别各有哪些优点缺点",
"max_tokens": 90,
"temperature": 0.7,
"num_beams": 4,
"top_k": 40
}'
```
json返回体:
``` json
{
"id": "cmpl-ceca9906bf0a429989e850368cc3f893",
"object": "text_completion",
"created": 1690870952,
"model": "chinese-llama-alpaca-2",
"choices": [
{
"index": 0,
"text": "中国的优点是拥有丰富的文化和历史,而美国的优点是拥有先进的科技和经济体系。"
}
]
}
```
#### completion接口高级参数说明
> 有关Decoding策略,更加详细的细节可以参考 <https://towardsdatascience.com/the-three-decoding-methods-for-nlp-23ca59cb1e9d> 该文章详细讲述了三种LLaMA会用到的Decoding策略:Greedy Decoding、Random Sampling 和 Beam Search,Decoding策略是top_k、top_p、temperature等高级参数的基础。
`prompt`: 生成文字接龙(completion)的提示。
`max_tokens`: 新生成的句子的token长度。
`temperature`: 在0和2之间选择的采样温度。较高的值如0.8会使输出更加随机,而较低的值如0.2则会使其输出更具有确定性。temperature越高,使用随机采样最为decoding的概率越大。
`use_beam_search`: 使用束搜索(beam search)。默认为`false`,即启用随机采样策略(random sampling)
`n`: 输出序列的数量,默认为1
`best_of`: 当搜索策略为束搜索(beam search)时,该参数为在束搜索(beam search)中所使用的束个数。默认和`n`相同
`top_k`: 在随机采样(random sampling)时,前top_k高概率的token将作为候选token被随机采样。
`top_p`: 在随机采样(random sampling)时,累积概率超过top_p的token将作为候选token被随机采样,越低随机性越大,举个例子,当top_p设定为0.6时,概率前5的token概率分别为{0.23, 0.20, 0.18, 0.11, 0.10}时,前三个token的累积概率为0.61,那么第4个token将被过滤掉,只有前三的token将作为候选token被随机采样。
`presence_penalty`: 重复惩罚,取值范围-2 ~ 2,默认值为0。值大于0表示鼓励模型使用新的token,反之鼓励重复。
`stream`: 设置为`true`时,按流式输出的形式返回。默认为`false`
### 聊天(chat completion)
聊天接口支持多轮对话
#### 快速体验聊天接口
请求command:
``` shell
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "chinese-llama-alpaca-2",
"messages": [
{"role": "user","content": "给我讲一些有关杭州的故事吧"}
]
}'
```
json返回体:
``` json
{
"id": "cmpl-8fc1b6356cf64681a41a8739445a8cf8",
"object": "chat.completion",
"created": 1690872695,
"model": "chinese-llama-alpaca-2",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "好的,请问您对杭州有什么特别的偏好吗?"
}
}
]
}
```
#### 多轮对话
请求command:
``` shell
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "chinese-llama-alpaca-2",
"messages": [
{"role": "user","content": "给我讲一些有关杭州的故事吧"},
{"role": "assistant","content": "好的,请问您对杭州有什么特别的偏好吗?"},
{"role": "user","content": "我比较喜欢和西湖,可以给我讲一下西湖吗"}
],
"repetition_penalty": 1.0
}'
```
json返回体:
``` json
{
"id": "cmpl-02bf36497d3543c980ca2ae8cc4feb63",
"object": "chat.completion",
"created": 1690872676,
"model": "chinese-llama-alpaca-2",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "是的,西湖是杭州最著名的景点之一,它被誉为“人间天堂”。 <\\s>"
}
}
]
}
```
#### 聊天接口高级参数说明
`prompt`: 生成文字接龙(completion)的提示。
`max_tokens`: 新生成的句子的token长度。
`temperature`: 在0和2之间选择的采样温度。较高的值如0.8会使输出更加随机,而较低的值如0.2则会使其输出更具有确定性。temperature越高,使用随机采样最为decoding的概率越大。
`use_beam_search`: 使用束搜索(beam search)。默认为`false`,即启用随机采样策略(random sampling)
`n`: 输出序列的数量,默认为1
`best_of`: 当搜索策略为束搜索(beam search)时,该参数为在束搜索(beam search)中所使用的束个数。默认和`n`相同
`top_k`: 在随机采样(random sampling)时,前top_k高概率的token将作为候选token被随机采样。
`top_p`: 在随机采样(random sampling)时,累积概率超过top_p的token将作为候选token被随机采样,越低随机性越大,举个例子,当top_p设定为0.6时,概率前5的token概率分别为{0.23, 0.20, 0.18, 0.11, 0.10}时,前三个token的累积概率为0.61,那么第4个token将被过滤掉,只有前三的token将作为候选token被随机采样。
`presence_penalty`: 重复惩罚,取值范围-2 ~ 2,默认值为0。值大于0表示鼓励模型使用新的token,反之鼓励重复。
`stream`: 设置为`true`时,按流式输出的形式返回。默认为`false`
from typing import Optional, List, Dict, Any, Union, Literal
import time
import shortuuid
from pydantic import BaseModel, Field
class ChatCompletionRequest(BaseModel):
model: str = "chinese-llama-alpaca-2"
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.2
top_p: Optional[float] = 0.9
top_k: Optional[int] = 40
n: Optional[int] = 1
max_tokens: Optional[int] = 512
num_beams: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
repetition_penalty: Optional[float] = 1.1
user: Optional[str] = None
do_sample: Optional[bool] = True
class ChatMessage(BaseModel):
role: str
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str = "chinese-llama-alpaca-2"
choices: List[
Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
]
class EmbeddingsRequest(BaseModel):
input: Union[str, List[Any]]
user: Optional[str] = None
class EmbeddingsResponse(BaseModel):
object: str = "list"
data: List[Dict[str, Any]]
model: str = "chinese-llama-alpaca-2"
class CompletionRequest(BaseModel):
prompt: Union[str, List[Any]]
temperature: Optional[float] = 0.2
n: Optional[int] = 1
max_tokens: Optional[int] = 512
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
top_p: Optional[float] = 0.9
top_k: Optional[int] = 40
num_beams: Optional[int] = 1
logprobs: Optional[int] = None
echo: Optional[bool] = False
repetition_penalty: Optional[float] = 1.1
user: Optional[str] = None
do_sample: Optional[bool] = True
class CompletionResponseChoice(BaseModel):
index: int
text: str
class CompletionResponse(BaseModel):
id: Optional[str] = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
object: Optional[str] = "text_completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: Optional[str] = "chinese-llama-alpaca-2"
choices: List[CompletionResponseChoice]
import time
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from vllm.utils import random_uuid
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: Optional[str] = None
class ModelPermission(BaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: str = False
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "vllm"
root: Optional[str] = None
parent: Optional[str] = None
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.2
top_p: Optional[float] = 0.9
n: Optional[int] = 1
max_tokens: Optional[int] = 512
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
presence_penalty: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by vLLM
best_of: Optional[int] = None
top_k: Optional[int] = 40
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 512
temperature: Optional[float] = 0.2
top_p: Optional[float] = 0.9
n: Optional[int] = 1
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
presence_penalty: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by vLLM
top_k: Optional[int] = 40
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class CompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
import argparse
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from threading import Thread
from sse_starlette.sse import EventSourceResponse
parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default=None, type=str, required=True)
parser.add_argument('--lora_model', default=None, type=str,help="If None, perform inference on the base model")
parser.add_argument('--tokenizer_path',default=None,type=str)
parser.add_argument('--gpus', default="0", type=str)
parser.add_argument('--load_in_8bit',action='store_true', help='Load the model in 8bit mode')
parser.add_argument('--load_in_4bit',action='store_true', help='Load the model in 4bit mode')
parser.add_argument('--only_cpu',action='store_true',help='Only use CPU for inference')
parser.add_argument('--alpha',type=str,default="1.0", help="The scaling factor of NTK method, can be a float or 'auto'. ")
parser.add_argument('--use_ntk', action='store_true', help="Use dynamic-ntk to extend context window")
parser.add_argument('--use_flash_attention_2', action='store_true', help="Use flash-attention2 to accelerate inference")
args = parser.parse_args()
if args.only_cpu is True:
args.gpus = ""
if args.load_in_8bit or args.load_in_4bit:
raise ValueError("Quantization is unavailable on CPU.")
if args.load_in_8bit and args.load_in_4bit:
raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments")
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
import torch
import torch.nn.functional as F
from transformers import (
AutoModelForCausalLM,
LlamaTokenizer,
GenerationConfig,
TextIteratorStreamer,
BitsAndBytesConfig
)
from peft import PeftModel
import sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch
apply_attention_patch(use_memory_efficient_attention=True)
if args.use_ntk:
apply_ntk_scaling_patch(args.alpha)
from openai_api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatMessage,
ChatCompletionResponseChoice,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
EmbeddingsRequest,
EmbeddingsResponse,
ChatCompletionResponseStreamChoice,
DeltaMessage,
)
load_type = torch.float16
if torch.cuda.is_available():
device = torch.device(0)
else:
device = torch.device("cpu")
if args.tokenizer_path is None:
args.tokenizer_path = args.lora_model
if args.lora_model is None:
args.tokenizer_path = args.base_model
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
if args.load_in_4bit or args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type,
)
base_model = AutoModelForCausalLM.from_pretrained(
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto' if not args.only_cpu else None,
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,
use_flash_attention_2=args.use_flash_attention_2,
trust_remote_code=True
)
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenizer_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
print(f"Vocab of the tokenizer: {tokenizer_vocab_size}")
if model_vocab_size != tokenizer_vocab_size:
print("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenizer_vocab_size)
if args.lora_model is not None:
print("loading peft model")
model = PeftModel.from_pretrained(
base_model,
args.lora_model,
torch_dtype=load_type,
device_map="auto",
)
else:
model = base_model
if device == torch.device("cpu"):
model.float()
model.eval()
DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。"""
TEMPLATE_WITH_SYSTEM_PROMPT = (
"[INST] <<SYS>>\n" "{system_prompt}\n" "<</SYS>>\n\n" "{instruction} [/INST]"
)
TEMPLATE_WITHOUT_SYSTEM_PROMPT = "[INST] {instruction} [/INST]"
def generate_prompt(
instruction, response="", with_system_prompt=True, system_prompt=None
):
if with_system_prompt is True:
if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map(
{"instruction": instruction, "system_prompt": system_prompt}
)
else:
prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({"instruction": instruction})
if len(response) > 0:
prompt += " " + response
return prompt
def generate_completion_prompt(instruction: str):
"""Generate prompt for completion"""
return generate_prompt(instruction, response="", with_system_prompt=True)
def generate_chat_prompt(messages: list):
"""Generate prompt for chat completion"""
system_msg = None
for msg in messages:
if msg.role == "system":
system_msg = msg.content
prompt = ""
is_first_user_content = True
for msg in messages:
if msg.role == "system":
continue
if msg.role == "user":
if is_first_user_content is True:
prompt += generate_prompt(
msg.content, with_system_prompt=True, system_prompt=system_msg
)
is_first_user_content = False
else:
prompt += "<s>" + generate_prompt(msg.content, with_system_prompt=False)
if msg.role == "assistant":
prompt += f" {msg.content}" + "</s>"
return prompt
def predict(
input,
max_new_tokens=128,
top_p=0.9,
temperature=0.2,
top_k=40,
num_beams=1,
repetition_penalty=1.1,
do_sample=True,
**kwargs,
):
"""
Main inference method
type(input) == str -> /v1/completions
type(input) == list -> /v1/chat/completions
"""
if isinstance(input, str):
prompt = generate_completion_prompt(input)
else:
prompt = generate_chat_prompt(input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
do_sample=do_sample,
**kwargs,
)
generation_config.return_dict_in_generate = True
generation_config.output_scores = False
generation_config.max_new_tokens = max_new_tokens
generation_config.repetition_penalty = float(repetition_penalty)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
output = output.split("[/INST]")[-1].strip()
return output
def stream_predict(
input,
max_new_tokens=128,
top_p=0.75,
temperature=0.1,
top_k=40,
num_beams=4,
repetition_penalty=1.0,
do_sample=True,
model_id="chinese-llama-alpaca-2",
**kwargs,
):
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id,
choices=[choice_data],
object="chat.completion.chunk",
)
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
if isinstance(input, str):
prompt = generate_completion_prompt(input)
else:
prompt = generate_chat_prompt(input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
do_sample=do_sample,
**kwargs,
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
streamer=streamer,
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=False,
max_new_tokens=max_new_tokens,
repetition_penalty=float(repetition_penalty),
)
Thread(target=model.generate, kwargs=generation_kwargs).start()
for new_text in streamer:
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id, choices=[choice_data], object="chat.completion.chunk"
)
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(), finish_reason="stop"
)
chunk = ChatCompletionResponse(
model=model_id, choices=[choice_data], object="chat.completion.chunk"
)
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "[DONE]"
def get_embedding(input):
"""Get embedding main function"""
with torch.no_grad():
encoding = tokenizer(input, padding=True, return_tensors="pt")
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
model_output = model(input_ids, attention_mask, output_hidden_states=True)
data = model_output.hidden_states[-1]
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
seq_length = torch.sum(mask, dim=1)
embedding = sum_embeddings / seq_length
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret = normalized_embeddings.squeeze(0).tolist()
return ret
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""Creates a completion for the chat message"""
msgs = request.messages
if isinstance(msgs, str):
msgs = [ChatMessage(role="user", content=msgs)]
else:
msgs = [ChatMessage(role=x["role"], content=x["content"]) for x in msgs]
if request.stream:
generate = stream_predict(
input=msgs,
max_new_tokens=request.max_tokens,
top_p=request.top_p,
top_k=request.top_k,
temperature=request.temperature,
num_beams=request.num_beams,
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
)
return EventSourceResponse(generate, media_type="text/event-stream")
output = predict(
input=msgs,
max_new_tokens=request.max_tokens,
top_p=request.top_p,
top_k=request.top_k,
temperature=request.temperature,
num_beams=request.num_beams,
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
)
choices = [
ChatCompletionResponseChoice(index=i, message=msg) for i, msg in enumerate(msgs)
]
choices += [
ChatCompletionResponseChoice(
index=len(choices), message=ChatMessage(role="assistant", content=output)
)
]
return ChatCompletionResponse(choices=choices)
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
"""Creates a completion"""
output = predict(
input=request.prompt,
max_new_tokens=request.max_tokens,
top_p=request.top_p,
top_k=request.top_k,
temperature=request.temperature,
num_beams=request.num_beams,
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
)
choices = [CompletionResponseChoice(index=0, text=output)]
return CompletionResponse(choices=choices)
@app.post("/v1/embeddings")
async def create_embeddings(request: EmbeddingsRequest):
"""Creates text embedding"""
embedding = get_embedding(request.input)
data = [{"object": "embedding", "embedding": embedding, "index": 0}]
return EmbeddingsResponse(data=data)
if __name__ == "__main__":
log_config = uvicorn.config.LOGGING_CONFIG
log_config["formatters"]["access"][
"fmt"
] = "%(asctime)s - %(levelname)s - %(message)s"
log_config["formatters"]["default"][
"fmt"
] = "%(asctime)s - %(levelname)s - %(message)s"
uvicorn.run(app, host="0.0.0.0", port=19327, workers=1, log_config=log_config)
import argparse
import asyncio
from http import HTTPStatus
import json
import time
from typing import AsyncGenerator, Dict, List, Optional
from packaging import version
import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import sys
import os
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from openai_api_protocol_vllm import (
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
served_model = None
app = fastapi.FastAPI()
from fastchat.conversation import register_conv_template, get_conv_template
from fastchat.model.model_adapter import BaseModelAdapter, model_adapters
import fastchat
def compare_version(version1, version2):
# if v1 >= v2, return True, else return False
v1 = version.parse(version1)
v2 = version.parse(version2)
return v1 >= v2
if compare_version(fastchat.__version__, '0.2.23'):
use_old_conversation = False
else:
use_old_conversation = True
def getConversation(name, system, roles, messages, offset, sep_style, sep, sep2=None, stop_str=None, stop_token_ids=None):
if not use_old_conversation:
return Conversation(
name=name,
system_message=system,
roles=roles,
messages=messages,
offset=offset,
sep_style=sep_style,
sep=sep,
sep2=sep2,
stop_str=stop_str,
stop_token_ids=stop_token_ids
)
else:
return Conversation(
name=name,
system=system,
roles=roles,
messages=messages,
offset=offset,
sep_style=sep_style,
sep=sep,
sep2=sep2,
stop_str=stop_str,
stop_token_ids=stop_token_ids
)
# Chinese LLaMA Alpaca default template
register_conv_template(
getConversation(
name="chinese-llama-alpaca",
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
roles=("### Instruction:\n", "### Response:"),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="\n\n",
sep2="",
)
)
# Chinese LLaMA Alpaca 2 default template
register_conv_template(
getConversation(
name="chinese-llama-alpaca-2",
system="[INST] <<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_token_ids=[2],
)
)
class ChineseLLaMAAlpacaAdapter(BaseModelAdapter):
"""The model adapter for Chinese-LLaMA-Alpaca"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "chinese-llama-alpaca" in model_path.lower()
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("chinese-llama-alpaca")
class ChineseLLaMAAlpaca2Adapter(BaseModelAdapter):
"""The model adapter for Chinese-LLaMA-Alpaca-2"""
def match(self, model_path: str):
return "chinese-llama-alpaca-2" in model_path.lower()
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("chinese-llama-alpaca-2")
# add model adapters to head of List model_adapters
model_adapters = [ChineseLLaMAAlpacaAdapter()] + model_adapters
model_adapters = [ChineseLLaMAAlpaca2Adapter()] + model_adapters
fastchat.model.model_adapter.model_adapters = model_adapters
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(ErrorResponse(message=message,
type="invalid_request_error").dict(),
status_code=status_code.value)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
async def check_model(request) -> Optional[JSONResponse]:
if request.model == served_model:
return
ret = create_error_response(
HTTPStatus.NOT_FOUND,
f"The model `{request.model}` does not exist.",
)
return ret
async def get_gen_prompt(request) -> str:
conv = get_conversation_template(request.model)
conv = getConversation(
name=conv.name,
system=conv.system_message if not use_old_conversation else conv.system,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
stop_token_ids=conv.stop_token_ids,
)
if isinstance(request.messages, str):
prompt = request.messages
else:
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
if not use_old_conversation:
conv.system_message = message["content"]
else:
conv.system = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
async def get_gen_prompt_nochat(request) -> str:
conv = get_conversation_template(request.model)
conv = getConversation(
name=conv.name,
system=conv.system_message if not use_old_conversation else conv.system,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
stop_token_ids=conv.stop_token_ids,
)
prompt = request.prompt
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
async def check_length(request, prompt, model_config):
if hasattr(model_config.hf_config, "max_sequence_length"):
context_len = model_config.hf_config.max_sequence_length
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
elif hasattr(model_config.hf_config, "max_position_embeddings"):
context_len = model_config.hf_config.max_position_embeddings
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
else:
context_len = 2048
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)
if token_num + request.max_tokens > context_len:
return create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {context_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
)
else:
return None
@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards)
def create_logprobs(token_ids: List[int],
id_logprobs: List[Dict[int, float]],
initial_text_offset: int = 0) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()
})
return logprobs
@app.post("/v1/chat/completions")
async def create_chat_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
request = ChatCompletionRequest(**await raw_request.json())
logger.info(f"Received chat completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt, engine_model_config)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
try:
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
finish_reason=finish_reason,
)
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
response_json = create_stream_response_json(
index=i,
text="",
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role="assistant", content=output.text),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
@app.post("/v1/completions")
async def create_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the vLLM engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.echo:
# We do not support echo since the vLLM engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")
if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
prompt = request.prompt[0]
else:
prompt = request.prompt
request.prompt = prompt
prompt = await get_gen_prompt_nochat(request)
created_time = int(time.time())
try:
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json(
index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
if request.logprobs is not None:
logprobs = create_logprobs(
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i]))
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = (LogProbs()
if request.logprobs is not None else None)
response_json = create_stream_response_json(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host",
type=str,
default="localhost",
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
logger.info(f"args: {args}")
if args.served_model_name is not None:
served_model = args.served_model_name
else:
served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode,
trust_remote_code=engine_args.trust_remote_code)
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
## privateGPT相关示例脚本
具体使用方法参考:https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_zh
Detailed usage: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_en
The following codes are adapted from https://github.com/imartinez/privateGPT/blob/main/privateGPT.py
### privateGPT.py
嵌套Alpaca-2指令模板的主程序入口示例代码。由于第三方库更新频繁,请勿直接使用。建议对照教程自行修改。
Example with Alpaca-2 template. Please do not use this script directly, as third-party library may change over time. Please follow our wiki to adapt to new code.
### privateGPT_refine.py
使用`refine`策略的主程序入口示例代码。
Example that uses `refine` strategy.
#!/usr/bin/env python3
from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.llms import GPT4All, LlamaCpp
import os
import argparse
import time
load_dotenv()
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
model_n_batch = int(os.environ.get('MODEL_N_BATCH', 8))
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS', 4))
from constants import CHROMA_SETTINGS
def main():
# Parse the command line arguments
args = parse_arguments()
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# activate/deactivate the streaming StdOut callback for LLMs
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
# Prepare the LLM
match model_type:
case "LlamaCpp":
llm = LlamaCpp(model_path=model_path, max_tokens=model_n_ctx, n_ctx=model_n_ctx,
n_gpu_layers=1, n_batch=model_n_batch, callbacks=callbacks, n_threads=8, verbose=False)
case "GPT4All":
llm = GPT4All(model=model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False)
case _default:
# raise exception if model_type is not supported
raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All")
# The followings are specifically designed for Chinese-Alpaca-2
# For detailed usage: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_en
alpaca2_prompt_template = (
"[INST] <<SYS>>\n"
"You are a helpful assistant. 你是一个乐于助人的助手。\n"
"<</SYS>>\n\n"
"{context}\n\n{question} [/INST]"
)
from langchain import PromptTemplate
input_with_prompt = PromptTemplate(template=alpaca2_prompt_template, input_variables=["context", "question"])
qa = RetrievalQA.from_chain_type(
llm=llm, chain_type="stuff", retriever=retriever,
return_source_documents= not args.hide_source,
chain_type_kwargs={"prompt": input_with_prompt})
# Interactive questions and answers
while True:
query = input("\nEnter a query: ")
if query == "exit":
break
if query.strip() == "":
continue
# Get the answer from the chain
start = time.time()
res = qa(query)
answer, docs = res['result'], [] if args.hide_source else res['source_documents']
end = time.time()
# Print the result
print("\n\n> Question:")
print(query)
print(f"\n> Answer (took {round(end - start, 2)} s.):")
print(answer)
# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
def parse_arguments():
parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, '
'using the power of LLMs.')
parser.add_argument("--hide-source", "-S", action='store_true',
help='Use this flag to disable printing of source documents used for answers.')
parser.add_argument("--mute-stream", "-M",
action='store_true',
help='Use this flag to disable the streaming StdOut callback for LLMs.')
return parser.parse_args()
if __name__ == "__main__":
main()
#!/usr/bin/env python3
from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.llms import GPT4All, LlamaCpp
import os
import argparse
import time
load_dotenv()
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
model_n_batch = int(os.environ.get('MODEL_N_BATCH', 8))
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS', 4))
from constants import CHROMA_SETTINGS
def main():
# Parse the command line arguments
args = parse_arguments()
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# activate/deactivate the streaming StdOut callback for LLMs
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
# Prepare the LLM
match model_type:
case "LlamaCpp":
llm = LlamaCpp(model_path=model_path, max_tokens=model_n_ctx, n_ctx=model_n_ctx,
n_gpu_layers=1, n_batch=model_n_batch, callbacks=callbacks, n_threads=8, verbose=False)
case "GPT4All":
llm = GPT4All(model=model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False)
case _default:
# raise exception if model_type is not supported
raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All")
# The followings are specifically designed for Chinese-Alpaca-2
# For detailed usage: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_en
alpaca2_refine_prompt_template = (
"[INST] <<SYS>>\n"
"You are a helpful assistant. 你是一个乐于助人的助手。\n"
"<</SYS>>\n\n"
"这是原始问题:{question}\n"
"已有的回答: {existing_answer}\n"
"现在还有一些文字,(如果有需要)你可以根据它们完善现有的回答。"
"\n\n{context_str}\n\n"
"请根据新的文段,进一步完善你的回答。 [/INST]"
)
alpaca2_initial_prompt_template = (
"[INST] <<SYS>>\n"
"You are a helpful assistant. 你是一个乐于助人的助手。\n"
"<</SYS>>\n\n"
"以下为背景知识:\n{context_str}\n"
"请根据以上背景知识,回答这个问题:{question} [/INST]"
)
from langchain import PromptTemplate
refine_prompt = PromptTemplate(
input_variables=["question", "existing_answer", "context_str"],
template=alpaca2_refine_prompt_template,
)
initial_qa_prompt = PromptTemplate(
input_variables=["context_str", "question"],
template=alpaca2_initial_prompt_template,
)
chain_type_kwargs = {"question_prompt": initial_qa_prompt, "refine_prompt": refine_prompt}
qa = RetrievalQA.from_chain_type(
llm=llm, chain_type="refine",
retriever=retriever, return_source_documents= not args.hide_source,
chain_type_kwargs=chain_type_kwargs)
# Interactive questions and answers
while True:
query = input("\nEnter a query: ")
if query == "exit":
break
if query.strip() == "":
continue
# Get the answer from the chain
start = time.time()
res = qa(query)
answer, docs = res['result'], [] if args.hide_source else res['source_documents']
end = time.time()
# Print the result
print("\n\n> Question:")
print(query)
print(f"\n> Answer (took {round(end - start, 2)} s.):")
print(answer)
# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
def parse_arguments():
parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, '
'using the power of LLMs.')
parser.add_argument("--hide-source", "-S", action='store_true',
help='Use this flag to disable printing of source documents used for answers.')
parser.add_argument("--mute-stream", "-M",
action='store_true',
help='Use this flag to disable the streaming StdOut callback for LLMs.')
return parser.parse_args()
if __name__ == "__main__":
main()
{
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"pad_token": "<pad>",
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}
{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"clean_up_tokenization_spaces": false,
"eos_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"model_max_length": 1000000000000000019884624838656,
"pad_token": null,
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"unk_token": {
"__type": "AddedToken",
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"use_fast": false
}
import logging
import os
from dataclasses import dataclass
from typing import Dict, Sequence, Union, List
import datasets
import torch
from datasets import load_dataset, concatenate_datasets
import transformers
IGNORE_INDEX = -100
logger = logging.getLogger('__name__')
PROMPT_TEMPLATE = (
"[INST] <<SYS>>\n"
"You are a helpful assistant. 你是一个乐于助人的助手。\n"
"<</SYS>>\n\n{instruction} [/INST]"
)
def build_instruction_dataset(data_path: Union[List[str],str],
tokenizer: transformers.PreTrainedTokenizer,
max_seq_length: int, data_cache_dir = None,
preprocessing_num_workers = None,
):
def tokenization(examples):
sources = []
targets = []
prompt = PROMPT_TEMPLATE
for instruction, input, output in zip(examples['instruction'],examples['input'],examples['output']):
if input is not None and input !="":
instruction = instruction+'\n'+input
source = prompt.format_map({'instruction':instruction})
target = f"{output}{tokenizer.eos_token}"
sources.append(source)
targets.append(target)
tokenized_sources = tokenizer(sources,return_attention_mask=False)
tokenized_targets = tokenizer(targets,return_attention_mask=False,add_special_tokens=False)
all_input_ids = []
all_labels = []
for s,t in zip(tokenized_sources['input_ids'],tokenized_targets['input_ids']):
input_ids = torch.LongTensor(s + t)[:max_seq_length]
labels = torch.LongTensor([IGNORE_INDEX] * len(s) + t)[:max_seq_length]
assert len(input_ids) == len(labels)
all_input_ids.append(input_ids)
all_labels.append(labels)
results = {'input_ids':all_input_ids, 'labels': all_labels}
return results
logging.warning("building dataset...")
all_datasets = []
if not isinstance(data_path,(list,tuple)):
data_path = [data_path]
for file in data_path:
if data_cache_dir is None:
data_cache_dir = str(os.path.dirname(file))
cache_path = os.path.join(data_cache_dir,os.path.basename(file).split('.')[0]+f"_{max_seq_length}")
os.makedirs(cache_path, exist_ok=True)
try:
processed_dataset = datasets.load_from_disk(cache_path)
logger.info(f'training datasets-{file} has been loaded from disk')
except Exception:
raw_dataset = load_dataset("json", data_files=file, cache_dir=cache_path)
tokenization_func = tokenization
tokenized_dataset = raw_dataset.map(
tokenization_func,
batched=True,
num_proc=preprocessing_num_workers,
remove_columns=["instruction","input","output"],
keep_in_memory=False,
desc="preprocessing on dataset",
)
processed_dataset = tokenized_dataset
processed_dataset.save_to_disk(cache_path)
processed_dataset.set_format('torch')
all_datasets.append(processed_dataset['train'])
all_datasets = concatenate_datasets(all_datasets)
return all_datasets
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 100,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1e-10
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 1e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.3.0.dev0"
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model
from .peft_model import (
PeftModel,
PeftModelForCausalLM,
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
)
from .tuners import (
LoraConfig,
LoraModel,
PrefixEncoder,
PrefixTuningConfig,
PromptEmbedding,
PromptEncoder,
PromptEncoderConfig,
PromptEncoderReparameterizationType,
PromptTuningConfig,
PromptTuningInit,
)
from .utils import (
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
PeftConfig,
PeftType,
PromptLearningConfig,
TaskType,
bloom_model_postprocess_past_key_value,
get_peft_model_state_dict,
# prepare_model_for_int8_training,
set_peft_model_state_dict,
shift_tokens_right,
)
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .peft_model import (
PeftModel,
PeftModelForCausalLM,
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
)
from .tuners import LoraConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig
from .utils import PromptLearningConfig
MODEL_TYPE_TO_PEFT_MODEL_MAPPING = {
"SEQ_CLS": PeftModelForSequenceClassification,
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
"CAUSAL_LM": PeftModelForCausalLM,
"TOKEN_CLS": PeftModelForTokenClassification,
}
PEFT_TYPE_TO_CONFIG_MAPPING = {
"PROMPT_TUNING": PromptTuningConfig,
"PREFIX_TUNING": PrefixTuningConfig,
"P_TUNING": PromptEncoderConfig,
"LORA": LoraConfig,
}
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "v"],
"mt5": ["q", "v"],
"bart": ["q_proj", "v_proj"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"opt": ["q_proj", "v_proj"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
"gpt_neo": ["q_proj", "v_proj"],
"bert": ["query", "value"],
"roberta": ["query", "value"],
"xlm-roberta": ["query", "value"],
"electra": ["query", "value"],
"deberta-v2": ["query_proj", "value_proj"],
"deberta": ["in_proj"],
"layoutlm": ["query", "value"],
"llama": ["q_proj", "v_proj"],
"chatglm": ["query_key_value"],
}
def get_peft_config(config_dict):
"""
Returns a Peft config object from a dictionary.
Args:
config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
"""
return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)
def _prepare_prompt_learning_config(peft_config, model_config):
if peft_config.num_layers is None:
if "num_hidden_layers" in model_config:
num_layers = model_config["num_hidden_layers"]
elif "num_layers" in model_config:
num_layers = model_config["num_layers"]
elif "n_layer" in model_config:
num_layers = model_config["n_layer"]
else:
raise ValueError("Please specify `num_layers` in `peft_config`")
peft_config.num_layers = num_layers
if peft_config.token_dim is None:
if "hidden_size" in model_config:
token_dim = model_config["hidden_size"]
elif "n_embd" in model_config:
token_dim = model_config["n_embd"]
elif "d_model" in model_config:
token_dim = model_config["d_model"]
else:
raise ValueError("Please specify `token_dim` in `peft_config`")
peft_config.token_dim = token_dim
if peft_config.num_attention_heads is None:
if "num_attention_heads" in model_config:
num_attention_heads = model_config["num_attention_heads"]
elif "n_head" in model_config:
num_attention_heads = model_config["n_head"]
elif "num_heads" in model_config:
num_attention_heads = model_config["num_heads"]
elif "encoder_attention_heads" in model_config:
num_attention_heads = model_config["encoder_attention_heads"]
else:
raise ValueError("Please specify `num_attention_heads` in `peft_config`")
peft_config.num_attention_heads = num_attention_heads
if getattr(peft_config, "encoder_hidden_size", None) is None:
setattr(peft_config, "encoder_hidden_size", token_dim)
return peft_config
def _prepare_lora_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
if len(peft_config.target_modules) == 1:
peft_config.fan_in_fan_out = True
peft_config.enable_lora = [True, False, True]
if peft_config.inference_mode:
peft_config.merge_weights = True
return peft_config
def get_peft_model(model, peft_config):
"""
Returns a Peft model object from a model and a config.
Args:
model ([`transformers.PreTrainedModel`]): Model to be wrapped.
peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
"""
model_config = model.config.to_dict()
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
peft_config = _prepare_lora_config(peft_config, model_config)
return PeftModel(model, peft_config)
if not isinstance(peft_config, PromptLearningConfig):
peft_config = _prepare_lora_config(peft_config, model_config)
else:
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config)
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import warnings
from contextlib import contextmanager
import torch
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory
from huggingface_hub import hf_hub_download
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin
from .tuners import LoraModel, PrefixEncoder, PromptEmbedding, PromptEncoder
from .utils import (
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
WEIGHTS_NAME,
PeftConfig,
PeftType,
PromptLearningConfig,
TaskType,
_set_trainable,
get_peft_model_state_dict,
set_peft_model_state_dict,
shift_tokens_right,
)
class PeftModel(PushToHubMixin, torch.nn.Module):
"""
Parameter-Efficient Fine-Tuning Model. Base model encompassing various Peft methods.
Args:
model ([`PreTrainedModel`]): The base transformer model used for Peft.
peft_config ([`PeftConfig`]): The configuration of the Peft model.
**Attributes**:
- **base_model** ([`PreTrainedModel`]) -- The base transformer model used for Peft.
- **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model.
- **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when
saving the model.
- **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if
`isinstance(self.peft_config, PromptLearningConfig)`.
- **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if
`isinstance(self.peft_config, PromptLearningConfig)`.
- **transformer_backbone_name** (`str`) -- The name of the transformer
backbone in the base model if `isinstance(self.peft_config, PromptLearningConfig)`.
- **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone
in the base model if `isinstance(self.peft_config, PromptLearningConfig)`.
"""
def __init__(self, model, peft_config: PeftConfig):
super().__init__()
self.peft_config = peft_config
self.base_model = model
self.config = self.base_model.config
self.modules_to_save = None
if isinstance(self.peft_config, PromptLearningConfig):
self._setup_prompt_encoder()
else:
self.base_model = LoraModel(peft_config, model)
if getattr(self.peft_config, "modules_to_save", None) is not None:
self.modules_to_save = self.peft_config.modules_to_save
_set_trainable(self)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def save_pretrained(self, save_directory, **kwargs):
r"""
Args:
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
re-loaded using the `LoraModel.from_pretrained` class method, and also used by the `LoraModel.push_to_hub`
method.
save_directory (`str`):
Directory where the adapter model and configuration files will be saved (will be created if it does not
exist).
**kwargs:
Additional keyword arguments passed along to the `push_to_hub` method.
"""
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
# save only the trainable weights
output_state_dict = get_peft_model_state_dict(self, kwargs.get("state_dict", None))
torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))
# save the config and change the inference mode to `True`
if self.peft_config.base_model_name_or_path is None:
self.peft_config.base_model_name_or_path = (
self.base_model.__dict__.get("name_or_path", None)
if isinstance(self.peft_config, PromptLearningConfig)
else self.base_model.model.__dict__.get("name_or_path", None)
)
inference_mode = self.peft_config.inference_mode
self.peft_config.inference_mode = True
self.peft_config.save_pretrained(save_directory)
self.peft_config.inference_mode = inference_mode
@classmethod
def from_pretrained(cls, model, model_id, **kwargs):
r"""
Args:
Instantiate a `LoraModel` from a pretrained Lora configuration and weights.
model (`transformers.PreTrainedModel`):
The model to be adapted. The model should be initialized with the `from_pretrained` method. from
`transformers` library.
model_id (`str`):
The name of the Lora configuration to use. Can be either:
- A string, the `model id` of a Lora configuration hosted inside a model repo on
huggingface Hub
- A path to a directory containing a Lora configuration file saved using the
`save_pretrained` method, e.g., ``./my_lora_config_directory/``.
"""
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING
# load the config
config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id)
if getattr(model, "hf_device_map", None) is not None:
remove_hook_from_submodules(model)
if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
model = cls(model, config)
else:
model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)
# load weights if any
if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)):
filename = os.path.join(model_id, WEIGHTS_NAME)
else:
try:
filename = hf_hub_download(model_id, WEIGHTS_NAME)
except: # noqa
raise ValueError(
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
f"Please check that the file {WEIGHTS_NAME} is present at {model_id}."
)
adapters_weights = torch.load(
filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
# load the weights into the model
model = set_peft_model_state_dict(model, adapters_weights)
if getattr(model, "hf_device_map", None) is not None:
device_map = kwargs.get("device_map", "auto")
max_memory = kwargs.get("max_memory", None)
no_split_module_classes = model._no_split_modules
if device_map != "sequential":
max_memory = get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
low_zero=(device_map == "balanced_low_0"),
)
if isinstance(device_map, str):
device_map = infer_auto_device_map(
model, max_memory=max_memory, no_split_module_classes=no_split_module_classes
)
model = dispatch_model(model, device_map=device_map)
hook = AlignDevicesHook(io_same_device=True)
if model.peft_config.peft_type == PeftType.LORA:
add_hook_to_module(model.base_model.model, hook)
else:
remove_hook_from_submodules(model.prompt_encoder)
add_hook_to_module(model.base_model, hook)
return model
def _setup_prompt_encoder(self):
transformer_backbone = None
for name, module in self.base_model.named_children():
for param in module.parameters():
param.requires_grad = False
if isinstance(module, PreTrainedModel):
# Make sure to freeze Tranformers model
if transformer_backbone is None:
transformer_backbone = module
self.transformer_backbone_name = name
if self.peft_config.num_transformer_submodules is None:
self.peft_config.num_transformer_submodules = (
2 if self.peft_config.task_type == TaskType.SEQ_2_SEQ_LM else 1
)
for named_param, value in list(transformer_backbone.named_parameters()):
if value.shape[0] == self.base_model.config.vocab_size:
self.word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", ""))
break
if self.peft_config.peft_type == PeftType.PROMPT_TUNING:
prompt_encoder = PromptEmbedding(self.peft_config, self.word_embeddings)
elif self.peft_config.peft_type == PeftType.P_TUNING:
prompt_encoder = PromptEncoder(self.peft_config)
elif self.peft_config.peft_type == PeftType.PREFIX_TUNING:
prompt_encoder = PrefixEncoder(self.peft_config)
else:
raise ValueError("Not supported")
self.prompt_encoder = prompt_encoder
self.prompt_tokens = torch.arange(
self.peft_config.num_virtual_tokens * self.peft_config.num_transformer_submodules
).long()
def get_prompt_embedding_to_save(self):
"""
Returns the prompt embedding to save when saving the model. Only applicable when `peft_config.peft_type !=
PeftType.LORA`.
"""
prompt_tokens = self.prompt_tokens.unsqueeze(0).expand(1, -1).to(self.device)
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
prompt_tokens = prompt_tokens[:, : self.peft_config.num_virtual_tokens]
prompt_embeddings = self.prompt_encoder(prompt_tokens)
return prompt_embeddings[0].detach().cpu()
def get_prompt(self, batch_size):
"""
Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`.
"""
prompt_tokens = self.prompt_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
prompt_tokens = prompt_tokens[:, : self.peft_config.num_virtual_tokens]
if self.peft_config.inference_mode:
past_key_values = self.prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
else:
past_key_values = self.prompt_encoder(prompt_tokens)
past_key_values = past_key_values.view(
batch_size,
self.peft_config.num_virtual_tokens,
self.peft_config.num_layers * 2,
self.peft_config.num_attention_heads,
self.peft_config.token_dim // self.peft_config.num_attention_heads,
)
if self.peft_config.num_transformer_submodules == 2:
past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
self.peft_config.num_transformer_submodules * 2
)
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
past_key_values = post_process_fn(past_key_values)
return past_key_values
else:
if self.peft_config.inference_mode:
prompts = self.prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
else:
prompts = self.prompt_encoder(prompt_tokens)
return prompts
def print_trainable_parameters(self):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in self.named_parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.base_model, name)
def forward(self, *args, **kwargs): # pylint: disable=E0202
"""
Forward pass of the model.
"""
return self.get_base_model()(*args, **kwargs)
@contextmanager
def disable_adapter(self):
"""
Disables the adapter module.
"""
if isinstance(self.peft_config, PromptLearningConfig):
old_forward = self.forward
self.forward = self.base_model.forward
else:
self.base_model.disable_adapter_layers()
yield
if isinstance(self.peft_config, PromptLearningConfig):
self.forward = old_forward
else:
self.base_model.enable_adapter_layers()
def get_base_model(self):
"""
Returns the base model.
"""
return self.base_model if isinstance(self.peft_config, PromptLearningConfig) else self.base_model.model
class PeftModelForSequenceClassification(PeftModel):
"""
Peft model for sequence classification tasks.
Args:
model ([`PreTrainedModel`]): Base transformer model
peft_config ([`PeftConfig`]): Peft config.
**Attributes**:
- **config** ([`PretrainedConfig`]) -- The configuration object of the base model.
- **cls_layer_name** (`str`) -- The name of the classification layer.
Example::
>>> from transformers import AutoModelForSequenceClassification >>> from peft import
PeftModelForSequenceClassification, get_peft_config >>> config = {
'peft_type': 'PREFIX_TUNING', 'task_type': 'SEQ_CLS', 'inference_mode': False, 'num_virtual_tokens':
20, 'token_dim': 768, 'num_transformer_submodules': 1, 'num_attention_heads': 12, 'num_layers': 12,
'encoder_hidden_size': 768, 'prefix_projection': False, 'postprocess_past_key_value_function': None
}
>>> peft_config = get_peft_config(config) >>> model =
AutoModelForSequenceClassification.from_pretrained("bert-base-cased") >>> peft_model =
PeftModelForSequenceClassification(model, peft_config) >>> peft_model.print_trainable_parameters() trainable
params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117
"""
def __init__(self, model, peft_config: PeftConfig):
super().__init__(model, peft_config)
self.modules_to_save = ["classifier", "score"]
for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break
# to make sure classifier layer is trainable
_set_trainable(self)
def forward( # pylint: disable=W0221
self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if not isinstance(self.peft_config, PromptLearningConfig):
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
batch_size = input_ids.shape[0]
if attention_mask is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
if kwargs.get("position_ids", None) is not None:
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
kwargs["position_ids"] = None
kwargs.update(
{
"attention_mask": attention_mask,
"labels": labels,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
}
)
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
else:
if kwargs.get("token_type_ids", None) is not None:
kwargs["token_type_ids"] = torch.cat(
(
torch.zeros(batch_size, self.peft_config.num_virtual_tokens).to(self.device),
kwargs["token_type_ids"],
),
dim=1,
).long()
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
prompts = self.get_prompt(batch_size=batch_size)
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
def _prefix_tuning_forward(
self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
batch_size = input_ids.shape[0]
past_key_values = self.get_prompt(batch_size)
fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
kwargs.update(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"inputs_embeds": inputs_embeds,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
"past_key_values": past_key_values,
}
)
if "past_key_values" in fwd_params:
return self.base_model(labels=labels, **kwargs)
else:
transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
if "past_key_values" not in fwd_params:
raise ValueError("Model does not support past key values which are required for prefix tuning.")
outputs = transformer_backbone_name(**kwargs)
pooled_output = outputs[1] if len(outputs) > 1 else outputs[0]
if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
pooled_output = self.base_model.dropout(pooled_output)
logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.base_model.num_labels == 1:
self.config.problem_type = "regression"
elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.base_model.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class PeftModelForCausalLM(PeftModel):
"""
Peft model for Causal LM
Args:
model ([`PreTrainedModel`]): Base transformer model
peft_config ([`PeftConfig`]): Peft config.
Example::
>>> from transformers import AutoModelForCausalLM >>> from peft import PeftModelForCausalLM, get_peft_config
>>> config = {
'peft_type': 'PREFIX_TUNING', 'task_type': 'CAUSAL_LM', 'inference_mode': False, 'num_virtual_tokens':
20, 'token_dim': 1280, 'num_transformer_submodules': 1, 'num_attention_heads': 20, 'num_layers': 36,
'encoder_hidden_size': 1280, 'prefix_projection': False, 'postprocess_past_key_value_function': None
}
>>> peft_config = get_peft_config(config) >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") >>>
peft_model = PeftModelForCausalLM(model, peft_config) >>> peft_model.print_trainable_parameters() trainable
params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544
"""
def __init__(self, model, peft_config: PeftConfig):
super().__init__(model, peft_config)
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
def forward(# pylint: disable=W0221
self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
if not isinstance(self.peft_config, PromptLearningConfig):
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
batch_size = input_ids.shape[0]
if attention_mask is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
if kwargs.get("position_ids", None) is not None:
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
kwargs["token_type_ids"] = None
kwargs.update(
{
"attention_mask": attention_mask,
"labels": labels,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
}
)
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
past_key_values = self.get_prompt(batch_size)
return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs)
else:
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# concat prompt labels
if labels is not None:
prefix_labels = torch.full((batch_size, self.peft_config.num_virtual_tokens), -100).to(self.device)
kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
prompts = self.get_prompt(batch_size=batch_size)
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
def generate(self, **kwargs):
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
try:
if not isinstance(self.peft_config, PromptLearningConfig):
outputs = self.base_model.generate(**kwargs)
else:
if "input_ids" not in kwargs:
raise ValueError("input_ids must be provided for Peft model generation")
if kwargs.get("attention_mask", None) is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(
kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
).to(kwargs["input_ids"].device)
kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)
if kwargs.get("position_ids", None) is not None:
warnings.warn(
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
)
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn(
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
)
kwargs["token_type_ids"] = None
outputs = self.base_model.generate(**kwargs)
except:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
raise
else:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
return outputs
def prepare_inputs_for_generation(self, *args, **kwargs):
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
if isinstance(self.peft_config, PromptLearningConfig):
if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING:
past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
model_kwargs["past_key_values"] = past_key_values
else:
if model_kwargs["past_key_values"] is None:
inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
prompts = prompts.to(inputs_embeds.dtype)
model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
model_kwargs["input_ids"] = None
return model_kwargs
class PeftModelForSeq2SeqLM(PeftModel):
"""
Peft model for Seq2Seq LM
Args:
model ([`PreTrainedModel`]): Base transformer model
peft_config ([`PeftConfig`]): Peft config.
Example::
>>> from transformers import AutoModelForSeq2SeqLM >>> from peft import PeftModelForSeq2SeqLM, get_peft_config
>>> config = {
'peft_type': 'LORA', 'task_type': 'SEQ_2_SEQ_LM', 'inference_mode': False, 'r': 8, 'target_modules':
['q', 'v'], 'lora_alpha': 32, 'lora_dropout': 0.1, 'merge_weights': False, 'fan_in_fan_out': False,
'enable_lora': None, 'bias': 'none'
}
>>> peft_config = get_peft_config(config) >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>>
peft_model = PeftModelForSeq2SeqLM(model, peft_config) >>> peft_model.print_trainable_parameters() trainable
params: 884736 || all params: 223843584 || trainable%: 0.3952474242013566
"""
def __init__(self, model, peft_config: PeftConfig):
super().__init__(model, peft_config)
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
self.base_model._prepare_encoder_decoder_kwargs_for_generation
)
def forward(
self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
if not isinstance(self.peft_config, PromptLearningConfig):
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
batch_size = input_ids.shape[0]
if decoder_attention_mask is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1)
if kwargs.get("position_ids", None) is not None:
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
kwargs["token_type_ids"] = None
kwargs.update(
{
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"labels": labels,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
}
)
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
past_key_values = self.get_prompt(batch_size)
return self.base_model(
input_ids=input_ids, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs
)
else:
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if decoder_inputs_embeds is None and decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
decoder_inputs_embeds = self.word_embeddings(decoder_input_ids)
if attention_mask is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
# concat prompt labels
if labels is not None:
if self.peft_config.num_transformer_submodules == 1:
kwargs["labels"] = labels
elif self.peft_config.num_transformer_submodules == 2:
prefix_labels = torch.full((batch_size, self.peft_config.num_virtual_tokens), -100).to(self.device)
kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
prompts = self.get_prompt(batch_size=batch_size)
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts[:, : self.peft_config.num_virtual_tokens], inputs_embeds), dim=1)
if self.peft_config.num_transformer_submodules == 1:
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
elif self.peft_config.num_transformer_submodules == 2:
decoder_inputs_embeds = torch.cat(
(prompts[:, self.peft_config.num_virtual_tokens :], decoder_inputs_embeds), dim=1
)
return self.base_model(
inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **kwargs
)
def generate(self, **kwargs):
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
self._prepare_encoder_decoder_kwargs_for_generation
)
try:
if not isinstance(self.peft_config, PromptLearningConfig):
outputs = self.base_model.generate(**kwargs)
else:
if "input_ids" not in kwargs:
raise ValueError("input_ids must be provided for Peft model generation")
if kwargs.get("position_ids", None) is not None:
warnings.warn(
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
)
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn(
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
)
kwargs["token_type_ids"] = None
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
outputs = self.base_model.generate(**kwargs)
else:
raise NotImplementedError
except:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
self.base_model_prepare_encoder_decoder_kwargs_for_generation
)
raise
else:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
self.base_model_prepare_encoder_decoder_kwargs_for_generation
)
return outputs
def prepare_inputs_for_generation(self, *args, **kwargs):
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING:
batch_size = model_kwargs["decoder_input_ids"].shape[0]
past_key_values = self.get_prompt(batch_size)
model_kwargs["past_key_values"] = past_key_values
return model_kwargs
class PeftModelForTokenClassification(PeftModel):
"""
Peft model for sequence classification tasks.
Args:
model ([`PreTrainedModel`]): Base transformer model
peft_config ([`PeftConfig`]): Peft config.
**Attributes**:
- **config** ([`PretrainedConfig`]) -- The configuration object of the base model.
- **cls_layer_name** (`str`) -- The name of the classification layer.
Example::
>>> from transformers import AutoModelForSequenceClassification >>> from peft import
PeftModelForTokenClassification, get_peft_config >>> config = {
'peft_type': 'PREFIX_TUNING', 'task_type': 'TOKEN_CLS', 'inference_mode': False, 'num_virtual_tokens':
20, 'token_dim': 768, 'num_transformer_submodules': 1, 'num_attention_heads': 12, 'num_layers': 12,
'encoder_hidden_size': 768, 'prefix_projection': False, 'postprocess_past_key_value_function': None
}
>>> peft_config = get_peft_config(config) >>> model =
AutoModelForTokenClassification.from_pretrained("bert-base-cased") >>> peft_model =
PeftModelForTokenClassification(model, peft_config) >>> peft_model.print_trainable_parameters() trainable
params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117
"""
def __init__(self, model, peft_config: PeftConfig):
super().__init__(model, peft_config)
self.modules_to_save = ["classifier", "score"]
for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break
# to make sure classifier layer is trainable
_set_trainable(self)
def forward(
self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if not isinstance(self.peft_config, PromptLearningConfig):
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
batch_size = input_ids.shape[0]
if attention_mask is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
if kwargs.get("position_ids", None) is not None:
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
kwargs["position_ids"] = None
kwargs.update(
{
"attention_mask": attention_mask,
"labels": labels,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
}
)
if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
else:
if kwargs.get("token_type_ids", None) is not None:
kwargs["token_type_ids"] = torch.cat(
(
torch.zeros(batch_size, self.peft_config.num_virtual_tokens).to(self.device),
kwargs["token_type_ids"],
),
dim=1,
).long()
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
prompts = self.get_prompt(batch_size=batch_size)
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
def _prefix_tuning_forward(
self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
batch_size = input_ids.shape[0]
past_key_values = self.get_prompt(batch_size)
fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
kwargs.update(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"inputs_embeds": inputs_embeds,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
"past_key_values": past_key_values,
}
)
if "past_key_values" in fwd_params:
return self.base_model(labels=labels, **kwargs)
else:
transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
if "past_key_values" not in fwd_params:
raise ValueError("Model does not support past key values which are required for prefix tuning.")
outputs = transformer_backbone_name(**kwargs)
sequence_output = outputs[0]
if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
sequence_output = self.base_model.dropout(sequence_output)
logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output)
loss = None
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .lora import LoraConfig, LoraModel
from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment