Unverified Commit 77a34c28 authored by UnicornChan's avatar UnicornChan Committed by GitHub
Browse files

Merge pull request #36 from kvcache-ai/develop-0.1.2

Release v0.1.2
parents 44f57270 395cd3e7
...@@ -2,36 +2,56 @@ ...@@ -2,36 +2,56 @@
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace: replace:
class: ktransformers.operators.RoPE.RotaryEmbedding class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
name: "^model\\.layers\\..*$" # regular expression name: "^model\\.layers\\..*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
replace: replace:
class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
name: "^model\\.layers\\..*\\.mlp\\.experts$" name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
device: "cpu" # which devices to load this module when initializing # device: "cpu" # which devices to load this module when initializing
kwargs: kwargs:
prefill_device: "cuda" prefill_device: "cuda"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda" out_device: "cuda"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.Qwen2MoeModelPerLayerPrefill" class: "ktransformers.operators.models.KQwen2MoeModel"
kwargs: kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\..*\\."
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
\ No newline at end of file
...@@ -6,6 +6,7 @@ from ktransformers.optimize.optimize import optimize_and_load_gguf ...@@ -6,6 +6,7 @@ from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_cache import StaticCache from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
class KTransformersThreadContext(TransformersThreadContext): class KTransformersThreadContext(TransformersThreadContext):
...@@ -48,8 +49,11 @@ class KTransformersInterface(TransformersInterface): ...@@ -48,8 +49,11 @@ class KTransformersInterface(TransformersInterface):
def decode_one_tokens(self): def decode_one_tokens(self):
if not hasattr(self, "cuda_graph_runner"): if not hasattr(self, "cuda_graph_runner"):
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, return_dict=False, use_cache=True) self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True)
if hasattr(self, "cuda_graph_runner"): if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position) logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)
......
import os import os
os.environ["CUDA_VISIBLE_DEVICES"]="1" # os.environ["CUDA_VISIBLE_DEVICES"]="1,2"
# add path # add path
import sys import sys
current_path = os.path.abspath(os.path.dirname(__file__)) current_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(current_path+"/../..") sys.path.append(current_path+"/../..")
import pycuda.autoinit
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import numpy as np import numpy as np
# from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMarlin # from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin
# from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch # from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
import torch import torch
import KTransformersOps import KTransformersOps
...@@ -18,40 +15,44 @@ import time ...@@ -18,40 +15,44 @@ import time
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
) )
import os
# CUDA_LAUNCH_BLOCKING=1
os.environ["CUDA_LAUNCH_BLOCKING"]="1"
gguf_config = GGUFLoader("/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m") gguf_config = GGUFLoader("/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m")
model_name = "/data/Qwen2-57B-A14B-Instruct" model_name = "/data/Qwen2-57B-A14B-Instruct"
key = "blk.0."
target = "ffn_down_exps.weight" # Q4k
key = "blk.1."
target = "attn_q.weight"
t1 = time.time() t1 = time.time()
q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu") q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu")
# q_weight_cpu = torch.from_numpy(q_weight_cpu) # q_weight_cpu = torch.from_numpy(q_weight_cpu)
t2 = time.time() t2 = time.time()
q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda") q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda:0")
t3 = time.time() t3 = time.time()
print() print()
allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu().to(torch.float32), atol=1e-6) allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu(), atol=1e-6)
print(f"Q6k {key+target}") print(f"Q4k {key+target}")
print("load gguf tensor from cpu cost: ", t2-t1) print("load gguf tensor from cpu cost: ", t2-t1)
print("load gguf tensor from gpu cost: ", t3-t2) print("load gguf tensor from gpu cost: ", t3-t2)
print("allclose: ", allclose) print("allclose: ", allclose)
key = "blk.1." # Q6k
target = "ffn_up_shexp.weight" key = "blk.0."
target = "ffn_down_exps.weight"
t1 = time.time() t1 = time.time()
q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu") q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu")
# q_weight_cpu = torch.from_numpy(q_weight_cpu)
t2 = time.time() t2 = time.time()
q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda") q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda:0")
t3 = time.time() t3 = time.time()
print() print()
allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu(), atol=1e-6) allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu().to(torch.float32), atol=1e-6)
print(f"Q4k {key+target}") print(f"Q6k {key+target}")
print("load gguf tensor from cpu cost: ", t2-t1) print("load gguf tensor from cpu cost: ", t2-t1)
print("load gguf tensor from gpu cost: ", t3-t2) print("load gguf tensor from gpu cost: ", t3-t2)
print("allclose: ", allclose) print("allclose: ", allclose)
...@@ -7,11 +7,11 @@ import pycuda.autoinit ...@@ -7,11 +7,11 @@ import pycuda.autoinit
import pycuda.driver as cuda import pycuda.driver as cuda
from pycuda.compiler import SourceModule from pycuda.compiler import SourceModule
import numpy as np import numpy as np
from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMarlin from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin
from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch
from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k
import torch import torch
import CudaOps import KTransformersOps
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
import time import time
from transformers import ( from transformers import (
......
...@@ -21,6 +21,7 @@ class CUDAGraphRunner: ...@@ -21,6 +21,7 @@ class CUDAGraphRunner:
position_ids, position_ids,
cache_position, cache_position,
past_key_values, past_key_values,
main_device,
**kwargs, **kwargs,
) -> None: ) -> None:
assert self.graph is None assert self.graph is None
...@@ -29,15 +30,24 @@ class CUDAGraphRunner: ...@@ -29,15 +30,24 @@ class CUDAGraphRunner:
self.graph = torch.cuda.CUDAGraph() self.graph = torch.cuda.CUDAGraph()
#self.graph.enable_debug_mode() #self.graph.enable_debug_mode()
self.model = model self.model = model
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to("cuda") inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device)
with torch.cuda.graph(self.graph): # torch.cuda.set_device can't set "cuda", must have a index
if main_device == "cuda":
main_device = "cuda:0"
torch.cuda.set_device(main_device)
self.main_device = main_device
capture_stream = torch.cuda.Stream()
with torch.cuda.graph(self.graph, stream = capture_stream):
logits=model(inputs_embeds=inputs_embeds, logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cache_position=cache_position, cache_position=cache_position,
past_key_values=past_key_values, past_key_values=past_key_values,
**kwargs)[0] **kwargs)[0]
capture_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.set_device(main_device)
torch.cuda.set_stream(capture_stream)
past_key_values.change_seq_length(-1) past_key_values.change_seq_length(-1)
torch.cuda.synchronize() torch.cuda.synchronize(self.main_device)
#self.graph.debug_dump("cuda_graph_hooked.dot") #self.graph.debug_dump("cuda_graph_hooked.dot")
# Save the input and output buffers. # Save the input and output buffers.
...@@ -65,7 +75,7 @@ class CUDAGraphRunner: ...@@ -65,7 +75,7 @@ class CUDAGraphRunner:
#print("begin replay") #print("begin replay")
#time.sleep(1) #time.sleep(1)
self.graph.replay() self.graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize(self.main_device)
# Return the output tensor. # Return the output tensor.
return self.output_buffers["logits"] return self.output_buffers["logits"]
......
...@@ -5,8 +5,8 @@ Description : ...@@ -5,8 +5,8 @@ Description :
Author : Azure-Tang, Boxin Zhang, chenht2022 Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54 Date : 2024-07-26 08:48:54
Version : 1.0.0 Version : 1.0.0
LastEditors : Azure LastEditors : kkk1nak0
LastEditTime : 2024-07-26 09:28:25 LastEditTime : 2024-08-12 07:21:55
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer Copyright (c) 2024 Thomas Germer
...@@ -18,6 +18,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ...@@ -18,6 +18,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import struct import struct
import warnings import warnings
import numpy as np import numpy as np
import re
import numpy.typing as npt import numpy.typing as npt
from typing import Sequence from typing import Sequence
import os import os
...@@ -168,6 +169,7 @@ class GGUFLoader: ...@@ -168,6 +169,7 @@ class GGUFLoader:
self.tensor_file_map = {} self.tensor_file_map = {}
self.file_data_map = {} self.file_data_map = {}
self.gguf_file_meta = {} self.gguf_file_meta = {}
self.tensor_device_map = {}
# Walk through all the .gguf files in the directory # Walk through all the .gguf files in the directory
for root, dirs, files in os.walk(gguf_path): for root, dirs, files in os.walk(gguf_path):
...@@ -292,8 +294,19 @@ class GGUFLoader: ...@@ -292,8 +294,19 @@ class GGUFLoader:
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values) values = torch.from_numpy(values)
return values.view(shape[::-1]) values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count']
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
.swapaxes(1, 2)
.reshape(values.shape))
elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count_kv']
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
.swapaxes(1, 2)
.reshape(values.shape))
return values
def read_value(f, data_type): def read_value(f, data_type):
if data_type == DATA_TYPES["string"]: if data_type == DATA_TYPES["string"]:
...@@ -377,8 +390,14 @@ def dequantize_q2_k(data): ...@@ -377,8 +390,14 @@ def dequantize_q2_k(data):
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q2_k_gpu(data): def dequantize_q2_k_gpu(data, device:str ="cuda"):
raise NotImplementedError() block_size = GGML_BLOCK_SIZES["Q2_K"]
data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q2_k(data, block_size, device)
def dequantize_q3_k(data): def dequantize_q3_k(data):
# C implementation # C implementation
...@@ -422,8 +441,14 @@ def dequantize_q3_k(data): ...@@ -422,8 +441,14 @@ def dequantize_q3_k(data):
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
], axis=1) ], axis=1)
def dequantize_q3_k_gpu(data): def dequantize_q3_k_gpu(data, device:str ="cuda"):
raise NotImplementedError() block_size = GGML_BLOCK_SIZES["Q3_K"]
data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q3_k(data, block_size, device)
def dequantize_q4_k(data): def dequantize_q4_k(data):
# C implementation # C implementation
...@@ -511,9 +536,14 @@ def dequantize_q5_k(data): ...@@ -511,9 +536,14 @@ def dequantize_q5_k(data):
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1) ], axis=1)
def dequantize_q5_k_gpu(data): def dequantize_q5_k_gpu(data, device:str ="cuda"):
raise NotImplementedError() block_size = GGML_BLOCK_SIZES["Q5_K"]
data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q5_k(data, block_size, device)
def dequantize_q6_k(data): def dequantize_q6_k(data):
# C implementation # C implementation
...@@ -570,7 +600,7 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"): ...@@ -570,7 +600,7 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) data = torch.from_numpy(data)
return KTransformersOps.dequantize_q6_k(data, 210, device) return KTransformersOps.dequantize_q6_k(data, block_size, device)
def dequantize_q4_0(data): def dequantize_q4_0(data):
# C implementation # C implementation
...@@ -679,7 +709,34 @@ GGML_DEQUANTIZE_GPU = { ...@@ -679,7 +709,34 @@ GGML_DEQUANTIZE_GPU = {
"Q6_K": dequantize_q6_k_gpu, "Q6_K": dequantize_q6_k_gpu,
} }
def translate_name_to_gguf_mixtral(name):
replacement_template = {
"w1.weight": "ffn_gate",
"w2.weight": "ffn_down",
"w3.weight": "ffn_up"
}
pattern = re.compile(r"model.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(w\d\.weight)")
def replace_match(match):
blk_id = match.group(1)
expert_id = match.group(2)
weight_type = match.group(3)
if weight_type in replacement_template:
return f"blk.{blk_id}.{replacement_template[weight_type]}.{expert_id}.weight"
else:
return match.group(0)
new_name = re.sub(pattern, replace_match, name)
return new_name
def translate_name_to_gguf(name): def translate_name_to_gguf(name):
name = translate_name_to_gguf_mixtral(name)
name = name.replace("lm_head.", "output.") name = name.replace("lm_head.", "output.")
name = name.replace("model.embed_tokens.", "token_embd.") name = name.replace("model.embed_tokens.", "token_embd.")
name = name.replace("model.norm.", "output_norm.") name = name.replace("model.norm.", "output_norm.")
...@@ -716,9 +773,14 @@ def translate_name_to_gguf(name): ...@@ -716,9 +773,14 @@ def translate_name_to_gguf(name):
name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps") name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps")
name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps") name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps")
name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.")
name = name.replace(".block_sparse_moe.experts", "")
return name return name
if __name__ == '__main__': if __name__ == '__main__':
gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH' gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'
loader = GGUFLoader(gguf_path) loader = GGUFLoader(gguf_path)
loader.load_gguf_tensor('token_embd.weight') loader.load_gguf_tensor('token_embd.weight')
...@@ -39,6 +39,22 @@ def set_param(module: nn.Module, name: str, weights: torch.Tensor): ...@@ -39,6 +39,22 @@ def set_param(module: nn.Module, name: str, weights: torch.Tensor):
param.unsqueeze_(0) param.unsqueeze_(0)
setattr(module, name, param) setattr(module, name, param)
def get_device(gguf_module_key:str, device_map:dict):
if gguf_module_key in device_map:
return device_map[gguf_module_key]["generate_device"]
else:
return "cuda"
def get_all_used_cuda_device(device_map:dict):
all_device_list = set()
for key in device_map:
all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
if "cpu" in all_device_list:
all_device_list.remove("cpu")
all_device_list = list(all_device_list)
return all_device_list
def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""): def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""):
prefix = prefix.replace("orig_module.", "") prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
...@@ -47,18 +63,19 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str ...@@ -47,18 +63,19 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
for name, param in local_state.items(): for name, param in local_state.items():
key = prefix + name key = prefix + name
translated_key = translate_name_to_gguf(key) translated_key = translate_name_to_gguf(key)
print("default loading weights", key, translated_key)
if translated_key in gguf_loader.tensor_file_map: if translated_key in gguf_loader.tensor_file_map:
target_dtype = torch.get_default_dtype() target_dtype = torch.get_default_dtype()
device = "cpu" if "embd" in translated_key else "cuda" device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
# device = "cpu" if "embd" in translated_key else "cuda"
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype) weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
set_param(module, name, weights) set_param(module, name, weights)
del weights del weights
else: else:
#print(load_config.tensor_file_map.keys()) #print(load_config.tensor_file_map.keys())
raise Exception(f"can't fand {translated_key} in GGUF file!") raise Exception(f"can't find {translated_key} in GGUF file!")
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_when_injected:bool = False, only_load_injected:bool = False): def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
# print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}") # print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}")
if not isinstance(module, base_operator.BaseInjectedModule): if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix) load_cur_state_dict(module, gguf_loader, prefix)
...@@ -66,29 +83,36 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_whe ...@@ -66,29 +83,36 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_whe
load_weights(child, gguf_loader, prefix+name+".") load_weights(child, gguf_loader, prefix+name+".")
else: else:
module.load() module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True):
import os import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape batch_size, seq_length = inputs.shape
torch_device = inputs.device device_map = model.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
inputs = inputs.to(torch_device)
all_cuda_device = get_all_used_cuda_device(device_map)
tokens = [] tokens = []
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values): def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
logits = cuda_graph_runner(cur_token, position_ids, cache_position) if use_cuda_graph:
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
torch.cuda.set_device(torch_device)
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
past_key_values.change_seq_length(1) past_key_values.change_seq_length(1)
""" for device in all_cuda_device:
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to("cuda") torch.cuda.synchronize(device)
custom_stream = torch.cuda.Stream()
with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds = inputs_embeds,
position_ids = position_ids,
cache_position = cache_position,
past_key_values = past_key_values,
return_dict = False, use_cache = True) [0]
"""
torch.cuda.synchronize()
#print(logits) #print(logits)
next_token_scores = logits_warper(inputs, logits[:, -1, :]) next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample: if generation_config.do_sample:
...@@ -97,11 +121,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): ...@@ -97,11 +121,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
else: else:
next_token = torch.argmax(next_token_scores, dim=-1) next_token = torch.argmax(next_token_scores, dim=-1)
return next_token return next_token
torch.cuda.set_device(torch_device)
with torch.no_grad(): with torch.no_grad():
stream = TextStreamer(tokenizer) stream = TextStreamer(tokenizer)
past_key_values = StaticCache( past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = torch_device, dtype = model.dtype config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
) )
cache_position = torch.arange(seq_length, device=torch_device) cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros( generated_ids = torch.zeros(
...@@ -111,21 +136,21 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): ...@@ -111,21 +136,21 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
past_key_values.cur_idx=cache_position past_key_values.cur_idx=cache_position
start_time = time.time() start_time = time.time()
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to("cuda") inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
logits = model( logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone() )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
generation_config, model_kwargs = model._prepare_generation_config( generation_config, model_kwargs = model._prepare_generation_config(
None, max_length=max_new_tokens, None, max_length=max_new_tokens,
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config
) )
try: # transformers==4.43 try: # transformers==4.43
logits_warper = ( logits_warper = (
model._get_logits_warper(generation_config,device=inputs.device) if generation_config.do_sample else None model._get_logits_warper(generation_config,device=inputs.device)
) )
except: except:
logits_warper = ( logits_warper = (
model._get_logits_warper(generation_config) if generation_config.do_sample else None model._get_logits_warper(generation_config)
) )
next_token_scores = logits_warper(inputs, logits[:, -1, :]) next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample: if generation_config.do_sample:
...@@ -137,7 +162,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): ...@@ -137,7 +162,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
prefill_count = seq_length prefill_count = seq_length
prefill_time = first_token_time prefill_time = first_token_time
print(stream.put(next_token.item()), end="", flush=True) print(stream.put(next_token.item()), end="", flush=True)
generated_ids[:, seq_length] = next_token generated_ids[:, seq_length] = next_token
tokens.append(next_token) tokens.append(next_token)
...@@ -145,12 +169,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): ...@@ -145,12 +169,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
cache_position = torch.tensor([seq_length], device=torch_device) cache_position = torch.tensor([seq_length], device=torch_device)
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
seq_length += 1 seq_length += 1
cuda_graph_runner = CUDAGraphRunner() if use_cuda_graph:
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, return_dict=False, use_cache=True) cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
else:
cuda_graph_runner = None
start_time = time.time() start_time = time.time()
for _ in range(1, max_new_tokens): for _ in range(1, max_new_tokens):
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values) next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int() generated_ids[:, cache_position] = next_token.int()
tokens.append(next_token.int()) tokens.append(next_token.int())
...@@ -163,6 +191,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): ...@@ -163,6 +191,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
print(stream.put(next_token.item()), end="", flush=True) print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1 cache_position += 1
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
total_time = time.time() - start_time total_time = time.time() - start_time
tokens_generated = len(tokens) tokens_generated = len(tokens)
......
...@@ -6,7 +6,7 @@ Author : chenxl ...@@ -6,7 +6,7 @@ Author : chenxl
Date : 2024-07-27 16:15:27 Date : 2024-07-27 16:15:27
Version : 1.0.0 Version : 1.0.0
LastEditors : chenxl LastEditors : chenxl
LastEditTime : 2024-08-08 02:45:15 LastEditTime : 2024-08-14 16:36:19
Adapted from: Adapted from:
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao. Copyright (c) 2023, Tri Dao.
...@@ -299,6 +299,15 @@ setup( ...@@ -299,6 +299,15 @@ setup(
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp', 'ktransformers/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu' 'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
]) ],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
]
}
)
] ]
) )
...@@ -94,7 +94,6 @@ static const struct GemmFuncs { ...@@ -94,7 +94,6 @@ static const struct GemmFuncs {
#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))) #if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX2__) #if defined(__AVX2__)
#if defined(__AVX512F__) #if defined(__AVX512F__)
printf("__AVX512F__\n");
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__) #if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-) // AMD Zen4+ (2023-)
sgemm = llamafile_sgemm_amd_zen4; sgemm = llamafile_sgemm_amd_zen4;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment