Unverified Commit 77a34c28 authored by UnicornChan's avatar UnicornChan Committed by GitHub
Browse files

Merge pull request #36 from kvcache-ai/develop-0.1.2

Release v0.1.2
parents 44f57270 395cd3e7
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h" #include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h" #include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct MOEConfig { struct MOEConfig {
int expert_num; int expert_num;
...@@ -48,13 +49,13 @@ struct MOEConfig { ...@@ -48,13 +49,13 @@ struct MOEConfig {
class MOE { class MOE {
public: public:
MOE(MOEConfig); MOE(MOEConfig);
~MOE();
void warm_up(Backend* backend); void warm_up(Backend* backend);
void forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); void forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend);
void forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); void forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend);
void forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); void forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend);
private: private:
static uint8_t* buffer_;
MOEConfig config_; MOEConfig config_;
void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
......
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-08-05 04:49:08
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-05 09:21:29
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "shared_mem_buffer.h"
#include <cstdio>
SharedMemBuffer::SharedMemBuffer() {
buffer_ = nullptr;
size_ = 0;
}
SharedMemBuffer::~SharedMemBuffer() {
if (buffer_) {
free(buffer_);
}
}
void SharedMemBuffer::alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests) {
uint64_t size = 0;
for (auto& request : requests) {
size += request.second;
}
if (size > size_) {
if (buffer_) {
free(buffer_);
}
buffer_ = malloc(size);
size_ = size;
for (auto& obj_requests : hist_requests_) {
for (auto& requests : obj_requests.second) {
arrange(requests);
}
}
}
arrange(requests);
hist_requests_[object].push_back(requests);
}
void SharedMemBuffer::dealloc(void* object) {
hist_requests_.erase(object);
}
void SharedMemBuffer::arrange(std::vector<std::pair<void**, uint64_t>> requests) {
uint64_t offset = 0;
for (auto& request : requests) {
*(request.first) = (uint8_t*)buffer_ + offset;
offset += request.second;
}
}
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-08-05 04:49:08
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-05 06:36:41
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_SHAREDMEMBUFFER_H
#define CPUINFER_SHAREDMEMBUFFER_H
#include <cstdint>
#include <cstdlib>
#include <map>
#include <vector>
class SharedMemBuffer {
public:
SharedMemBuffer();
~SharedMemBuffer();
void alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests);
void dealloc(void* object);
private:
void* buffer_;
uint64_t size_;
std::map<void*, std::vector<std::vector<std::pair<void**, uint64_t>>>> hist_requests_;
void arrange(std::vector<std::pair<void**, uint64_t>> requests);
};
static SharedMemBuffer shared_mem_buffer;
#endif
\ No newline at end of file
...@@ -31,18 +31,21 @@ import fire ...@@ -31,18 +31,21 @@ import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
custom_models = { custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
} }
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
default_optimize_rules ={ default_optimize_rules ={
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
"MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
} }
def local_chat( def local_chat(
...@@ -50,7 +53,8 @@ def local_chat( ...@@ -50,7 +53,8 @@ def local_chat(
optimize_rule_path: str = None, optimize_rule_path: str = None,
gguf_path: str = None, gguf_path: str = None,
max_new_tokens: int = 1000, max_new_tokens: int = 1000,
cpu_infer: int = Config().cpu_infer cpu_infer: int = Config().cpu_infer,
use_cuda_graph: bool = True,
): ):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -64,6 +68,8 @@ def local_chat( ...@@ -64,6 +68,8 @@ def local_chat(
print("using custom modeling_xxx.py.") print("using custom modeling_xxx.py.")
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow. if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config) model = custom_models[config.architectures[0]](config)
else: else:
model = AutoModelForCausalLM.from_config( model = AutoModelForCausalLM.from_config(
...@@ -100,7 +106,6 @@ def local_chat( ...@@ -100,7 +106,6 @@ def local_chat(
while True: while True:
content = input("Chat: ") content = input("Chat: ")
# if content is num
if content == "": if content == "":
content = "Please write a piece of quicksort code in C++." content = "Please write a piece of quicksort code in C++."
...@@ -109,7 +114,7 @@ def local_chat( ...@@ -109,7 +114,7 @@ def local_chat(
messages, add_generation_prompt=True, return_tensors="pt" messages, add_generation_prompt=True, return_tensors="pt"
) )
torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens) generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(local_chat) fire.Fire(local_chat)
\ No newline at end of file
...@@ -22,13 +22,14 @@ class StaticCache(transformers.StaticCache): ...@@ -22,13 +22,14 @@ class StaticCache(transformers.StaticCache):
The maximum batch size with which the model will be used. The maximum batch size with which the model will be used.
max_cache_len (`int`): max_cache_len (`int`):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device`): device (`torch.device` or `dict`):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
If a `dict`, it should contain the `device` key with the device name as the value.
dtype (*optional*, defaults to `torch.float32`): dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
""" """
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
Cache.__init__(self) Cache.__init__(self)
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
...@@ -57,11 +58,15 @@ class StaticCache(transformers.StaticCache): ...@@ -57,11 +58,15 @@ class StaticCache(transformers.StaticCache):
self.past_tokens = [] self.past_tokens = []
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
for _ in range(self.num_hidden_layers): for idx in range(self.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. # breaks when updating the cache.
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=device) if isinstance(device, dict):
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=device) target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
else:
target_device = device
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache) torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache) self.key_cache.append(new_layer_key_cache)
......
...@@ -1048,7 +1048,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention): ...@@ -1048,7 +1048,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
""" """
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores. first unpad the input, then computes the attention scores and pad the final attention scores.
Args: # Args:
query_states (`torch.Tensor`): query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`): key_states (`torch.Tensor`):
...@@ -1245,12 +1245,14 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1245,12 +1245,14 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_position=cache_position, cache_position=cache_position,
**kwargs, **kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)
......
This diff is collapsed.
...@@ -10,6 +10,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule ...@@ -10,6 +10,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
def __init__(self, def __init__(self,
...@@ -17,12 +18,16 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): ...@@ -17,12 +18,16 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", # device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.orig_module.__init__(orig_module.dim, self.orig_module.__init__(orig_module.dim,
orig_module.max_position_embeddings, orig_module.max_position_embeddings,
orig_module.base) orig_module.base)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self): def load(self):
self.orig_module.__init__(self.orig_module.dim, self.orig_module.__init__(self.orig_module.dim,
...@@ -36,9 +41,11 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): ...@@ -36,9 +41,11 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", # device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.orig_module.__init__(orig_module.dim, self.orig_module.__init__(orig_module.dim,
orig_module.max_position_embeddings, orig_module.max_position_embeddings,
orig_module.base, orig_module.base,
...@@ -49,13 +56,15 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): ...@@ -49,13 +56,15 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
orig_module.beta_slow, orig_module.beta_slow,
orig_module.mscale, orig_module.mscale,
orig_module.mscale_all_dim) orig_module.mscale_all_dim)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self): def load(self):
self.orig_module.__init__(self.orig_module.dim, self.orig_module.__init__(self.orig_module.dim,
self.orig_module.max_position_embeddings, self.orig_module.max_position_embeddings,
self.orig_module.base, self.orig_module.base,
self.device, self.generate_device,
self.orig_module.scaling_factor, self.orig_module.scaling_factor,
self.orig_module.original_max_position_embeddings, self.orig_module.original_max_position_embeddings,
self.orig_module.beta_fast, self.orig_module.beta_fast,
......
...@@ -15,7 +15,7 @@ from ktransformers.util.custom_gguf import GGUFLoader ...@@ -15,7 +15,7 @@ from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
class DeepseekV2AttentionInjected(BaseInjectedModule, DeepseekV2Attention): class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, def __init__(self,
......
import sys, os
from typing import Any
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from ktransformers.server.config.config import Config
class CPUInfer:
cpu_infer = None
def __init__(self, cpu_infer:int = Config().cpu_infer):
if CPUInfer.cpu_infer is None:
CPUInfer.cpu_infer = cpuinfer_ext.CPUInfer(cpu_infer)
def __getattribute__(self, __name: str) -> Any:
return CPUInfer.cpu_infer.__getattribute__(__name)
def __setattr__(self, __name: str, __value: Any) -> None:
return CPUInfer.cpu_infer.__setattr__(__name, __value)
\ No newline at end of file
This diff is collapsed.
...@@ -6,13 +6,14 @@ Author : Azure-Tang, Boxin Zhang ...@@ -6,13 +6,14 @@ Author : Azure-Tang, Boxin Zhang
Date : 2024-07-25 11:25:24 Date : 2024-07-25 11:25:24
Version : 0.1.0 Version : 0.1.0
LastEditors : Azure LastEditors : Azure
LastEditTime : 2024-07-26 09:27:53 LastEditTime : 2024-08-14 14:57:04
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
import ctypes
import torch import torch
from torch import nn from torch import Tensor, nn
import KTransformersOps import KTransformersOps
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
...@@ -25,10 +26,16 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl ...@@ -25,10 +26,16 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from ktransformers.operators.cpuinfer import CPUInfer
from ktransformers.server.config.config import Config
#class KLinearBase(BaseInjectedModule, ABC):
#class QuantizedLinearBase(BaseInjectedModule, ABC): class KLinearBase(ABC):
class QuantizedLinearBase(ABC):
def __init__( def __init__(
self, self,
key: str, key: str,
...@@ -99,7 +106,7 @@ class QuantizedLinearBase(ABC): ...@@ -99,7 +106,7 @@ class QuantizedLinearBase(ABC):
pass pass
class QuantizedLinearTorch(QuantizedLinearBase): class KLinearTorch(KLinearBase):
def __init__( def __init__(
self, self,
key: str, key: str,
...@@ -118,6 +125,7 @@ class QuantizedLinearTorch(QuantizedLinearBase): ...@@ -118,6 +125,7 @@ class QuantizedLinearTorch(QuantizedLinearBase):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype dtype = x.dtype
out_device = x.device out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
x = x.to(device=self.device, dtype=self.dtype) x = x.to(device=self.device, dtype=self.dtype)
x = x @ self.w x = x @ self.w
if self.has_bias: if self.has_bias:
...@@ -150,7 +158,7 @@ class QuantizedLinearTorch(QuantizedLinearBase): ...@@ -150,7 +158,7 @@ class QuantizedLinearTorch(QuantizedLinearBase):
self.bias = None self.bias = None
class QuantizedLinearMarlin(QuantizedLinearBase): class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor marlin_q_w: torch.Tensor
marlin_s: torch.Tensor marlin_s: torch.Tensor
g_idx: torch.Tensor g_idx: torch.Tensor
...@@ -176,7 +184,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase): ...@@ -176,7 +184,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
self.act_order = act_order self.act_order = act_order
self.is_k_full = is_k_full self.is_k_full = is_k_full
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = "cuda"): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
if w is None: w = self.load_weight(device=device) if w is None: w = self.load_weight(device=device)
...@@ -200,7 +208,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase): ...@@ -200,7 +208,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
weight, self.num_bits, self.group_size, self.act_order weight, self.num_bits, self.group_size, self.act_order
) )
self.workspace = MarlinWorkspace( self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
) )
self.marlin_q_w = marlin_q_w self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s self.marlin_s = marlin_s
...@@ -244,35 +252,137 @@ class QuantizedLinearMarlin(QuantizedLinearBase): ...@@ -244,35 +252,137 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
self.sort_indices = None self.sort_indices = None
self.workspace = None self.workspace = None
class KLinearCPUInfer(KLinearBase):
CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cpu",
out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu.
stride = 16,
group_max_len = 1024,
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.dtype = torch.get_default_dtype()
self.w = None
self.has_bias = False
self.stride = stride
self.group_max_len = group_max_len
self.out_device = out_device
def forward(self, x: torch.Tensor) -> torch.Tensor:
origin_shape = x.shape # [batch_size, q_len, hidden_size]
if origin_shape[1] == 1:
out_device = x.device
self.input_tensor_cpu.copy_(x, non_blocking=True)
qlen = origin_shape[1]
KLinearCPUInfer.CPU_INFER.submit_with_cuda_stream(
torch.cuda.current_stream().cuda_stream,
self.linear.forward(
qlen,
self.input_tensor_cpu.data_ptr(),
self.output_cpu.data_ptr()
)
)
KLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
self.output_gpu.copy_(self.output_cpu, non_blocking=True)
if self.has_bias:
self.output_gpu += self.bias
return self.output_gpu
else:
dtype = x.dtype
out_device = x.device
x = x.to(device=self.device)
qlen = origin_shape[1]
output_shape = (*origin_shape[:-1], self.out_features)
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
KLinearCPUInfer.CPU_INFER.submit(
self.linear.forward(
qlen,
x.data_ptr(),
output.data_ptr()
)
)
KLinearCPUInfer.CPU_INFER.sync()
if self.has_bias:
output = output + self.bias
output = output.to(dtype=dtype, device=out_device)
return output
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None, warmup:bool = True):
print(f"loading {self.key} to {self.device} using CPUInfer")
if device is None: device = self.device
self.load_weights(w=w, device=device)
if self.bias is not None:
self.has_bias = True
self.bias = self.bias.to(device)
weight_ptr = ctypes.addressof(
ctypes.cast(self.weight.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
)
config = cpuinfer_ext.linear.LinearConfig(self.in_features, self.out_features, self.stride, self.group_max_len, weight_ptr, self.weight_type, 30)
self.linear = cpuinfer_ext.linear.Linear(config)
if warmup:
KLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up())
KLinearCPUInfer.CPU_INFER.sync()
self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device="cpu", pin_memory=True)
self.output_cpu = torch.zeros((1, 1, self.out_features), device="cpu", pin_memory=True, dtype=torch.bfloat16)
self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device)
def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu"):
if self.key + ".weight" in self.gguf_loader.tensor_info:
if self.key + ".bias" in self.gguf_loader.tensor_file_map:
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
self.bias = self.gguf_loader.load_gguf_tensor(self.key + ".bias", device=device)
else:
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
self.bias = None
else:
raise ValueError(f"Linear {self.key} not found in gguf_loader")
def unload(self):
if self.w is not None:
self.w = None
if self.has_bias:
self.bias = None
LINEAR_MAP = { LINEAR_MAP = {
"QuantizedLinearMarlin": QuantizedLinearMarlin, "KLinearMarlin": KLinearMarlin,
"QuantizedLinearTorch": QuantizedLinearTorch, "KLinearTorch": KLinearTorch,
"QuantizedLinearTorch": QuantizedLinearTorch, "KLinearCPUInfer": KLinearCPUInfer
} }
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase): class KTransformersLinear(BaseInjectedModule, KLinearBase):
def __init__( def __init__(
self, self,
key: str, key: str,
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", # device: str = "cuda",
generate_device: str = "cuda", generate_device: str = "cuda",
generate_op: str| None = "QuantizedLinearMarlin", generate_op: str| None = "KLinearMarlin",
prefill_device: str = "cuda", prefill_device: str = "cuda",
prefill_op: str| None = "QuantizedLinearTorch", prefill_op: str| None = "KLinearTorch",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
# build all the linear operators # build all the linear operators
if prefill_op is not None: if prefill_op is not None:
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
if prefill_op == "QuantizedLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): if prefill_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0):
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using QuantizedLinearTorch instead.") print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}") print(f"module info: key:{key} orig_module:{orig_module}")
self.prefill_linear = QuantizedLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.prefill_linear = KLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else: else:
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else: else:
...@@ -280,16 +390,15 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase): ...@@ -280,16 +390,15 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
if generate_op is not None: if generate_op is not None:
assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported"
if generate_op == "QuantizedLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): if generate_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0):
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using QuantizedLinearTorch instead.") print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}") print(f"module info: key:{key} orig_module:{orig_module}")
self.generate_op = "QuantizedLinearTorch" self.generate_op = "KLinearTorch"
self.generate_linear = QuantizedLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs) self.generate_linear = KLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs)
else: else:
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
else: else:
self.generate_linear = None self.generate_linear = None
self.device = device
self.mode = InferenceState.UNLOAD self.mode = InferenceState.UNLOAD
def forward(self, x): def forward(self, x):
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment