Commit 5aee6c04 authored by Azure's avatar Azure
Browse files

Merge branch 'main' into develop-0.2.3

parents 216a63b8 48b98007
......@@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
warm_uped = False
class KTransformersThreadContext(TransformersThreadContext):
......@@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface):
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
try:
generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
generation_config = GenerationConfig(
max_length=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
do_sample=True
)
torch.set_default_dtype(config.torch_dtype)
if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2"
......@@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface):
" belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.model.generation_config = generation_config
self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
self.cache = StaticCache(
......@@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
dtype=self.model.dtype,
)
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try:
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
gen_config = GenerationConfig(
max_length=128,
temperature=0.7,
top_p=0.9,
do_sample=True
)
self.model.generation_config = gen_config
if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer)
......@@ -110,12 +111,10 @@ class KTransformersInterface(TransformersInterface):
warm_uped = True
if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(torch_device)
logits = self.model(
self.current_ids.to(torch_device),
cache_position=self.active_cache_position,
past_key_values=self.cache,
attention_mask=mask,
return_dict=False,
use_cache=True,
)[0]
......@@ -128,10 +127,13 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool):
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
input_ids_length = input_ids.shape[-1]
if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
self.seq_length = input_ids_length
return
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device
......@@ -166,44 +168,57 @@ class KTransformersInterface(TransformersInterface):
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
else:
logger.warning(f"seq_length bigger than cache_lens, killed")
exit(0)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
mask = torch.ones((1, self.seq_length)).to(device)
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
attention_mask=mask,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
def chunk_prefill(input_ids, cache_position):
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
return logits
chunk_start = 0
while chunk_start < input_ids_length:
chunk_end = min(chunk_start + self.args.chunk_prefill_size, input_ids_length)
if self.cache != None:
self.cache.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
chunk_start += self.args.chunk_prefill_size
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device)
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
......@@ -212,7 +227,7 @@ class KTransformersInterface(TransformersInterface):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device)
async def inference(self, local_messages, thread_id: str):
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
async with self._infer_lock:
async for v in super().inference(local_messages, thread_id):
async for v in super().inference(local_messages, thread_id, temperature, top_p):
yield v
......@@ -13,6 +13,7 @@ from transformers import (
from ktransformers.server.config.config import Config
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler
from torch.nn.attention import SDPBackend
import torch
import sys, os
from ..base import ThreadContext, BackendInterfaceBase
......@@ -202,20 +203,23 @@ class TransformersInterface(BackendInterfaceBase):
self.seq_length += 1
return self.streamer.put(new_tokens)
def prepare_logits_wrapper(self, inputs, device):
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
if temperature is None or temperature == 0:
temperature = self.model.generation_config.temperature
if top_p is None:
top_p = self.model.generation_config.top_p
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
top_k=self.args.top_k,
top_p=self.args.top_p,
temperature=self.args.temperature,
top_p=top_p,
temperature=temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
)
self.inputs = inputs
self.generation_config = generation_config
try: # transformers==4.43
self.logits_warper = (
self.model._get_logits_warper(generation_config,device=device)
self.model._get_logits_warper(generation_config, device=device)
)
except:
self.logits_warper = (
......@@ -239,12 +243,10 @@ class TransformersInterface(BackendInterfaceBase):
def decode_one_tokens(self):
if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(self.args.device)
logits = self.model(
self.current_ids,
cache_position=self.active_cache_position,
past_key_values=self.cache,
attention_mask=mask,
return_dict=False,
use_cache=True,
)[0]
......@@ -255,7 +257,7 @@ class TransformersInterface(BackendInterfaceBase):
return self.logits_to_token(logits)
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool):
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
input_ids_length = input_ids.shape[-1]
logger.debug(f"input_ids: {input_ids.shape}")
......@@ -306,7 +308,6 @@ class TransformersInterface(BackendInterfaceBase):
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
mask = torch.ones((1, self.seq_length)).to(self.args.device)
device = input_ids.device
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
......@@ -318,21 +319,26 @@ class TransformersInterface(BackendInterfaceBase):
past_key_values=self.cache,
return_dict=False,
use_cache=True,
attention_mask=mask,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device)
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
@torch.no_grad
def generate(self):
self.args.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length)
if(self.args.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0")
yield self.streamer.end()
return
logger.info(f"max_new_tokens: {self.args.max_new_tokens}")
self.profiler.set_counter("decode", 0)
for i in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
......@@ -359,7 +365,7 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id
return True
async def inference(self, local_messages, thread_id: str):
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List):
......@@ -386,7 +392,7 @@ class TransformersInterface(BackendInterfaceBase):
print(think, end="",flush=True)
yield think
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)
......
......@@ -105,7 +105,8 @@ class Config(metaclass=Singleton):
self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.max_chunk_size = self.model.get("max_chunk_size", 2048)
self.chunk_prefill_size = self.model.get("chunk_prefill_size", 8192)
self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False)
self.healing = self.model.get("healing", False)
......
......@@ -25,7 +25,9 @@ class ChatCompletionCreate(BaseModel):
messages: List[Message]
model : str
stream : bool = False
temperature: Optional[float] = None
top_p: Optional[float] = None
def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages]
......
......@@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
model: str
prompt: str | List[str]
stream: bool = False
temperature: Optional[float] = None
top_p: Optional[float] = None
def get_tokenizer_messages(self):
if isinstance(self.prompt,List):
......
......@@ -27,6 +27,7 @@ import torch
import KTransformersOps
from .custom_loader import SafeTensorLoader
import ctypes
import math
class GGMLQuantizationType(IntEnum):
F32 = 0
......@@ -230,7 +231,7 @@ class GGUFLoader:
shape = [read_value(f, DATA_TYPES["uint64"]) for _ in range(shape_len)]
ggml_type = read_value(f, DATA_TYPES["uint32"])
bad_offset = read_value(f, DATA_TYPES["uint64"])
n_elems = int(np.prod(shape))
n_elems = int(math.prod(shape))
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
n_bytes = n_elems * type_size // block_size
np_dims = tuple(reversed(shape))
......
......@@ -110,7 +110,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False, use_flashinfer_mla = False,
mode = 'normal', force_think: bool = False, chunk_prefill_size = 16384, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
......@@ -124,7 +124,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
tokens = []
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
if cuda_graph_runner is None:
use_cuda_graph = False
if use_cuda_graph:
......@@ -152,25 +152,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
next_token = torch.argmax(next_token_scores, dim=-1)
return next_token
torch.cuda.set_device(torch_device)
with torch.no_grad():
stream = TextStreamer(tokenizer)
if mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
)
else:
past_key_values = None
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
if past_key_values != None:
past_key_values.cur_idx=cache_position
start_time = time.time()
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
# TODO: use CUDA Graph for chunk prefill, may get small improvement
def chunk_prefill(inputs, cache_position, past_key_values):
if mode == "long_context":
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else:
......@@ -182,9 +165,24 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
return logits
torch.cuda.set_device(torch_device)
with torch.no_grad():
stream = TextStreamer(tokenizer)
if mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
)
else:
past_key_values = None
generation_config, model_kwargs = model._prepare_generation_config(
None, max_length=max_new_tokens,
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config
None, do_sample=True
# change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
)
try: # transformers==4.43
logits_warper = (
......@@ -194,12 +192,29 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits_warper = (
model._get_logits_warper(generation_config)
)
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
start_time = time.time()
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_prefill_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_prefill_size
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_token = torch.argmax(next_token_scores, dim=-1)
first_token_time = time.time() - start_time
if use_flashinfer_mla:
......@@ -208,7 +223,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
prefill_count = seq_length
prefill_time = first_token_time
if force_think:
print("<think>\n")
print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
generated_ids[:, seq_length] = next_token
tokens.append(int(next_token))
......@@ -230,7 +245,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
warm_uped = True
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment