Commit ca7366d2 authored by Azure's avatar Azure
Browse files

Merge remote-tracking branch 'upstream/develop-0.2.2' into support-fp8

parents 581a524f cdb6f896
...@@ -21,6 +21,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl ...@@ -21,6 +21,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
MarlinWorkspace, MarlinWorkspace,
marlin_quantize, marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MAX_PARALLEL,
) )
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
...@@ -65,6 +66,8 @@ class KLinearBase(ABC): ...@@ -65,6 +66,8 @@ class KLinearBase(ABC):
self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0]
self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1]
self.loaded = False # for lm_head pre-load, TODO: use new way to do lm_head pre-load when layer wise prefill.
@abstractmethod @abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
pass pass
...@@ -141,6 +144,7 @@ class KLinearTorch(KLinearBase): ...@@ -141,6 +144,7 @@ class KLinearTorch(KLinearBase):
return x return x
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device if device is None: device = self.device
if w is None: w = self.load_weight(device=device) if w is None: w = self.load_weight(device=device)
# else: self.out_features = w.shape[0], self.in_features = w.shape[1] # else: self.out_features = w.shape[0], self.in_features = w.shape[1]
...@@ -164,6 +168,7 @@ class KLinearTorch(KLinearBase): ...@@ -164,6 +168,7 @@ class KLinearTorch(KLinearBase):
self.weight = self.weight.to(device) self.weight = self.weight.to(device)
if self.has_bias: if self.has_bias:
self.bias = self.bias.to(device) self.bias = self.bias.to(device)
self.loaded = True
def unload(self): def unload(self):
if self.weight is not None: if self.weight is not None:
...@@ -251,20 +256,36 @@ class KLinearMarlin(KLinearBase): ...@@ -251,20 +256,36 @@ class KLinearMarlin(KLinearBase):
self.group_size = group_size self.group_size = group_size
self.act_order = act_order self.act_order = act_order
self.is_k_full = is_k_full self.is_k_full = is_k_full
self.padding = False
self.orin_in_features = self.in_features
self.orin_out_features = self.out_features
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self.padding = True
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
self.k = self.in_features
self.n = self.out_features
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
#if self.in_features * self.out_features:
if w is None: if w is None:
w = self.load_weight(device=device) w = self.load_weight(device=device)
if isinstance(w, nn.Parameter): if isinstance(w, nn.Parameter):
# pad weight # pad weight
weight = w.view(self.out_features, self.in_features).T weight = w.view(self.orin_out_features, self.orin_in_features).T
self.has_bias = False self.has_bias = False
elif isinstance(w, tuple): elif isinstance(w, tuple):
w = list(w) w = list(w)
weight = w[0].view(self.out_features, self.in_features).T weight = w[0].view(self.orin_out_features, self.orin_in_features).T
self.bias = w[1].view(self.orin_out_features)
self.bias = w[1] self.bias = w[1]
self.has_bias = True self.has_bias = True
else: else:
...@@ -272,8 +293,14 @@ class KLinearMarlin(KLinearBase): ...@@ -272,8 +293,14 @@ class KLinearMarlin(KLinearBase):
weight = weight.to(device) weight = weight.to(device)
if self.has_bias: if self.has_bias:
self.bias = self.bias.to(device) self.bias = self.bias.to(device)
if self.padding:
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
weight = padded_weight
# Pack Marlin linear # Pack Marlin linear
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
weight, self.num_bits, self.group_size, self.act_order weight, self.num_bits, self.group_size, self.act_order
) )
self.workspace = MarlinWorkspace( self.workspace = MarlinWorkspace(
...@@ -286,6 +313,7 @@ class KLinearMarlin(KLinearBase): ...@@ -286,6 +313,7 @@ class KLinearMarlin(KLinearBase):
self.sort_indices = sort_indices self.sort_indices = sort_indices
self.k = weight.shape[0] self.k = weight.shape[0]
self.n = weight.shape[1] self.n = weight.shape[1]
self.loaded = True
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# Only support input x as BF16 and FP16 # Only support input x as BF16 and FP16
...@@ -293,6 +321,11 @@ class KLinearMarlin(KLinearBase): ...@@ -293,6 +321,11 @@ class KLinearMarlin(KLinearBase):
orig_shape = list(x.shape) orig_shape = list(x.shape)
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.reshape(-1, orig_shape[-1]) x = x.reshape(-1, orig_shape[-1])
x = x.reshape(-1, x.shape[-1])
if self.padding:
padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)
padding_input[:,:self.orin_in_features] = x
x = padding_input
marlin_s = self.marlin_s.to(x.dtype) marlin_s = self.marlin_s.to(x.dtype)
x = KTransformersOps.gptq_marlin_gemm( x = KTransformersOps.gptq_marlin_gemm(
x, x,
...@@ -307,6 +340,11 @@ class KLinearMarlin(KLinearBase): ...@@ -307,6 +340,11 @@ class KLinearMarlin(KLinearBase):
x.shape[-1], x.shape[-1],
self.is_k_full, self.is_k_full,
) )
if self.padding:
x = x[:,:self.orin_out_features]
orig_shape[-1] = self.orin_out_features
else:
orig_shape[-1] = self.out_features
if self.has_bias: if self.has_bias:
x = x + self.bias x = x + self.bias
orig_shape[-1] = self.n orig_shape[-1] = self.n
...@@ -450,24 +488,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): ...@@ -450,24 +488,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
# build all the linear operators # build all the linear operators
if prefill_op is not None: if prefill_op is not None:
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
if prefill_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}")
self.prefill_linear = KLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else:
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else: else:
self.prefill_linear = None self.prefill_linear = None
if generate_op is not None: if generate_op is not None:
assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported"
if generate_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}")
self.generate_op = "KLinearTorch"
self.generate_linear = KLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs)
else:
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
else: else:
self.generate_linear = None self.generate_linear = None
self.mode = InferenceState.UNLOAD self.mode = InferenceState.UNLOAD
......
...@@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo ...@@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
gguf_loader=GGUFLoader(gguf_path) gguf_loader=GGUFLoader(gguf_path)
with torch.device("meta"): with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader) inject(module, optimize_config, model_config, gguf_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, gguf_loader, "lm_head.")
load_weights(module, gguf_loader) load_weights(module, gguf_loader)
module.gguf_loader = gguf_loader module.gguf_loader = gguf_loader
del_meta(module) del_meta(module)
......
...@@ -219,8 +219,20 @@ ...@@ -219,8 +219,20 @@
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)|(^lm_head)" name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -118,7 +118,18 @@ ...@@ -118,7 +118,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,18 @@ ...@@ -15,6 +15,18 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
......
...@@ -118,7 +118,18 @@ ...@@ -118,7 +118,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,18 @@ ...@@ -15,6 +15,18 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
......
...@@ -188,7 +188,7 @@ ...@@ -188,7 +188,7 @@
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!! # !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!! # !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
# # GPU 0: layers 3–4 # GPU 0: layers 3–4
# - match: # - match:
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$" # name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
# replace: # replace:
...@@ -363,11 +363,20 @@ ...@@ -363,11 +363,20 @@
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
# don't inject lm_head if already inject marlin experts - match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config) # For final modules (model.norm), ensure they are on GPU 3 (as in your original config)
- match: - match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -713,11 +713,20 @@ ...@@ -713,11 +713,20 @@
generate_device: "cuda:7" generate_device: "cuda:7"
prefill_device: "cuda:7" prefill_device: "cuda:7"
# don't inject lm_head if already inject marlin experts - match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:7"
prefill_device: "cuda:7"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config) # For final modules (model.norm), ensure they are on GPU 7 (as in your original config)
- match: - match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -153,7 +153,18 @@ ...@@ -153,7 +153,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -135,7 +135,18 @@ ...@@ -135,7 +135,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -5,6 +5,18 @@ ...@@ -5,6 +5,18 @@
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.block_sparse_moe$" name: "^model\\.layers\\..*\\.block_sparse_moe$"
class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock
......
...@@ -77,9 +77,19 @@ ...@@ -77,9 +77,19 @@
kwargs: kwargs:
generate_device: "cpu" generate_device: "cpu"
prefill_device: "cpu" prefill_device: "cpu"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "(^model.norm)|(^lm_head)" name: "(^model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
......
...@@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext): ...@@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface): class KTransformersInterface(TransformersInterface):
def __init__(self, args: ConfigArgs = default_args): def __init__(self, args: ConfigArgs = default_args):
self.args = args self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
torch.set_default_dtype(config.torch_dtype)
if config.architectures[0] == "Qwen2MoeForCausalLM": if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
......
...@@ -176,7 +176,7 @@ if __name__ == "__main__": ...@@ -176,7 +176,7 @@ if __name__ == "__main__":
parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file") parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file")
parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file") parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file")
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path") parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL")
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL") # parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -26,6 +26,7 @@ from enum import IntEnum ...@@ -26,6 +26,7 @@ from enum import IntEnum
import torch import torch
import KTransformersOps import KTransformersOps
from .custom_loader import SafeTensorLoader from .custom_loader import SafeTensorLoader
import ctypes
class GGMLQuantizationType(IntEnum): class GGMLQuantizationType(IntEnum):
F32 = 0 F32 = 0
...@@ -305,7 +306,7 @@ class GGUFLoader: ...@@ -305,7 +306,7 @@ class GGUFLoader:
data = torch.from_numpy(data) data = torch.from_numpy(data)
return data, ggml_type return data, ggml_type
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor: def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
if device.lower() == "cpu": if device.lower() == "cpu":
print(f"loading expert {expert_id} of {name} with CPU") print(f"loading expert {expert_id} of {name} with CPU")
...@@ -324,19 +325,21 @@ class GGUFLoader: ...@@ -324,19 +325,21 @@ class GGUFLoader:
data = data[offset: offset + block_size * blocks_per_experts] data = data[offset: offset + block_size * blocks_per_experts]
if "cuda" in device.lower(): if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values) values = torch.from_numpy(values.copy())
values = values.view(shape[-2::-1]) values = values.view(shape[-2::-1])
return values return values
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor: def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
if device.lower() == "cpu": if device.lower() == "cpu":
print(f"loading {name} with CPU") print(f"loading {name} with CPU")
if target_dtype == None:
target_dtype = torch.get_default_dtype()
shape = t["shape"] shape = t["shape"]
ggml_type = t["ggml_type"] ggml_type = t["ggml_type"]
...@@ -348,16 +351,38 @@ class GGUFLoader: ...@@ -348,16 +351,38 @@ class GGUFLoader:
data = self.get_mmap_tensor(name) data = self.get_mmap_tensor(name)
if "cuda" in device.lower(): block_size = GGML_BLOCK_SIZES[ggml_name]
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]
#values = GGML_DEQUANTIZE[ggml_name](data) num_elements = int(np.prod(shape))
#print("load_gguf_tensor") num_blocks = num_elements // elements_per_block
#values = torch.from_numpy(values).to(device = device)
blocks_per_iter = 16384
if num_blocks > blocks_per_iter: # dequant large tensor
values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
blocks_begin = i * blocks_per_iter
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
if "cuda" in device.lower():
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
else:
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
cur_values = torch.from_numpy(cur_values.copy())
cur_values = cur_values.view(-1, elements_per_block)
if ggml_name == "BF16":
cur_values = cur_values.view(torch.bfloat16)
values[blocks_begin : blocks_end] = cur_values
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) if "cuda" in device.lower():
values = torch.from_numpy(values) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
if ggml_name == "BF16": if ggml_name == "BF16":
values = values.view(torch.bfloat16) values = values.view(torch.bfloat16)
values = values.view(shape[::-1]) values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count'] n_head = self.gguf_file_meta['llama.attention.head_count']
...@@ -456,14 +481,15 @@ def dequantize_q2_k(data): ...@@ -456,14 +481,15 @@ def dequantize_q2_k(data):
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q2_k_gpu(data, device:str ="cuda"): def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q2_K"] block_size = GGML_BLOCK_SIZES["Q2_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q2_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q2_k(data, block_size, device) return KTransformersOps.dequantize_q2_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q3_k(data): def dequantize_q3_k(data):
# C implementation # C implementation
...@@ -507,14 +533,15 @@ def dequantize_q3_k(data): ...@@ -507,14 +533,15 @@ def dequantize_q3_k(data):
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
], axis=1) ], axis=1)
def dequantize_q3_k_gpu(data, device:str ="cuda"): def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q3_K"] block_size = GGML_BLOCK_SIZES["Q3_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q3_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q3_k(data, block_size, device) return KTransformersOps.dequantize_q3_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q4_k(data): def dequantize_q4_k(data):
# C implementation # C implementation
...@@ -538,13 +565,15 @@ def dequantize_q4_k(data): ...@@ -538,13 +565,15 @@ def dequantize_q4_k(data):
# Dequantize final weights using scales and offsets # Dequantize final weights using scales and offsets
return factors * qs2 - offsets return factors * qs2 - offsets
def dequantize_q4_k_gpu(data, device:str ="cuda"): def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q4_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q4_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q4_k(data, 144, device) return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q5_k(data): def dequantize_q5_k(data):
# C implementation # C implementation
...@@ -602,14 +631,15 @@ def dequantize_q5_k(data): ...@@ -602,14 +631,15 @@ def dequantize_q5_k(data):
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1) ], axis=1)
def dequantize_q5_k_gpu(data, device:str ="cuda"): def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q5_K"] block_size = GGML_BLOCK_SIZES["Q5_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q5_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q5_k(data, block_size, device) return KTransformersOps.dequantize_q5_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q6_k(data): def dequantize_q6_k(data):
# C implementation # C implementation
...@@ -660,13 +690,14 @@ def dequantize_q6_k(data): ...@@ -660,13 +690,14 @@ def dequantize_q6_k(data):
], axis=1) ], axis=1)
# @torch.jit.script # @torch.jit.script
def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"): def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q6_K"] block_size = GGML_BLOCK_SIZES["Q6_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q6_K"]
device = torch.device(device) device = torch.device(device)
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q6_k(data, block_size, device) return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
...@@ -700,13 +731,14 @@ def dequantize_iq4_xs(data): ...@@ -700,13 +731,14 @@ def dequantize_iq4_xs(data):
return y.flatten() return y.flatten()
def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda"): def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["IQ4_XS"] block_size = GGML_BLOCK_SIZES["IQ4_XS"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["IQ4_XS"]
device = torch.device(device) device = torch.device(device)
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_iq4_xs(data, block_size, device) return KTransformersOps.dequantize_iq4_xs(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q4_0(data): def dequantize_q4_0(data):
# C implementation # C implementation
...@@ -723,7 +755,7 @@ def dequantize_q4_0(data): ...@@ -723,7 +755,7 @@ def dequantize_q4_0(data):
scales * ((qs >> 4).astype(np.int8) - 8), scales * ((qs >> 4).astype(np.int8) - 8),
], axis=1) ], axis=1)
def dequantize_q4_0_gpu(data): def dequantize_q4_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
raise NotImplementedError() raise NotImplementedError()
def dequantize_q5_0(data): def dequantize_q5_0(data):
...@@ -747,7 +779,7 @@ def dequantize_q5_0(data): ...@@ -747,7 +779,7 @@ def dequantize_q5_0(data):
scales * x1, scales * x1,
], axis=1) ], axis=1)
def dequantize_q5_0_gpu(data): def dequantize_q5_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
raise NotImplementedError() raise NotImplementedError()
def dequantize_q8_0(data): def dequantize_q8_0(data):
...@@ -759,32 +791,41 @@ def dequantize_q8_0(data): ...@@ -759,32 +791,41 @@ def dequantize_q8_0(data):
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
return scales * qs return scales * qs
def dequantize_q8_0_gpu(data, device:str = "cuda"): def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
# C struct definition # C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
block_size = GGML_BLOCK_SIZES["Q8_0"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q8_0"]
device = torch.device(device) device = torch.device(device)
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q8_0(data, 34, device) return KTransformersOps.dequantize_q8_0(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_f32(data): def dequantize_f32(data):
return np.frombuffer(data, dtype=np.float32) return np.frombuffer(data, dtype=np.float32)
def dequantize_f32_gpu(data, device): def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float32) data = np.frombuffer(data, dtype=np.float32)
res = torch.from_numpy(data) res = torch.from_numpy(data.copy())
res_gpu = torch.empty_like(res, device=device) res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)
res_gpu.copy_(res) res_gpu.copy_(res)
return res_gpu return res_gpu
def dequantize_f16(data): def dequantize_f16(data):
return np.frombuffer(data, dtype=np.float16) return np.frombuffer(data, dtype=np.float16)
def dequantize_f16_gpu(data, device): def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float16)
res = torch.from_numpy(data.copy())
res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)
res_gpu.copy_(res)
return res_gpu
def dequantize_bf16_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float16) data = np.frombuffer(data, dtype=np.float16)
res = torch.from_numpy(data) res = torch.from_numpy(data.copy())
res_gpu = torch.empty_like(res, device=device) res_gpu = torch.empty_like(res, device=device)
res_gpu.copy_(res) res_gpu.copy_(res)
return res_gpu return res_gpu
...@@ -807,7 +848,7 @@ GGML_DEQUANTIZE = { ...@@ -807,7 +848,7 @@ GGML_DEQUANTIZE = {
GGML_DEQUANTIZE_GPU = { GGML_DEQUANTIZE_GPU = {
"F32": dequantize_f32_gpu, "F32": dequantize_f32_gpu,
"F16": dequantize_f16_gpu, "F16": dequantize_f16_gpu,
"BF16": dequantize_f16_gpu, "BF16": dequantize_bf16_gpu,
"Q4_0": dequantize_q4_0_gpu, "Q4_0": dequantize_q4_0_gpu,
"Q5_0": dequantize_q5_0_gpu, "Q5_0": dequantize_q5_0_gpu,
"Q8_0": dequantize_q8_0_gpu, "Q8_0": dequantize_q8_0_gpu,
......
...@@ -90,7 +90,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str ...@@ -90,7 +90,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
raise Exception(f"can't find {translated_key} in GGUF file!") raise Exception(f"can't find {translated_key} in GGUF file!")
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
# print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}") #print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule): if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix) load_cur_state_dict(module, gguf_loader, prefix)
for name, child in module._modules.items(): for name, child in module._modules.items():
......
#!/usr/bin/env python #!/usr/bin/env python
# coding=utf-8 # coding=utf-8
''' '''
Description : Description :
Author : chenxl Author : chenxl
Date : 2024-07-27 16:15:27 Date : 2024-07-27 16:15:27
Version : 1.0.0 Version : 1.0.0
LastEditors : chenxl LastEditors : chenxl
LastEditTime : 2024-08-14 16:36:19 LastEditTime : 2024-08-14 16:36:19
Adapted from: Adapted from:
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao. Copyright (c) 2023, Tri Dao.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
import os import os
...@@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel ...@@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from setuptools import setup, Extension from setuptools import setup, Extension
from cpufeature.extension import CPUFeature from cpufeature.extension import CPUFeature
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
try:
from torch_musa.utils.simple_porting import SimplePorting
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError:
MUSA_HOME=None
class CpuInstructInfo: class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE") CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
...@@ -40,7 +45,7 @@ class CpuInstructInfo: ...@@ -40,7 +45,7 @@ class CpuInstructInfo:
CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON" CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON"
CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON" CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON"
CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON" CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON"
class VersionInfo: class VersionInfo:
THIS_DIR = os.path.dirname(os.path.abspath(__file__)) THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "ktransformers" PACKAGE_NAME = "ktransformers"
...@@ -49,6 +54,16 @@ class VersionInfo: ...@@ -49,6 +54,16 @@ class VersionInfo:
) )
FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE" FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE"
def get_musa_bare_metal_version(self, musa_dir):
raw_output = subprocess.run(
[musa_dir + "/bin/mcc", "-v"], check=True,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode("utf-8")
output = raw_output.split()
release_idx = output.index("version") + 1
bare_metal_version = parse(output[release_idx].split(",")[0])
musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return musa_version
def get_cuda_bare_metal_version(self, cuda_dir): def get_cuda_bare_metal_version(self, cuda_dir):
raw_output = subprocess.check_output( raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
...@@ -58,7 +73,7 @@ class VersionInfo: ...@@ -58,7 +73,7 @@ class VersionInfo:
cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}" cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return cuda_version return cuda_version
def get_cuda_version_of_torch(self,): def get_cuda_version_of_torch(self):
torch_cuda_version = parse(torch.version.cuda) torch_cuda_version = parse(torch.version.cuda)
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
return cuda_version return cuda_version
...@@ -117,7 +132,7 @@ class VersionInfo: ...@@ -117,7 +132,7 @@ class VersionInfo:
torch_version_raw = parse(torch.__version__) torch_version_raw = parse(torch.__version__)
torch_version = f"{torch_version_raw.major}{torch_version_raw.minor}" torch_version = f"{torch_version_raw.major}{torch_version_raw.minor}"
return torch_version return torch_version
def get_flash_version(self,): def get_flash_version(self,):
version_file = os.path.join( version_file = os.path.join(
Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, "__init__.py") Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, "__init__.py")
...@@ -128,12 +143,21 @@ class VersionInfo: ...@@ -128,12 +143,21 @@ class VersionInfo:
return flash_version return flash_version
def get_package_version(self, full_version=False): def get_package_version(self, full_version=False):
flash_version = self.get_flash_version() flash_version = str(self.get_flash_version())
package_version = f"{str(flash_version)}+cu{self.get_cuda_bare_metal_version(CUDA_HOME)}torch{self.get_torch_version()}{self.get_cpu_instruct()}" torch_version = self.get_torch_version()
cpu_instruct = self.get_cpu_instruct()
backend_version = ""
if CUDA_HOME is not None:
backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
elif MUSA_HOME is not None:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
if full_version: if full_version:
return package_version return package_version
if not VersionInfo.FORCE_BUILD: if not VersionInfo.FORCE_BUILD:
return str(flash_version) return flash_version
return package_version return package_version
...@@ -218,11 +242,19 @@ class CMakeBuild(BuildExtension): ...@@ -218,11 +242,19 @@ class CMakeBuild(BuildExtension):
f"-DPYTHON_EXECUTABLE={sys.executable}", f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
] ]
if CUDA_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
elif MUSA_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
build_args = [] build_args = []
if "CMAKE_ARGS" in os.environ: if "CMAKE_ARGS" in os.environ:
cmake_args += [ cmake_args += [
item for item in os.environ["CMAKE_ARGS"].split(" ") if item] item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY: if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
cpu_args = CpuInstructInfo.CMAKE_FANCY cpu_args = CpuInstructInfo.CMAKE_FANCY
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512: elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
...@@ -231,7 +263,7 @@ class CMakeBuild(BuildExtension): ...@@ -231,7 +263,7 @@ class CMakeBuild(BuildExtension):
cpu_args = CpuInstructInfo.CMAKE_AVX2 cpu_args = CpuInstructInfo.CMAKE_AVX2
else: else:
cpu_args = CpuInstructInfo.CMAKE_NATIVE cpu_args = CpuInstructInfo.CMAKE_NATIVE
cmake_args += [ cmake_args += [
item for item in cpu_args.split(" ") if item item for item in cpu_args.split(" ") if item
] ]
...@@ -276,8 +308,13 @@ class CMakeBuild(BuildExtension): ...@@ -276,8 +308,13 @@ class CMakeBuild(BuildExtension):
"-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] "-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
cpu_count = os.cpu_count()
if cpu_count is None:
cpu_count = 1
if hasattr(self, "parallel") and self.parallel: if hasattr(self, "parallel") and self.parallel:
build_args += [f"-j{self.parallel}"] build_args += [f"--parallel={self.parallel}"]
else:
build_args += [f"--parallel={cpu_count}"]
print("CMake args:", cmake_args) print("CMake args:", cmake_args)
build_temp = Path(ext.sourcedir) / "build" build_temp = Path(ext.sourcedir) / "build"
if not build_temp.exists(): if not build_temp.exists():
...@@ -288,28 +325,55 @@ class CMakeBuild(BuildExtension): ...@@ -288,28 +325,55 @@ class CMakeBuild(BuildExtension):
print("Standard output:", result.stdout) print("Standard output:", result.stdout)
print("Standard error:", result.stderr) print("Standard error:", result.stderr)
subprocess.run( subprocess.run(
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True ["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
) )
if CUDA_HOME is not None:
ops_module = CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
extra_compile_args={
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
'-DKTRANSFORMERS_USE_CUDA',
]
}
)
elif MUSA_HOME is not None:
SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={
# Common rules
"at::cuda": "at::musa",
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
}).run()
ops_module = MUSAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
'ktransformers/ktransformers_ext/cuda_musa/binding.cpp',
# TODO: Add Marlin support for MUSA.
# 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
],
extra_compile_args={
'cxx': ['force_mcc'],
'mcc': [
'-O3',
'-DKTRANSFORMERS_USE_MUSA',
'-DTHRUST_IGNORE_CUB_VERSION_CHECK',
]
}
)
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
setup( setup(
version=VersionInfo().get_package_version(), version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[ ext_modules=[
CMakeExtension("cpuinfer_ext"), CMakeExtension("cpuinfer_ext"),
CUDAExtension('KTransformersOps', [ ops_module,
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
]
}
)
] ]
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment