Commit 91c16192 authored by Azure's avatar Azure
Browse files

Merge branch 'develop-0.2.2' into support-fp8

 Update README.md
parents 2c0cce90 d9b2895b
......@@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
warm_uped = False
......@@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
......@@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
......
......@@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
for i in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
if i > 1 and flashinfer_enabled:
if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
......
......@@ -173,8 +173,8 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="API Generate Tester")
parser.add_argument("--concurrent", type=int, default=1000, help="Number of concurrent evaluations")
parser.add_argument("--file", type=str, default="TIGER-Lab/MMLU-Pro", help="Path to the mmlu.jsonl file")
parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file")
parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file")
parser.add_argument("--result", type=str, default="./mmlu_result_pro.json", help="Path to save the result JSON file")
parser.add_argument("--log", type=str, default="./mmlu_result_pro.log", help="Path to save the log file")
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL")
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
......
......@@ -330,6 +330,8 @@ class GGUFLoader:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values.copy())
if ggml_name == "BF16":
values = values.view(torch.bfloat16)
values = values.view(shape[-2::-1])
return values
......
......@@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
warm_uped = False
def get_compute_capability(device:torch.device = None):
if torch.cuda.is_available():
if device is None:
num_gpus = torch.cuda.device_count()
min_compute_capability_major = 100
for gpu_id in range(num_gpus):
gpu_props = torch.cuda.get_device_properties(gpu_id)
min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)
return min_compute_capability_major
else:
return torch.cuda.get_device_properties(device)
def set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
......@@ -164,6 +176,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else:
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if use_flashinfer_mla:
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
MLAWrapperSingleton.need_plan_all()
logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
......@@ -186,6 +202,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
else:
next_token = torch.argmax(next_token_scores, dim=-1)
first_token_time = time.time() - start_time
if use_flashinfer_mla:
MLAWrapperSingleton.reset_buffer()
prefill_count = seq_length
prefill_time = first_token_time
......@@ -203,22 +222,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
start_time = time.time()
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
if i > 1 and use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
......
......@@ -350,6 +350,7 @@ elif MUSA_HOME is not None:
"at::cuda": "at::musa",
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
"nv_bfloat16": "mt_bfloat16",
}).run()
ops_module = MUSAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment