import argparse import os from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware import uvicorn from threading import Thread from sse_starlette.sse import EventSourceResponse parser = argparse.ArgumentParser() parser.add_argument('--base_model', default=None, type=str, required=True) parser.add_argument('--lora_model', default=None, type=str,help="If None, perform inference on the base model") parser.add_argument('--tokenizer_path',default=None,type=str) parser.add_argument('--gpus', default="0", type=str) parser.add_argument('--load_in_8bit',action='store_true', help='Load the model in 8bit mode') parser.add_argument('--load_in_4bit',action='store_true', help='Load the model in 4bit mode') parser.add_argument('--only_cpu',action='store_true',help='Only use CPU for inference') parser.add_argument('--alpha',type=str,default="1.0", help="The scaling factor of NTK method, can be a float or 'auto'. ") parser.add_argument('--use_ntk', action='store_true', help="Use dynamic-ntk to extend context window") parser.add_argument('--use_flash_attention_2', action='store_true', help="Use flash-attention2 to accelerate inference") args = parser.parse_args() if args.only_cpu is True: args.gpus = "" if args.load_in_8bit or args.load_in_4bit: raise ValueError("Quantization is unavailable on CPU.") if args.load_in_8bit and args.load_in_4bit: raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments") os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus import torch import torch.nn.functional as F from transformers import ( AutoModelForCausalLM, LlamaTokenizer, GenerationConfig, TextIteratorStreamer, BitsAndBytesConfig ) from peft import PeftModel import sys parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch apply_attention_patch(use_memory_efficient_attention=True) if args.use_ntk: apply_ntk_scaling_patch(args.alpha) from openai_api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ChatCompletionResponseChoice, CompletionRequest, CompletionResponse, CompletionResponseChoice, EmbeddingsRequest, EmbeddingsResponse, ChatCompletionResponseStreamChoice, DeltaMessage, ) load_type = torch.float16 if torch.cuda.is_available(): device = torch.device(0) else: device = torch.device("cpu") if args.tokenizer_path is None: args.tokenizer_path = args.lora_model if args.lora_model is None: args.tokenizer_path = args.base_model tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True) if args.load_in_4bit or args.load_in_8bit: quantization_config = BitsAndBytesConfig( load_in_4bit=args.load_in_4bit, load_in_8bit=args.load_in_8bit, bnb_4bit_compute_dtype=load_type, ) base_model = AutoModelForCausalLM.from_pretrained( args.base_model, torch_dtype=load_type, low_cpu_mem_usage=True, device_map='auto' if not args.only_cpu else None, load_in_4bit=args.load_in_4bit, load_in_8bit=args.load_in_8bit, quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None, use_flash_attention_2=args.use_flash_attention_2, trust_remote_code=True ) model_vocab_size = base_model.get_input_embeddings().weight.size(0) tokenizer_vocab_size = len(tokenizer) print(f"Vocab of the base model: {model_vocab_size}") print(f"Vocab of the tokenizer: {tokenizer_vocab_size}") if model_vocab_size != tokenizer_vocab_size: print("Resize model embeddings to fit tokenizer") base_model.resize_token_embeddings(tokenizer_vocab_size) if args.lora_model is not None: print("loading peft model") model = PeftModel.from_pretrained( base_model, args.lora_model, torch_dtype=load_type, device_map="auto", ) else: model = base_model if device == torch.device("cpu"): model.float() model.eval() DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。""" TEMPLATE_WITH_SYSTEM_PROMPT = ( "[INST] <>\n" "{system_prompt}\n" "<>\n\n" "{instruction} [/INST]" ) TEMPLATE_WITHOUT_SYSTEM_PROMPT = "[INST] {instruction} [/INST]" def generate_prompt( instruction, response="", with_system_prompt=True, system_prompt=None ): if with_system_prompt is True: if system_prompt is None: system_prompt = DEFAULT_SYSTEM_PROMPT prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map( {"instruction": instruction, "system_prompt": system_prompt} ) else: prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({"instruction": instruction}) if len(response) > 0: prompt += " " + response return prompt def generate_completion_prompt(instruction: str): """Generate prompt for completion""" return generate_prompt(instruction, response="", with_system_prompt=True) def generate_chat_prompt(messages: list): """Generate prompt for chat completion""" system_msg = None for msg in messages: if msg.role == "system": system_msg = msg.content prompt = "" is_first_user_content = True for msg in messages: if msg.role == "system": continue if msg.role == "user": if is_first_user_content is True: prompt += generate_prompt( msg.content, with_system_prompt=True, system_prompt=system_msg ) is_first_user_content = False else: prompt += "" + generate_prompt(msg.content, with_system_prompt=False) if msg.role == "assistant": prompt += f" {msg.content}" + "" return prompt def predict( input, max_new_tokens=128, top_p=0.9, temperature=0.2, top_k=40, num_beams=1, repetition_penalty=1.1, do_sample=True, **kwargs, ): """ Main inference method type(input) == str -> /v1/completions type(input) == list -> /v1/chat/completions """ if isinstance(input, str): prompt = generate_completion_prompt(input) else: prompt = generate_chat_prompt(input) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, do_sample=do_sample, **kwargs, ) generation_config.return_dict_in_generate = True generation_config.output_scores = False generation_config.max_new_tokens = max_new_tokens generation_config.repetition_penalty = float(repetition_penalty) with torch.no_grad(): generation_output = model.generate( input_ids=input_ids, generation_config=generation_config, ) s = generation_output.sequences[0] output = tokenizer.decode(s, skip_special_tokens=True) output = output.split("[/INST]")[-1].strip() return output def stream_predict( input, max_new_tokens=128, top_p=0.75, temperature=0.1, top_k=40, num_beams=4, repetition_penalty=1.0, do_sample=True, model_id="chinese-llama-alpaca-2", **kwargs, ): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), finish_reason=None ) chunk = ChatCompletionResponse( model=model_id, choices=[choice_data], object="chat.completion.chunk", ) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) if isinstance(input, str): prompt = generate_completion_prompt(input) else: prompt = generate_chat_prompt(input) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, do_sample=do_sample, **kwargs, ) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( streamer=streamer, input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=False, max_new_tokens=max_new_tokens, repetition_penalty=float(repetition_penalty), ) Thread(target=model.generate, kwargs=generation_kwargs).start() for new_text in streamer: choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=new_text), finish_reason=None ) chunk = ChatCompletionResponse( model=model_id, choices=[choice_data], object="chat.completion.chunk" ) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), finish_reason="stop" ) chunk = ChatCompletionResponse( model=model_id, choices=[choice_data], object="chat.completion.chunk" ) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "[DONE]" def get_embedding(input): """Get embedding main function""" with torch.no_grad(): encoding = tokenizer(input, padding=True, return_tensors="pt") input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) model_output = model(input_ids, attention_mask, output_hidden_states=True) data = model_output.hidden_states[-1] mask = attention_mask.unsqueeze(-1).expand(data.size()).float() masked_embeddings = data * mask sum_embeddings = torch.sum(masked_embeddings, dim=1) seq_length = torch.sum(mask, dim=1) embedding = sum_embeddings / seq_length normalized_embeddings = F.normalize(embedding, p=2, dim=1) ret = normalized_embeddings.squeeze(0).tolist() return ret app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): """Creates a completion for the chat message""" msgs = request.messages if isinstance(msgs, str): msgs = [ChatMessage(role="user", content=msgs)] else: msgs = [ChatMessage(role=x["role"], content=x["content"]) for x in msgs] if request.stream: generate = stream_predict( input=msgs, max_new_tokens=request.max_tokens, top_p=request.top_p, top_k=request.top_k, temperature=request.temperature, num_beams=request.num_beams, repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, ) return EventSourceResponse(generate, media_type="text/event-stream") output = predict( input=msgs, max_new_tokens=request.max_tokens, top_p=request.top_p, top_k=request.top_k, temperature=request.temperature, num_beams=request.num_beams, repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, ) choices = [ ChatCompletionResponseChoice(index=i, message=msg) for i, msg in enumerate(msgs) ] choices += [ ChatCompletionResponseChoice( index=len(choices), message=ChatMessage(role="assistant", content=output) ) ] return ChatCompletionResponse(choices=choices) @app.post("/v1/completions") async def create_completion(request: CompletionRequest): """Creates a completion""" output = predict( input=request.prompt, max_new_tokens=request.max_tokens, top_p=request.top_p, top_k=request.top_k, temperature=request.temperature, num_beams=request.num_beams, repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, ) choices = [CompletionResponseChoice(index=0, text=output)] return CompletionResponse(choices=choices) @app.post("/v1/embeddings") async def create_embeddings(request: EmbeddingsRequest): """Creates text embedding""" embedding = get_embedding(request.input) data = [{"object": "embedding", "embedding": embedding, "index": 0}] return EmbeddingsResponse(data=data) if __name__ == "__main__": log_config = uvicorn.config.LOGGING_CONFIG log_config["formatters"]["access"][ "fmt" ] = "%(asctime)s - %(levelname)s - %(message)s" log_config["formatters"]["default"][ "fmt" ] = "%(asctime)s - %(levelname)s - %(message)s" uvicorn.run(app, host="0.0.0.0", port=19327, workers=1, log_config=log_config)