Commit f5f79f5c authored by chenxl's avatar chenxl
Browse files

[ADD] support multi-gpu qlen>1 q5_k

parent f2938031
...@@ -21,6 +21,7 @@ class CUDAGraphRunner: ...@@ -21,6 +21,7 @@ class CUDAGraphRunner:
position_ids, position_ids,
cache_position, cache_position,
past_key_values, past_key_values,
main_device,
**kwargs, **kwargs,
) -> None: ) -> None:
assert self.graph is None assert self.graph is None
...@@ -29,15 +30,24 @@ class CUDAGraphRunner: ...@@ -29,15 +30,24 @@ class CUDAGraphRunner:
self.graph = torch.cuda.CUDAGraph() self.graph = torch.cuda.CUDAGraph()
#self.graph.enable_debug_mode() #self.graph.enable_debug_mode()
self.model = model self.model = model
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to("cuda") inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device)
with torch.cuda.graph(self.graph): # torch.cuda.set_device can't set "cuda", must have a index
if main_device == "cuda":
main_device = "cuda:0"
torch.cuda.set_device(main_device)
self.main_device = main_device
capture_stream = torch.cuda.Stream()
with torch.cuda.graph(self.graph, stream = capture_stream):
logits=model(inputs_embeds=inputs_embeds, logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cache_position=cache_position, cache_position=cache_position,
past_key_values=past_key_values, past_key_values=past_key_values,
**kwargs)[0] **kwargs)[0]
capture_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.set_device(main_device)
torch.cuda.set_stream(capture_stream)
past_key_values.change_seq_length(-1) past_key_values.change_seq_length(-1)
torch.cuda.synchronize() torch.cuda.synchronize(self.main_device)
#self.graph.debug_dump("cuda_graph_hooked.dot") #self.graph.debug_dump("cuda_graph_hooked.dot")
# Save the input and output buffers. # Save the input and output buffers.
...@@ -65,7 +75,7 @@ class CUDAGraphRunner: ...@@ -65,7 +75,7 @@ class CUDAGraphRunner:
#print("begin replay") #print("begin replay")
#time.sleep(1) #time.sleep(1)
self.graph.replay() self.graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize(self.main_device)
# Return the output tensor. # Return the output tensor.
return self.output_buffers["logits"] return self.output_buffers["logits"]
......
...@@ -5,8 +5,11 @@ Description : ...@@ -5,8 +5,11 @@ Description :
Author : Azure-Tang, Boxin Zhang, chenht2022 Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54 Date : 2024-07-26 08:48:54
Version : 1.0.0 Version : 1.0.0
LastEditors : Azure LastEditors : kkk1nak0
LastEditTime : 2024-07-26 09:28:25 LastEditTime : 2024-08-09 08:03:44
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
# copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf # copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf
...@@ -15,6 +18,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ...@@ -15,6 +18,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import struct import struct
import warnings import warnings
import numpy as np import numpy as np
import re
import numpy.typing as npt import numpy.typing as npt
from typing import Sequence from typing import Sequence
import os import os
...@@ -96,6 +100,8 @@ def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantization ...@@ -96,6 +100,8 @@ def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantization
GGML_TYPES = { GGML_TYPES = {
"F32": 0, "F32": 0,
"F16": 1, "F16": 1,
"Q4_0": 2,
"Q5_0": 6,
"Q8_0": 8, "Q8_0": 8,
"Q2_K": 10, "Q2_K": 10,
"Q3_K": 11, "Q3_K": 11,
...@@ -109,6 +115,8 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()} ...@@ -109,6 +115,8 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
GGML_BLOCK_SIZES = { GGML_BLOCK_SIZES = {
"F32": 4, "F32": 4,
"F16": 2, "F16": 2,
"Q4_0": 2 + 16,
"Q5_0": 2 + 4 + 16,
"Q8_0": 2 + 32, "Q8_0": 2 + 32,
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2, "Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2, "Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
...@@ -120,6 +128,8 @@ GGML_BLOCK_SIZES = { ...@@ -120,6 +128,8 @@ GGML_BLOCK_SIZES = {
GGML_ELEMENTS_PER_BLOCK = { GGML_ELEMENTS_PER_BLOCK = {
"F32": 1, "F32": 1,
"F16": 1, "F16": 1,
"Q4_0": 32,
"Q5_0": 32,
"Q8_0": 32, "Q8_0": 32,
"Q2_K": 256, "Q2_K": 256,
"Q3_K": 256, "Q3_K": 256,
...@@ -128,14 +138,6 @@ GGML_ELEMENTS_PER_BLOCK = { ...@@ -128,14 +138,6 @@ GGML_ELEMENTS_PER_BLOCK = {
"Q6_K": 256, "Q6_K": 256,
} }
# DATA_TYPES = {
# "uint32": 4,
# "int32": 5,
# "float32": 6,
# "string": 8,
# "array": 9,
# "uint64": 10,
# }
DATA_TYPES = { DATA_TYPES = {
"uint8": 0, "uint8": 0,
"int8": 1, "int8": 1,
...@@ -167,6 +169,7 @@ class GGUFLoader: ...@@ -167,6 +169,7 @@ class GGUFLoader:
self.tensor_file_map = {} self.tensor_file_map = {}
self.file_data_map = {} self.file_data_map = {}
self.gguf_file_meta = {} self.gguf_file_meta = {}
self.tensor_device_map = {}
# Walk through all the .gguf files in the directory # Walk through all the .gguf files in the directory
for root, dirs, files in os.walk(gguf_path): for root, dirs, files in os.walk(gguf_path):
...@@ -272,7 +275,7 @@ class GGUFLoader: ...@@ -272,7 +275,7 @@ class GGUFLoader:
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor: def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
shape = t["shape"] shape = t["shape"]
ggml_type = t["ggml_type"] ggml_type = t["ggml_type"]
...@@ -282,15 +285,28 @@ class GGUFLoader: ...@@ -282,15 +285,28 @@ class GGUFLoader:
ggml_name = GGML_NAMES[ggml_type] ggml_name = GGML_NAMES[ggml_type]
data = self.get_mmap_tensor(name) data = self.get_mmap_tensor(name)
if "cuda" in device.lower(): if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
#values = GGML_DEQUANTIZE[ggml_name](data)
#print("load_gguf_tensor")
#values = torch.from_numpy(values).to(device = device)
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values) values = torch.from_numpy(values)
return values.view(shape[::-1]) values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count']
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
.swapaxes(1, 2)
.reshape(values.shape))
elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count_kv']
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
.swapaxes(1, 2)
.reshape(values.shape))
return values
def read_value(f, data_type): def read_value(f, data_type):
if data_type == DATA_TYPES["string"]: if data_type == DATA_TYPES["string"]:
...@@ -375,7 +391,7 @@ def dequantize_q2_k(data): ...@@ -375,7 +391,7 @@ def dequantize_q2_k(data):
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q2_k_gpu(data): def dequantize_q2_k_gpu(data):
pass raise NotImplementedError()
def dequantize_q3_k(data): def dequantize_q3_k(data):
# C implementation # C implementation
...@@ -420,7 +436,7 @@ def dequantize_q3_k(data): ...@@ -420,7 +436,7 @@ def dequantize_q3_k(data):
], axis=1) ], axis=1)
def dequantize_q3_k_gpu(data): def dequantize_q3_k_gpu(data):
pass raise NotImplementedError()
def dequantize_q4_k(data): def dequantize_q4_k(data):
# C implementation # C implementation
...@@ -429,20 +445,16 @@ def dequantize_q4_k(data): ...@@ -429,20 +445,16 @@ def dequantize_q4_k(data):
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116 # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116
block_size = GGML_BLOCK_SIZES["Q4_K"] block_size = GGML_BLOCK_SIZES["Q4_K"]
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
# Casting to float32 because float16 is very slow on CPU # Casting to float32 because float16 is very slow on CPU
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32) scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32) scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1) qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32) qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
# Dequantize scales and offsets (6 bits and 4 + 2 bits) # Dequantize scales and offsets (6 bits and 4 + 2 bits)
factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1) factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1) offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
# Interleave low and high quantized bits # Interleave low and high quantized bits
qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32) qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
# Dequantize final weights using scales and offsets # Dequantize final weights using scales and offsets
...@@ -512,9 +524,14 @@ def dequantize_q5_k(data): ...@@ -512,9 +524,14 @@ def dequantize_q5_k(data):
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1) ], axis=1)
def dequantize_q5_k_gpu(data): def dequantize_q5_k_gpu(data, device:str ="cuda"):
pass block_size = GGML_BLOCK_SIZES["Q5_K"]
data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q5_k(data, block_size, device)
def dequantize_q6_k(data): def dequantize_q6_k(data):
# C implementation # C implementation
...@@ -571,7 +588,49 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"): ...@@ -571,7 +588,49 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) data = torch.from_numpy(data)
return KTransformersOps.dequantize_q6_k(data, 210, device) return KTransformersOps.dequantize_q6_k(data, block_size, device)
def dequantize_q4_0(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515
# C struct definition
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L141
num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]
return np.concatenate([
scales * ((qs & 0xf).astype(np.int8) - 8),
scales * ((qs >> 4).astype(np.int8) - 8),
], axis=1)
def dequantize_q4_0_gpu(data):
raise NotImplementedError()
def dequantize_q5_0(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1556
# C struct definition
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L161
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)
qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]
bits = np.unpackbits(qh, axis=-1, bitorder="little")
x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16
x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16
return np.concatenate([
scales * x0,
scales * x1,
], axis=1)
def dequantize_q5_0_gpu(data):
raise NotImplementedError()
def dequantize_q8_0(data): def dequantize_q8_0(data):
# C struct definition # C struct definition
...@@ -615,6 +674,8 @@ def dequantize_f16_gpu(data, device): ...@@ -615,6 +674,8 @@ def dequantize_f16_gpu(data, device):
GGML_DEQUANTIZE = { GGML_DEQUANTIZE = {
"F32": dequantize_f32, "F32": dequantize_f32,
"F16": dequantize_f16, "F16": dequantize_f16,
"Q4_0": dequantize_q4_0,
"Q5_0": dequantize_q5_0,
"Q8_0": dequantize_q8_0, "Q8_0": dequantize_q8_0,
"Q2_K": dequantize_q2_k, "Q2_K": dequantize_q2_k,
"Q3_K": dequantize_q3_k, "Q3_K": dequantize_q3_k,
...@@ -626,6 +687,8 @@ GGML_DEQUANTIZE = { ...@@ -626,6 +687,8 @@ GGML_DEQUANTIZE = {
GGML_DEQUANTIZE_GPU = { GGML_DEQUANTIZE_GPU = {
"F32": dequantize_f32_gpu, "F32": dequantize_f32_gpu,
"F16": dequantize_f16_gpu, "F16": dequantize_f16_gpu,
"Q4_0": dequantize_q4_0_gpu,
"Q5_0": dequantize_q5_0_gpu,
"Q8_0": dequantize_q8_0_gpu, "Q8_0": dequantize_q8_0_gpu,
"Q2_K": dequantize_q2_k_gpu, "Q2_K": dequantize_q2_k_gpu,
"Q3_K": dequantize_q3_k_gpu, "Q3_K": dequantize_q3_k_gpu,
...@@ -634,7 +697,34 @@ GGML_DEQUANTIZE_GPU = { ...@@ -634,7 +697,34 @@ GGML_DEQUANTIZE_GPU = {
"Q6_K": dequantize_q6_k_gpu, "Q6_K": dequantize_q6_k_gpu,
} }
def translate_name_to_gguf_mixtral(name):
replacement_template = {
"w1.weight": "ffn_gate",
"w2.weight": "ffn_down",
"w3.weight": "ffn_up"
}
pattern = re.compile(r"model.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(w\d\.weight)")
def replace_match(match):
blk_id = match.group(1)
expert_id = match.group(2)
weight_type = match.group(3)
if weight_type in replacement_template:
return f"blk.{blk_id}.{replacement_template[weight_type]}.{expert_id}.weight"
else:
return match.group(0)
new_name = re.sub(pattern, replace_match, name)
return new_name
def translate_name_to_gguf(name): def translate_name_to_gguf(name):
name = translate_name_to_gguf_mixtral(name)
name = name.replace("lm_head.", "output.") name = name.replace("lm_head.", "output.")
name = name.replace("model.embed_tokens.", "token_embd.") name = name.replace("model.embed_tokens.", "token_embd.")
name = name.replace("model.norm.", "output_norm.") name = name.replace("model.norm.", "output_norm.")
...@@ -671,9 +761,14 @@ def translate_name_to_gguf(name): ...@@ -671,9 +761,14 @@ def translate_name_to_gguf(name):
name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps") name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps")
name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps") name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps")
name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.")
name = name.replace(".block_sparse_moe.experts", "")
return name return name
if __name__ == '__main__': if __name__ == '__main__':
gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH' gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'
loader = GGUFLoader(gguf_path) loader = GGUFLoader(gguf_path)
loader.load_gguf_tensor('token_embd.weight') loader.load_gguf_tensor('token_embd.weight')
...@@ -39,6 +39,22 @@ def set_param(module: nn.Module, name: str, weights: torch.Tensor): ...@@ -39,6 +39,22 @@ def set_param(module: nn.Module, name: str, weights: torch.Tensor):
param.unsqueeze_(0) param.unsqueeze_(0)
setattr(module, name, param) setattr(module, name, param)
def get_device(gguf_module_key:str, device_map:dict):
if gguf_module_key in device_map:
return device_map[gguf_module_key]["generate_device"]
else:
return "cuda"
def get_all_used_cuda_device(device_map:dict):
all_device_list = set()
for key in device_map:
all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
if "cpu" in all_device_list:
all_device_list.remove("cpu")
all_device_list = list(all_device_list)
return all_device_list
def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""): def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""):
prefix = prefix.replace("orig_module.", "") prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
...@@ -47,18 +63,19 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str ...@@ -47,18 +63,19 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
for name, param in local_state.items(): for name, param in local_state.items():
key = prefix + name key = prefix + name
translated_key = translate_name_to_gguf(key) translated_key = translate_name_to_gguf(key)
print("default loading weights", key, translated_key)
if translated_key in gguf_loader.tensor_file_map: if translated_key in gguf_loader.tensor_file_map:
target_dtype = torch.get_default_dtype() target_dtype = torch.get_default_dtype()
device = "cpu" if "embd" in translated_key else "cuda" device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
# device = "cpu" if "embd" in translated_key else "cuda"
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype) weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
set_param(module, name, weights) set_param(module, name, weights)
del weights del weights
else: else:
#print(load_config.tensor_file_map.keys()) #print(load_config.tensor_file_map.keys())
raise Exception(f"can't fand {translated_key} in GGUF file!") raise Exception(f"can't find {translated_key} in GGUF file!")
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_when_injected:bool = False, only_load_injected:bool = False): def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
# print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}") # print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}")
if not isinstance(module, base_operator.BaseInjectedModule): if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix) load_cur_state_dict(module, gguf_loader, prefix)
...@@ -66,27 +83,36 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_whe ...@@ -66,27 +83,36 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_whe
load_weights(child, gguf_loader, prefix+name+".") load_weights(child, gguf_loader, prefix+name+".")
else: else:
module.load() module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True):
import os import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape batch_size, seq_length = inputs.shape
torch_device = inputs.device device_map = model.config.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
inputs = inputs.to(torch_device)
all_cuda_device = get_all_used_cuda_device(device_map)
tokens = [] tokens = []
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values): def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
logits = cuda_graph_runner(cur_token, position_ids, cache_position) if use_cuda_graph:
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
torch.cuda.set_device(torch_device)
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
past_key_values.change_seq_length(1) past_key_values.change_seq_length(1)
""" for device in all_cuda_device:
with torch.cuda.stream(custom_stream): torch.cuda.synchronize(device)
logits=model(cur_token,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
#"""
torch.cuda.synchronize()
#print(logits) #print(logits)
next_token_scores = logits_warper(inputs, logits[:, -1, :]) next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample: if generation_config.do_sample:
...@@ -95,11 +121,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): ...@@ -95,11 +121,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
else: else:
next_token = torch.argmax(next_token_scores, dim=-1) next_token = torch.argmax(next_token_scores, dim=-1)
return next_token return next_token
torch.cuda.set_device(torch_device)
with torch.no_grad(): with torch.no_grad():
stream = TextStreamer(tokenizer) stream = TextStreamer(tokenizer)
past_key_values = StaticCache( past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = torch_device, dtype = model.dtype config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
) )
cache_position = torch.arange(seq_length, device=torch_device) cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros( generated_ids = torch.zeros(
...@@ -108,23 +135,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): ...@@ -108,23 +135,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int) generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
past_key_values.cur_idx=cache_position past_key_values.cur_idx=cache_position
start_time = time.time() start_time = time.time()
#custom_stream = torch.cuda.Stream()
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to("cuda") inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
logits = model( logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone() )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
generation_config, model_kwargs = model._prepare_generation_config( generation_config, model_kwargs = model._prepare_generation_config(
None, max_length=max_new_tokens, None, max_length=max_new_tokens,
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config
) )
try: # transformers==4.43 try: # transformers==4.43
logits_warper = ( logits_warper = (
model._get_logits_warper(generation_config,device=inputs.device) if generation_config.do_sample else None model._get_logits_warper(generation_config,device=inputs.device)
) )
except: except:
logits_warper = ( logits_warper = (
model._get_logits_warper(generation_config) if generation_config.do_sample else None model._get_logits_warper(generation_config)
) )
next_token_scores = logits_warper(inputs, logits[:, -1, :]) next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample: if generation_config.do_sample:
...@@ -136,7 +162,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): ...@@ -136,7 +162,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
prefill_count = seq_length prefill_count = seq_length
prefill_time = first_token_time prefill_time = first_token_time
print(stream.put(next_token.item()), end="", flush=True) print(stream.put(next_token.item()), end="", flush=True)
generated_ids[:, seq_length] = next_token generated_ids[:, seq_length] = next_token
tokens.append(next_token) tokens.append(next_token)
...@@ -144,12 +169,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): ...@@ -144,12 +169,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
cache_position = torch.tensor([seq_length], device=torch_device) cache_position = torch.tensor([seq_length], device=torch_device)
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
seq_length += 1 seq_length += 1
cuda_graph_runner = CUDAGraphRunner() if use_cuda_graph:
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, return_dict=False, use_cache=True) cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
else:
cuda_graph_runner = None
start_time = time.time() start_time = time.time()
for _ in range(1, max_new_tokens): for _ in range(1, max_new_tokens):
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values) next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int() generated_ids[:, cache_position] = next_token.int()
tokens.append(next_token.int()) tokens.append(next_token.int())
...@@ -162,6 +191,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000): ...@@ -162,6 +191,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
print(stream.put(next_token.item()), end="", flush=True) print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1 cache_position += 1
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
total_time = time.time() - start_time total_time = time.time() - start_time
tokens_generated = len(tokens) tokens_generated = len(tokens)
......
...@@ -3,7 +3,8 @@ requires = [ ...@@ -3,7 +3,8 @@ requires = [
"setuptools", "setuptools",
"torch >= 2.3.0", "torch >= 2.3.0",
"ninja", "ninja",
"packaging" "packaging",
"cpufeature"
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
......
...@@ -6,7 +6,7 @@ Author : chenxl ...@@ -6,7 +6,7 @@ Author : chenxl
Date : 2024-07-27 16:15:27 Date : 2024-07-27 16:15:27
Version : 1.0.0 Version : 1.0.0
LastEditors : chenxl LastEditors : chenxl
LastEditTime : 2024-07-31 09:44:46 LastEditTime : 2024-08-08 02:45:15
Adapted from: Adapted from:
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao. Copyright (c) 2023, Tri Dao.
...@@ -19,6 +19,7 @@ import re ...@@ -19,6 +19,7 @@ import re
import ast import ast
import subprocess import subprocess
import platform import platform
import shutil
import http.client import http.client
import urllib.request import urllib.request
import urllib.error import urllib.error
...@@ -27,6 +28,7 @@ from packaging.version import parse ...@@ -27,6 +28,7 @@ from packaging.version import parse
import torch.version import torch.version
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from setuptools import setup, Extension from setuptools import setup, Extension
from cpufeature.extension import CPUFeature
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
class CpuInstructInfo: class CpuInstructInfo:
...@@ -67,6 +69,8 @@ class VersionInfo: ...@@ -67,6 +69,8 @@ class VersionInfo:
""" """
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
return f'linux_{platform.uname().machine}' return f'linux_{platform.uname().machine}'
elif sys.platform == "win32":
return "win_amd64"
else: else:
raise ValueError("Unsupported platform: {}".format(sys.platform)) raise ValueError("Unsupported platform: {}".format(sys.platform))
...@@ -97,6 +101,15 @@ class VersionInfo: ...@@ -97,6 +101,15 @@ class VersionInfo:
return 'avx2' return 'avx2'
raise ValueError( raise ValueError(
"Unsupported cpu Instructions: {}".format(flags_line)) "Unsupported cpu Instructions: {}".format(flags_line))
elif sys.platform == "win32":
if CPUFeature.get("AVX512bw", False):
return 'fancy'
if CPUFeature.get("AVX512f", False):
return 'avx512'
if CPUFeature.get("AVX2", False):
return 'avx2'
raise ValueError(
"Unsupported cpu Instructions: {}".format(str(CPUFeature)))
else: else:
raise ValueError("Unsupported platform: {}".format(sys.platform)) raise ValueError("Unsupported platform: {}".format(sys.platform))
...@@ -154,7 +167,7 @@ class BuildWheelsCommand(_bdist_wheel): ...@@ -154,7 +167,7 @@ class BuildWheelsCommand(_bdist_wheel):
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path) print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path) shutil.move(wheel_filename, wheel_path)
except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected): except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):
print("Precompiled wheel not found. Building from source...") print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source # If the wheel could not be downloaded, build from source
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <cstring> #include <cstring>
#include <type_traits> #include <type_traits>
#if defined __x86_64__ || defined __aarch64__ #if defined __x86_64__ || defined __aarch64__ || defined(_M_X64)
#include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml-quants.h"
...@@ -225,7 +225,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const voi ...@@ -225,7 +225,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const voi
return true; return true;
} }
#if defined __x86_64__ #if defined __x86_64__ || defined(_M_X64)
#if defined HAVE_FANCY_SIMD #if defined HAVE_FANCY_SIMD
#undef HAVE_FANCY_SIMD #undef HAVE_FANCY_SIMD
...@@ -1412,7 +1412,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { ...@@ -1412,7 +1412,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int) { bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int) {
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); if (ne00 % ggml_blck_size(GGML_TYPE_Q8_K) == 0)
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);
switch (typeA) { switch (typeA) {
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
......
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
// Copyrigth 2024 Iwan Kawrakow. // Copyrigth 2024 Iwan Kawrakow.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#include "iqk_mul_mat.inc" #include "iqk_mul_mat.inc"
#endif // __x86_64__ #endif // __x86_64__
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Copyrigth 2024 Iwan Kawrakow. // Copyrigth 2024 Iwan Kawrakow.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define iqk_mul_mat iqk_mul_mat_zen4 #define iqk_mul_mat iqk_mul_mat_zen4
#define iqk_mul_mat_moe iqk_mul_mat_moe_zen4 #define iqk_mul_mat_moe iqk_mul_mat_moe_zen4
#include "iqk_mul_mat.inc" #include "iqk_mul_mat.inc"
......
...@@ -22,19 +22,22 @@ ...@@ -22,19 +22,22 @@
#include "sgemm.h" #include "sgemm.h"
// #include <cosmo.h> // #include <cosmo.h>
#include <cpuid.h> // #include <cpuid.h>
// #include <libc/sysv/consts/hwcap.h> // #include <libc/sysv/consts/hwcap.h>
#include <stdio.h> #include <stdio.h>
#include <sys/auxv.h> // #include <sys/auxv.h>
#include <cassert> #include <cassert>
// #include "llamafile.h" // #include "llamafile.h"
static const struct GemmFuncs { static const struct GemmFuncs {
typeof(llamafile_sgemm)* sgemm; bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);
typeof(llamafile_mixmul)* mixmul; bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);
typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported; bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);
// typeof(llamafile_sgemm)* sgemm;
// typeof(llamafile_mixmul)* mixmul;
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
GemmFuncs() { GemmFuncs() {
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
// if (X86_HAVE(AVX)) { // if (X86_HAVE(AVX)) {
// if (X86_HAVE(FMA)) { // if (X86_HAVE(FMA)) {
// if (X86_HAVE(AVX2)) { // if (X86_HAVE(AVX2)) {
...@@ -86,10 +89,12 @@ static const struct GemmFuncs { ...@@ -86,10 +89,12 @@ static const struct GemmFuncs {
// sgemm = llamafile_sgemm_unsupported; // sgemm = llamafile_sgemm_unsupported;
// mixmul = llamafile_mixmul_unsupported; // mixmul = llamafile_mixmul_unsupported;
// } // }
#if defined(__AVX__) #if defined(__AVX__)
#if defined(__FMA__) #if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX2__) #if defined(__AVX2__)
#if defined(__AVX512F__) #if defined(__AVX512F__)
printf("__AVX512F__\n");
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__) #if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-) // AMD Zen4+ (2023-)
sgemm = llamafile_sgemm_amd_zen4; sgemm = llamafile_sgemm_amd_zen4;
......
...@@ -223,7 +223,7 @@ inline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t* e) ...@@ -223,7 +223,7 @@ inline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t* e)
} }
#endif #endif
#if defined(__FMA__) #if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <> template <>
inline __m256 madd(__m256 a, __m256 b, __m256 c) { inline __m256 madd(__m256 a, __m256 b, __m256 c) {
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avx #define llamafile_mixmul llamafile_mixmul_amd_avx
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avx2 #define llamafile_mixmul llamafile_mixmul_amd_avx2
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__ #endif // __x86_64__
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avx512f #define llamafile_mixmul llamafile_mixmul_amd_avx512f
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__ #endif // __x86_64__
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avxvnni #define llamafile_mixmul llamafile_mixmul_amd_avxvnni
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__ #endif // __x86_64__
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_fma #define llamafile_mixmul llamafile_mixmul_amd_fma
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__ #endif // __x86_64__
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_zen4 #define llamafile_mixmul llamafile_mixmul_amd_zen4
#include "tinyblas_cpu_mixmul.inc" #include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__ #endif // __x86_64__
...@@ -321,8 +321,8 @@ bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void ...@@ -321,8 +321,8 @@ bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void
assert(ith < nth); assert(ith < nth);
#if QK_K == 256 #if QK_K == 256
#if defined(__x86_64__) #if defined(__x86_64__) || defined(_M_X64)
#if defined(__AVX2__) && defined(__FMA__) #if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
// if (X86_CHECK(AVX2) && X86_CHECK(FMA)) { // if (X86_CHECK(AVX2) && X86_CHECK(FMA)) {
if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) { if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) { if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_avx #define llamafile_sgemm llamafile_sgemm_amd_avx
#include "tinyblas_cpu_sgemm.inc" #include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__ #endif // __x86_64__
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_avx2 #define llamafile_sgemm llamafile_sgemm_amd_avx2
#include "tinyblas_cpu_sgemm.inc" #include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__ #endif // __x86_64__
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation. // Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_avx512f #define llamafile_sgemm llamafile_sgemm_amd_avx512f
#include "tinyblas_cpu_sgemm.inc" #include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__ #endif // __x86_64__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment