Commit ca7366d2 authored by Azure's avatar Azure
Browse files

Merge remote-tracking branch 'upstream/develop-0.2.2' into support-fp8

parents 581a524f cdb6f896
...@@ -21,6 +21,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl ...@@ -21,6 +21,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
MarlinWorkspace, MarlinWorkspace,
marlin_quantize, marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MAX_PARALLEL,
) )
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
...@@ -65,6 +66,8 @@ class KLinearBase(ABC): ...@@ -65,6 +66,8 @@ class KLinearBase(ABC):
self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0]
self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1]
self.loaded = False # for lm_head pre-load, TODO: use new way to do lm_head pre-load when layer wise prefill.
@abstractmethod @abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
pass pass
...@@ -141,6 +144,7 @@ class KLinearTorch(KLinearBase): ...@@ -141,6 +144,7 @@ class KLinearTorch(KLinearBase):
return x return x
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device if device is None: device = self.device
if w is None: w = self.load_weight(device=device) if w is None: w = self.load_weight(device=device)
# else: self.out_features = w.shape[0], self.in_features = w.shape[1] # else: self.out_features = w.shape[0], self.in_features = w.shape[1]
...@@ -164,6 +168,7 @@ class KLinearTorch(KLinearBase): ...@@ -164,6 +168,7 @@ class KLinearTorch(KLinearBase):
self.weight = self.weight.to(device) self.weight = self.weight.to(device)
if self.has_bias: if self.has_bias:
self.bias = self.bias.to(device) self.bias = self.bias.to(device)
self.loaded = True
def unload(self): def unload(self):
if self.weight is not None: if self.weight is not None:
...@@ -251,20 +256,36 @@ class KLinearMarlin(KLinearBase): ...@@ -251,20 +256,36 @@ class KLinearMarlin(KLinearBase):
self.group_size = group_size self.group_size = group_size
self.act_order = act_order self.act_order = act_order
self.is_k_full = is_k_full self.is_k_full = is_k_full
self.padding = False
self.orin_in_features = self.in_features
self.orin_out_features = self.out_features
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self.padding = True
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
self.k = self.in_features
self.n = self.out_features
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
#if self.in_features * self.out_features:
if w is None: if w is None:
w = self.load_weight(device=device) w = self.load_weight(device=device)
if isinstance(w, nn.Parameter): if isinstance(w, nn.Parameter):
# pad weight # pad weight
weight = w.view(self.out_features, self.in_features).T weight = w.view(self.orin_out_features, self.orin_in_features).T
self.has_bias = False self.has_bias = False
elif isinstance(w, tuple): elif isinstance(w, tuple):
w = list(w) w = list(w)
weight = w[0].view(self.out_features, self.in_features).T weight = w[0].view(self.orin_out_features, self.orin_in_features).T
self.bias = w[1].view(self.orin_out_features)
self.bias = w[1] self.bias = w[1]
self.has_bias = True self.has_bias = True
else: else:
...@@ -272,8 +293,14 @@ class KLinearMarlin(KLinearBase): ...@@ -272,8 +293,14 @@ class KLinearMarlin(KLinearBase):
weight = weight.to(device) weight = weight.to(device)
if self.has_bias: if self.has_bias:
self.bias = self.bias.to(device) self.bias = self.bias.to(device)
if self.padding:
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
weight = padded_weight
# Pack Marlin linear # Pack Marlin linear
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
weight, self.num_bits, self.group_size, self.act_order weight, self.num_bits, self.group_size, self.act_order
) )
self.workspace = MarlinWorkspace( self.workspace = MarlinWorkspace(
...@@ -286,6 +313,7 @@ class KLinearMarlin(KLinearBase): ...@@ -286,6 +313,7 @@ class KLinearMarlin(KLinearBase):
self.sort_indices = sort_indices self.sort_indices = sort_indices
self.k = weight.shape[0] self.k = weight.shape[0]
self.n = weight.shape[1] self.n = weight.shape[1]
self.loaded = True
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# Only support input x as BF16 and FP16 # Only support input x as BF16 and FP16
...@@ -293,6 +321,11 @@ class KLinearMarlin(KLinearBase): ...@@ -293,6 +321,11 @@ class KLinearMarlin(KLinearBase):
orig_shape = list(x.shape) orig_shape = list(x.shape)
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.reshape(-1, orig_shape[-1]) x = x.reshape(-1, orig_shape[-1])
x = x.reshape(-1, x.shape[-1])
if self.padding:
padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)
padding_input[:,:self.orin_in_features] = x
x = padding_input
marlin_s = self.marlin_s.to(x.dtype) marlin_s = self.marlin_s.to(x.dtype)
x = KTransformersOps.gptq_marlin_gemm( x = KTransformersOps.gptq_marlin_gemm(
x, x,
...@@ -307,6 +340,11 @@ class KLinearMarlin(KLinearBase): ...@@ -307,6 +340,11 @@ class KLinearMarlin(KLinearBase):
x.shape[-1], x.shape[-1],
self.is_k_full, self.is_k_full,
) )
if self.padding:
x = x[:,:self.orin_out_features]
orig_shape[-1] = self.orin_out_features
else:
orig_shape[-1] = self.out_features
if self.has_bias: if self.has_bias:
x = x + self.bias x = x + self.bias
orig_shape[-1] = self.n orig_shape[-1] = self.n
...@@ -450,23 +488,12 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): ...@@ -450,23 +488,12 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
# build all the linear operators # build all the linear operators
if prefill_op is not None: if prefill_op is not None:
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
if prefill_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0):
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}")
self.prefill_linear = KLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else:
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else: else:
self.prefill_linear = None self.prefill_linear = None
if generate_op is not None: if generate_op is not None:
assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported"
if generate_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0):
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}")
self.generate_op = "KLinearTorch"
self.generate_linear = KLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs)
else:
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
else: else:
self.generate_linear = None self.generate_linear = None
......
...@@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo ...@@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
gguf_loader=GGUFLoader(gguf_path) gguf_loader=GGUFLoader(gguf_path)
with torch.device("meta"): with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader) inject(module, optimize_config, model_config, gguf_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, gguf_loader, "lm_head.")
load_weights(module, gguf_loader) load_weights(module, gguf_loader)
module.gguf_loader = gguf_loader module.gguf_loader = gguf_loader
del_meta(module) del_meta(module)
......
...@@ -219,8 +219,20 @@ ...@@ -219,8 +219,20 @@
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)|(^lm_head)" name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -118,7 +118,18 @@ ...@@ -118,7 +118,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,18 @@ ...@@ -15,6 +15,18 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
......
...@@ -118,7 +118,18 @@ ...@@ -118,7 +118,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,18 @@ ...@@ -15,6 +15,18 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
......
...@@ -188,7 +188,7 @@ ...@@ -188,7 +188,7 @@
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!! # !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!! # !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
# # GPU 0: layers 3–4 # GPU 0: layers 3–4
# - match: # - match:
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$" # name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
# replace: # replace:
...@@ -363,11 +363,20 @@ ...@@ -363,11 +363,20 @@
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
# don't inject lm_head if already inject marlin experts - match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config) # For final modules (model.norm), ensure they are on GPU 3 (as in your original config)
- match: - match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -713,11 +713,20 @@ ...@@ -713,11 +713,20 @@
generate_device: "cuda:7" generate_device: "cuda:7"
prefill_device: "cuda:7" prefill_device: "cuda:7"
# don't inject lm_head if already inject marlin experts - match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:7"
prefill_device: "cuda:7"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config) # For final modules (model.norm), ensure they are on GPU 7 (as in your original config)
- match: - match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -153,7 +153,18 @@ ...@@ -153,7 +153,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -135,7 +135,18 @@ ...@@ -135,7 +135,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -5,6 +5,18 @@ ...@@ -5,6 +5,18 @@
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.block_sparse_moe$" name: "^model\\.layers\\..*\\.block_sparse_moe$"
class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock
......
...@@ -77,9 +77,19 @@ ...@@ -77,9 +77,19 @@
kwargs: kwargs:
generate_device: "cpu" generate_device: "cpu"
prefill_device: "cpu" prefill_device: "cpu"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "(^model.norm)|(^lm_head)" name: "(^model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
......
...@@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext): ...@@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface): class KTransformersInterface(TransformersInterface):
def __init__(self, args: ConfigArgs = default_args): def __init__(self, args: ConfigArgs = default_args):
self.args = args self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
torch.set_default_dtype(config.torch_dtype)
if config.architectures[0] == "Qwen2MoeForCausalLM": if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
......
...@@ -176,7 +176,7 @@ if __name__ == "__main__": ...@@ -176,7 +176,7 @@ if __name__ == "__main__":
parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file") parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file")
parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file") parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file")
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path") parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL")
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL") # parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -26,6 +26,7 @@ from enum import IntEnum ...@@ -26,6 +26,7 @@ from enum import IntEnum
import torch import torch
import KTransformersOps import KTransformersOps
from .custom_loader import SafeTensorLoader from .custom_loader import SafeTensorLoader
import ctypes
class GGMLQuantizationType(IntEnum): class GGMLQuantizationType(IntEnum):
F32 = 0 F32 = 0
...@@ -305,7 +306,7 @@ class GGUFLoader: ...@@ -305,7 +306,7 @@ class GGUFLoader:
data = torch.from_numpy(data) data = torch.from_numpy(data)
return data, ggml_type return data, ggml_type
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor: def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
if device.lower() == "cpu": if device.lower() == "cpu":
print(f"loading expert {expert_id} of {name} with CPU") print(f"loading expert {expert_id} of {name} with CPU")
...@@ -324,19 +325,21 @@ class GGUFLoader: ...@@ -324,19 +325,21 @@ class GGUFLoader:
data = data[offset: offset + block_size * blocks_per_experts] data = data[offset: offset + block_size * blocks_per_experts]
if "cuda" in device.lower(): if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values) values = torch.from_numpy(values.copy())
values = values.view(shape[-2::-1]) values = values.view(shape[-2::-1])
return values return values
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor: def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
if device.lower() == "cpu": if device.lower() == "cpu":
print(f"loading {name} with CPU") print(f"loading {name} with CPU")
if target_dtype == None:
target_dtype = torch.get_default_dtype()
shape = t["shape"] shape = t["shape"]
ggml_type = t["ggml_type"] ggml_type = t["ggml_type"]
...@@ -348,16 +351,38 @@ class GGUFLoader: ...@@ -348,16 +351,38 @@ class GGUFLoader:
data = self.get_mmap_tensor(name) data = self.get_mmap_tensor(name)
block_size = GGML_BLOCK_SIZES[ggml_name]
elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]
num_elements = int(np.prod(shape))
num_blocks = num_elements // elements_per_block
blocks_per_iter = 16384
if num_blocks > blocks_per_iter: # dequant large tensor
values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
blocks_begin = i * blocks_per_iter
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
if "cuda" in device.lower():
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
else:
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
cur_values = torch.from_numpy(cur_values.copy())
cur_values = cur_values.view(-1, elements_per_block)
if ggml_name == "BF16":
cur_values = cur_values.view(torch.bfloat16)
values[blocks_begin : blocks_end] = cur_values
else:
if "cuda" in device.lower(): if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
#values = GGML_DEQUANTIZE[ggml_name](data)
#print("load_gguf_tensor")
#values = torch.from_numpy(values).to(device = device)
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values) values = torch.from_numpy(values)
if ggml_name == "BF16": if ggml_name == "BF16":
values = values.view(torch.bfloat16) values = values.view(torch.bfloat16)
values = values.view(shape[::-1]) values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count'] n_head = self.gguf_file_meta['llama.attention.head_count']
...@@ -456,14 +481,15 @@ def dequantize_q2_k(data): ...@@ -456,14 +481,15 @@ def dequantize_q2_k(data):
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q2_k_gpu(data, device:str ="cuda"): def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q2_K"] block_size = GGML_BLOCK_SIZES["Q2_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q2_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q2_k(data, block_size, device) return KTransformersOps.dequantize_q2_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q3_k(data): def dequantize_q3_k(data):
# C implementation # C implementation
...@@ -507,14 +533,15 @@ def dequantize_q3_k(data): ...@@ -507,14 +533,15 @@ def dequantize_q3_k(data):
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
], axis=1) ], axis=1)
def dequantize_q3_k_gpu(data, device:str ="cuda"): def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q3_K"] block_size = GGML_BLOCK_SIZES["Q3_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q3_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q3_k(data, block_size, device) return KTransformersOps.dequantize_q3_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q4_k(data): def dequantize_q4_k(data):
# C implementation # C implementation
...@@ -538,13 +565,15 @@ def dequantize_q4_k(data): ...@@ -538,13 +565,15 @@ def dequantize_q4_k(data):
# Dequantize final weights using scales and offsets # Dequantize final weights using scales and offsets
return factors * qs2 - offsets return factors * qs2 - offsets
def dequantize_q4_k_gpu(data, device:str ="cuda"): def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q4_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q4_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q4_k(data, 144, device) return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q5_k(data): def dequantize_q5_k(data):
# C implementation # C implementation
...@@ -602,14 +631,15 @@ def dequantize_q5_k(data): ...@@ -602,14 +631,15 @@ def dequantize_q5_k(data):
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1) ], axis=1)
def dequantize_q5_k_gpu(data, device:str ="cuda"): def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q5_K"] block_size = GGML_BLOCK_SIZES["Q5_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q5_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q5_k(data, block_size, device) return KTransformersOps.dequantize_q5_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q6_k(data): def dequantize_q6_k(data):
# C implementation # C implementation
...@@ -660,13 +690,14 @@ def dequantize_q6_k(data): ...@@ -660,13 +690,14 @@ def dequantize_q6_k(data):
], axis=1) ], axis=1)
# @torch.jit.script # @torch.jit.script
def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"): def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q6_K"] block_size = GGML_BLOCK_SIZES["Q6_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q6_K"]
device = torch.device(device) device = torch.device(device)
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q6_k(data, block_size, device) return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
...@@ -700,13 +731,14 @@ def dequantize_iq4_xs(data): ...@@ -700,13 +731,14 @@ def dequantize_iq4_xs(data):
return y.flatten() return y.flatten()
def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda"): def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["IQ4_XS"] block_size = GGML_BLOCK_SIZES["IQ4_XS"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["IQ4_XS"]
device = torch.device(device) device = torch.device(device)
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_iq4_xs(data, block_size, device) return KTransformersOps.dequantize_iq4_xs(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q4_0(data): def dequantize_q4_0(data):
# C implementation # C implementation
...@@ -723,7 +755,7 @@ def dequantize_q4_0(data): ...@@ -723,7 +755,7 @@ def dequantize_q4_0(data):
scales * ((qs >> 4).astype(np.int8) - 8), scales * ((qs >> 4).astype(np.int8) - 8),
], axis=1) ], axis=1)
def dequantize_q4_0_gpu(data): def dequantize_q4_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
raise NotImplementedError() raise NotImplementedError()
def dequantize_q5_0(data): def dequantize_q5_0(data):
...@@ -747,7 +779,7 @@ def dequantize_q5_0(data): ...@@ -747,7 +779,7 @@ def dequantize_q5_0(data):
scales * x1, scales * x1,
], axis=1) ], axis=1)
def dequantize_q5_0_gpu(data): def dequantize_q5_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
raise NotImplementedError() raise NotImplementedError()
def dequantize_q8_0(data): def dequantize_q8_0(data):
...@@ -759,32 +791,41 @@ def dequantize_q8_0(data): ...@@ -759,32 +791,41 @@ def dequantize_q8_0(data):
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
return scales * qs return scales * qs
def dequantize_q8_0_gpu(data, device:str = "cuda"): def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
# C struct definition # C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
block_size = GGML_BLOCK_SIZES["Q8_0"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q8_0"]
device = torch.device(device) device = torch.device(device)
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q8_0(data, 34, device) return KTransformersOps.dequantize_q8_0(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_f32(data): def dequantize_f32(data):
return np.frombuffer(data, dtype=np.float32) return np.frombuffer(data, dtype=np.float32)
def dequantize_f32_gpu(data, device): def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float32) data = np.frombuffer(data, dtype=np.float32)
res = torch.from_numpy(data) res = torch.from_numpy(data.copy())
res_gpu = torch.empty_like(res, device=device) res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)
res_gpu.copy_(res) res_gpu.copy_(res)
return res_gpu return res_gpu
def dequantize_f16(data): def dequantize_f16(data):
return np.frombuffer(data, dtype=np.float16) return np.frombuffer(data, dtype=np.float16)
def dequantize_f16_gpu(data, device): def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float16)
res = torch.from_numpy(data.copy())
res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)
res_gpu.copy_(res)
return res_gpu
def dequantize_bf16_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float16) data = np.frombuffer(data, dtype=np.float16)
res = torch.from_numpy(data) res = torch.from_numpy(data.copy())
res_gpu = torch.empty_like(res, device=device) res_gpu = torch.empty_like(res, device=device)
res_gpu.copy_(res) res_gpu.copy_(res)
return res_gpu return res_gpu
...@@ -807,7 +848,7 @@ GGML_DEQUANTIZE = { ...@@ -807,7 +848,7 @@ GGML_DEQUANTIZE = {
GGML_DEQUANTIZE_GPU = { GGML_DEQUANTIZE_GPU = {
"F32": dequantize_f32_gpu, "F32": dequantize_f32_gpu,
"F16": dequantize_f16_gpu, "F16": dequantize_f16_gpu,
"BF16": dequantize_f16_gpu, "BF16": dequantize_bf16_gpu,
"Q4_0": dequantize_q4_0_gpu, "Q4_0": dequantize_q4_0_gpu,
"Q5_0": dequantize_q5_0_gpu, "Q5_0": dequantize_q5_0_gpu,
"Q8_0": dequantize_q8_0_gpu, "Q8_0": dequantize_q8_0_gpu,
......
...@@ -90,7 +90,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str ...@@ -90,7 +90,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
raise Exception(f"can't find {translated_key} in GGUF file!") raise Exception(f"can't find {translated_key} in GGUF file!")
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
# print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}") #print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule): if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix) load_cur_state_dict(module, gguf_loader, prefix)
for name, child in module._modules.items(): for name, child in module._modules.items():
......
...@@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel ...@@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from setuptools import setup, Extension from setuptools import setup, Extension
from cpufeature.extension import CPUFeature from cpufeature.extension import CPUFeature
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
try:
from torch_musa.utils.simple_porting import SimplePorting
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError:
MUSA_HOME=None
class CpuInstructInfo: class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE") CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
...@@ -49,6 +54,16 @@ class VersionInfo: ...@@ -49,6 +54,16 @@ class VersionInfo:
) )
FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE" FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE"
def get_musa_bare_metal_version(self, musa_dir):
raw_output = subprocess.run(
[musa_dir + "/bin/mcc", "-v"], check=True,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode("utf-8")
output = raw_output.split()
release_idx = output.index("version") + 1
bare_metal_version = parse(output[release_idx].split(",")[0])
musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return musa_version
def get_cuda_bare_metal_version(self, cuda_dir): def get_cuda_bare_metal_version(self, cuda_dir):
raw_output = subprocess.check_output( raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
...@@ -58,7 +73,7 @@ class VersionInfo: ...@@ -58,7 +73,7 @@ class VersionInfo:
cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}" cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return cuda_version return cuda_version
def get_cuda_version_of_torch(self,): def get_cuda_version_of_torch(self):
torch_cuda_version = parse(torch.version.cuda) torch_cuda_version = parse(torch.version.cuda)
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
return cuda_version return cuda_version
...@@ -128,12 +143,21 @@ class VersionInfo: ...@@ -128,12 +143,21 @@ class VersionInfo:
return flash_version return flash_version
def get_package_version(self, full_version=False): def get_package_version(self, full_version=False):
flash_version = self.get_flash_version() flash_version = str(self.get_flash_version())
package_version = f"{str(flash_version)}+cu{self.get_cuda_bare_metal_version(CUDA_HOME)}torch{self.get_torch_version()}{self.get_cpu_instruct()}" torch_version = self.get_torch_version()
cpu_instruct = self.get_cpu_instruct()
backend_version = ""
if CUDA_HOME is not None:
backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
elif MUSA_HOME is not None:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
if full_version: if full_version:
return package_version return package_version
if not VersionInfo.FORCE_BUILD: if not VersionInfo.FORCE_BUILD:
return str(flash_version) return flash_version
return package_version return package_version
...@@ -218,6 +242,14 @@ class CMakeBuild(BuildExtension): ...@@ -218,6 +242,14 @@ class CMakeBuild(BuildExtension):
f"-DPYTHON_EXECUTABLE={sys.executable}", f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
] ]
if CUDA_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
elif MUSA_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
build_args = [] build_args = []
if "CMAKE_ARGS" in os.environ: if "CMAKE_ARGS" in os.environ:
cmake_args += [ cmake_args += [
...@@ -276,8 +308,13 @@ class CMakeBuild(BuildExtension): ...@@ -276,8 +308,13 @@ class CMakeBuild(BuildExtension):
"-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] "-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
cpu_count = os.cpu_count()
if cpu_count is None:
cpu_count = 1
if hasattr(self, "parallel") and self.parallel: if hasattr(self, "parallel") and self.parallel:
build_args += [f"-j{self.parallel}"] build_args += [f"--parallel={self.parallel}"]
else:
build_args += [f"--parallel={cpu_count}"]
print("CMake args:", cmake_args) print("CMake args:", cmake_args)
build_temp = Path(ext.sourcedir) / "build" build_temp = Path(ext.sourcedir) / "build"
if not build_temp.exists(): if not build_temp.exists():
...@@ -288,28 +325,55 @@ class CMakeBuild(BuildExtension): ...@@ -288,28 +325,55 @@ class CMakeBuild(BuildExtension):
print("Standard output:", result.stdout) print("Standard output:", result.stdout)
print("Standard error:", result.stderr) print("Standard error:", result.stderr)
subprocess.run( subprocess.run(
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True ["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
) )
if CUDA_HOME is not None:
setup( ops_module = CUDAExtension('KTransformersOps', [
version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[
CMakeExtension("cpuinfer_ext"),
CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp', 'ktransformers/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu' 'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
], ],
extra_compile_args={ extra_compile_args={
'cxx': ['-O3'], 'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
'nvcc': [ 'nvcc': [
'-O3', '-O3',
'--use_fast_math', '--use_fast_math',
'-Xcompiler', '-fPIC', '-Xcompiler', '-fPIC',
'-DKTRANSFORMERS_USE_CUDA',
]
}
)
elif MUSA_HOME is not None:
SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={
# Common rules
"at::cuda": "at::musa",
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
}).run()
ops_module = MUSAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
'ktransformers/ktransformers_ext/cuda_musa/binding.cpp',
# TODO: Add Marlin support for MUSA.
# 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
],
extra_compile_args={
'cxx': ['force_mcc'],
'mcc': [
'-O3',
'-DKTRANSFORMERS_USE_MUSA',
'-DTHRUST_IGNORE_CUB_VERSION_CHECK',
] ]
} }
) )
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
setup(
version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[
CMakeExtension("cpuinfer_ext"),
ops_module,
] ]
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment