Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
f5f79f5c
Commit
f5f79f5c
authored
Aug 12, 2024
by
chenxl
Browse files
[ADD] support multi-gpu qlen>1 q5_k
parent
f2938031
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
240 additions
and
85 deletions
+240
-85
ktransformers/util/cuda_graph_runner.py
ktransformers/util/cuda_graph_runner.py
+14
-4
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+119
-24
ktransformers/util/utils.py
ktransformers/util/utils.py
+60
-30
pyproject.toml
pyproject.toml
+2
-1
setup.py
setup.py
+15
-2
third_party/llamafile/iqk_mul_mat.inc
third_party/llamafile/iqk_mul_mat.inc
+4
-3
third_party/llamafile/iqk_mul_mat_amd_avx2.cpp
third_party/llamafile/iqk_mul_mat_amd_avx2.cpp
+1
-1
third_party/llamafile/iqk_mul_mat_amd_zen4.cpp
third_party/llamafile/iqk_mul_mat_amd_zen4.cpp
+1
-1
third_party/llamafile/sgemm.cpp
third_party/llamafile/sgemm.cpp
+12
-7
third_party/llamafile/tinyblas_cpu.h
third_party/llamafile/tinyblas_cpu.h
+1
-1
third_party/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp
third_party/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp
+1
-1
third_party/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp
third_party/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp
+1
-1
third_party/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp
third_party/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp
+1
-1
third_party/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp
third_party/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp
+1
-1
third_party/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp
third_party/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp
+1
-1
third_party/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp
third_party/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp
+1
-1
third_party/llamafile/tinyblas_cpu_sgemm.inc
third_party/llamafile/tinyblas_cpu_sgemm.inc
+2
-2
third_party/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp
third_party/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp
+1
-1
third_party/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp
third_party/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp
+1
-1
third_party/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp
third_party/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp
+1
-1
No files found.
ktransformers/util/cuda_graph_runner.py
View file @
f5f79f5c
...
...
@@ -21,6 +21,7 @@ class CUDAGraphRunner:
position_ids
,
cache_position
,
past_key_values
,
main_device
,
**
kwargs
,
)
->
None
:
assert
self
.
graph
is
None
...
...
@@ -29,15 +30,24 @@ class CUDAGraphRunner:
self
.
graph
=
torch
.
cuda
.
CUDAGraph
()
#self.graph.enable_debug_mode()
self
.
model
=
model
inputs_embeds
=
model
.
model
.
embed_tokens
(
cur_token
.
to
(
"cpu"
)).
to
(
"cuda"
)
with
torch
.
cuda
.
graph
(
self
.
graph
):
inputs_embeds
=
model
.
model
.
embed_tokens
(
cur_token
.
to
(
"cpu"
)).
to
(
main_device
)
# torch.cuda.set_device can't set "cuda", must have a index
if
main_device
==
"cuda"
:
main_device
=
"cuda:0"
torch
.
cuda
.
set_device
(
main_device
)
self
.
main_device
=
main_device
capture_stream
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
graph
(
self
.
graph
,
stream
=
capture_stream
):
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
position_ids
=
position_ids
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
**
kwargs
)[
0
]
capture_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
torch
.
cuda
.
set_device
(
main_device
)
torch
.
cuda
.
set_stream
(
capture_stream
)
past_key_values
.
change_seq_length
(
-
1
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
(
self
.
main_device
)
#self.graph.debug_dump("cuda_graph_hooked.dot")
# Save the input and output buffers.
...
...
@@ -65,7 +75,7 @@ class CUDAGraphRunner:
#print("begin replay")
#time.sleep(1)
self
.
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
(
self
.
main_device
)
# Return the output tensor.
return
self
.
output_buffers
[
"logits"
]
...
...
ktransformers/util/custom_gguf.py
View file @
f5f79f5c
...
...
@@ -5,8 +5,11 @@ Description :
Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54
Version : 1.0.0
LastEditors : Azure
LastEditTime : 2024-07-26 09:28:25
LastEditors : kkk1nak0
LastEditTime : 2024-08-09 08:03:44
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
# copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf
...
...
@@ -15,6 +18,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import
struct
import
warnings
import
numpy
as
np
import
re
import
numpy.typing
as
npt
from
typing
import
Sequence
import
os
...
...
@@ -96,6 +100,8 @@ def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantization
GGML_TYPES
=
{
"F32"
:
0
,
"F16"
:
1
,
"Q4_0"
:
2
,
"Q5_0"
:
6
,
"Q8_0"
:
8
,
"Q2_K"
:
10
,
"Q3_K"
:
11
,
...
...
@@ -109,6 +115,8 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
GGML_BLOCK_SIZES
=
{
"F32"
:
4
,
"F16"
:
2
,
"Q4_0"
:
2
+
16
,
"Q5_0"
:
2
+
4
+
16
,
"Q8_0"
:
2
+
32
,
"Q2_K"
:
256
//
16
+
256
//
4
+
2
+
2
,
"Q3_K"
:
256
//
8
+
256
//
4
+
12
+
2
,
...
...
@@ -120,6 +128,8 @@ GGML_BLOCK_SIZES = {
GGML_ELEMENTS_PER_BLOCK
=
{
"F32"
:
1
,
"F16"
:
1
,
"Q4_0"
:
32
,
"Q5_0"
:
32
,
"Q8_0"
:
32
,
"Q2_K"
:
256
,
"Q3_K"
:
256
,
...
...
@@ -128,14 +138,6 @@ GGML_ELEMENTS_PER_BLOCK = {
"Q6_K"
:
256
,
}
# DATA_TYPES = {
# "uint32": 4,
# "int32": 5,
# "float32": 6,
# "string": 8,
# "array": 9,
# "uint64": 10,
# }
DATA_TYPES
=
{
"uint8"
:
0
,
"int8"
:
1
,
...
...
@@ -167,6 +169,7 @@ class GGUFLoader:
self
.
tensor_file_map
=
{}
self
.
file_data_map
=
{}
self
.
gguf_file_meta
=
{}
self
.
tensor_device_map
=
{}
# Walk through all the .gguf files in the directory
for
root
,
dirs
,
files
in
os
.
walk
(
gguf_path
):
...
...
@@ -272,7 +275,7 @@ class GGUFLoader:
def
load_gguf_tensor
(
self
,
name
:
str
,
device
:
str
=
"cpu"
)
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
shape
=
t
[
"shape"
]
ggml_type
=
t
[
"ggml_type"
]
...
...
@@ -282,15 +285,28 @@ class GGUFLoader:
ggml_name
=
GGML_NAMES
[
ggml_type
]
data
=
self
.
get_mmap_tensor
(
name
)
if
"cuda"
in
device
.
lower
():
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
)
#values = GGML_DEQUANTIZE[ggml_name](data)
#print("load_gguf_tensor")
#values = torch.from_numpy(values).to(device = device)
else
:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
)
return
values
.
view
(
shape
[::
-
1
])
values
=
values
.
view
(
shape
[::
-
1
])
if
"attn_q"
in
name
and
self
.
gguf_file_meta
[
'general.architecture'
]
in
[
"llama"
]:
n_head
=
self
.
gguf_file_meta
[
'llama.attention.head_count'
]
values
=
(
values
.
reshape
(
n_head
,
values
.
shape
[
0
]
//
n_head
//
2
,
2
,
*
values
.
shape
[
1
:])
.
swapaxes
(
1
,
2
)
.
reshape
(
values
.
shape
))
elif
"attn_k"
in
name
and
self
.
gguf_file_meta
[
'general.architecture'
]
in
[
"llama"
]:
n_head
=
self
.
gguf_file_meta
[
'llama.attention.head_count_kv'
]
values
=
(
values
.
reshape
(
n_head
,
values
.
shape
[
0
]
//
n_head
//
2
,
2
,
*
values
.
shape
[
1
:])
.
swapaxes
(
1
,
2
)
.
reshape
(
values
.
shape
))
return
values
def
read_value
(
f
,
data_type
):
if
data_type
==
DATA_TYPES
[
"string"
]:
...
...
@@ -375,7 +391,7 @@ def dequantize_q2_k(data):
return
d
*
(
scales
&
15
)
*
(
tmp
&
3
)
-
dmin
*
(
scales
>>
4
)
def
dequantize_q2_k_gpu
(
data
):
pass
raise
NotImplementedError
()
def
dequantize_q3_k
(
data
):
# C implementation
...
...
@@ -420,7 +436,7 @@ def dequantize_q3_k(data):
],
axis
=
1
)
def
dequantize_q3_k_gpu
(
data
):
pass
raise
NotImplementedError
()
def
dequantize_q4_k
(
data
):
# C implementation
...
...
@@ -429,20 +445,16 @@ def dequantize_q4_k(data):
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116
block_size
=
GGML_BLOCK_SIZES
[
"Q4_K"
]
num_blocks
=
len
(
data
)
//
block_size
data_f16
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float16
).
reshape
(
num_blocks
,
block_size
//
2
)
data_u8
=
np
.
frombuffer
(
data
,
dtype
=
np
.
uint8
).
reshape
(
num_blocks
,
block_size
)
# Casting to float32 because float16 is very slow on CPU
scale_factors
=
data_f16
[:,
0
].
reshape
(
num_blocks
,
1
,
1
).
astype
(
np
.
float32
)
scale_offsets
=
data_f16
[:,
1
].
reshape
(
num_blocks
,
1
,
1
).
astype
(
np
.
float32
)
qs1
=
data_u8
[:,
4
:
16
].
reshape
(
num_blocks
,
12
,
1
)
qs2
=
data_u8
[:,
16
:].
reshape
(
num_blocks
,
4
,
32
)
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
factors
=
scale_factors
*
np
.
concatenate
([
qs1
[:,
0
:
4
]
&
0b111111
,
(
qs1
[:,
8
:]
&
15
)
|
((
qs1
[:,
0
:
4
]
>>
6
)
<<
4
)],
axis
=
1
)
offsets
=
scale_offsets
*
np
.
concatenate
([
qs1
[:,
4
:
8
]
&
0b111111
,
(
qs1
[:,
8
:]
>>
4
)
|
((
qs1
[:,
4
:
8
]
>>
6
)
<<
4
)],
axis
=
1
)
# Interleave low and high quantized bits
qs2
=
np
.
stack
([
qs2
&
0xf
,
qs2
>>
4
],
axis
=
2
).
reshape
(
num_blocks
,
8
,
32
)
# Dequantize final weights using scales and offsets
...
...
@@ -512,9 +524,14 @@ def dequantize_q5_k(data):
d8
*
(
qs_hi_4
[:,
3
]
+
(
bits
[:,
:,
7
]
<<
4
))
-
m8
,
],
axis
=
1
)
def
dequantize_q5_k_gpu
(
data
):
pass
def
dequantize_q5_k_gpu
(
data
,
device
:
str
=
"cuda"
):
block_size
=
GGML_BLOCK_SIZES
[
"Q5_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q5_k
(
data
,
block_size
,
device
)
def
dequantize_q6_k
(
data
):
# C implementation
...
...
@@ -571,7 +588,49 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
num_blocks
=
len
(
data
)
//
block_size
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q6_k
(
data
,
210
,
device
)
return
KTransformersOps
.
dequantize_q6_k
(
data
,
block_size
,
device
)
def
dequantize_q4_0
(
data
):
# C implementation
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515
# C struct definition
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L141
num_blocks
=
len
(
data
)
//
GGML_BLOCK_SIZES
[
"Q4_0"
]
scales
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float16
).
reshape
(
num_blocks
,
1
+
8
)[:,
:
1
].
astype
(
np
.
float32
)
qs
=
np
.
frombuffer
(
data
,
dtype
=
np
.
uint8
).
reshape
(
num_blocks
,
2
+
16
)[:,
2
:]
return
np
.
concatenate
([
scales
*
((
qs
&
0xf
).
astype
(
np
.
int8
)
-
8
),
scales
*
((
qs
>>
4
).
astype
(
np
.
int8
)
-
8
),
],
axis
=
1
)
def
dequantize_q4_0_gpu
(
data
):
raise
NotImplementedError
()
def
dequantize_q5_0
(
data
):
# C implementation
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1556
# C struct definition
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L161
num_blocks
=
len
(
data
)
//
GGML_BLOCK_SIZES
[
"Q5_0"
]
scales
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float16
).
reshape
(
num_blocks
,
1
+
2
+
8
)[:,
:
1
].
astype
(
np
.
float32
)
qh
=
np
.
frombuffer
(
data
,
dtype
=
np
.
uint8
).
reshape
(
num_blocks
,
2
+
4
+
16
)[:,
2
:
2
+
4
]
qs
=
np
.
frombuffer
(
data
,
dtype
=
np
.
uint8
).
reshape
(
num_blocks
,
2
+
4
+
16
)[:,
2
+
4
:]
bits
=
np
.
unpackbits
(
qh
,
axis
=-
1
,
bitorder
=
"little"
)
x0
=
((
qs
&
0xf
).
astype
(
np
.
int8
)
|
(
bits
[:,
:
16
]
<<
4
))
-
16
x1
=
((
qs
>>
4
).
astype
(
np
.
int8
)
|
(
bits
[:,
16
:]
<<
4
))
-
16
return
np
.
concatenate
([
scales
*
x0
,
scales
*
x1
,
],
axis
=
1
)
def
dequantize_q5_0_gpu
(
data
):
raise
NotImplementedError
()
def
dequantize_q8_0
(
data
):
# C struct definition
...
...
@@ -615,6 +674,8 @@ def dequantize_f16_gpu(data, device):
GGML_DEQUANTIZE
=
{
"F32"
:
dequantize_f32
,
"F16"
:
dequantize_f16
,
"Q4_0"
:
dequantize_q4_0
,
"Q5_0"
:
dequantize_q5_0
,
"Q8_0"
:
dequantize_q8_0
,
"Q2_K"
:
dequantize_q2_k
,
"Q3_K"
:
dequantize_q3_k
,
...
...
@@ -626,6 +687,8 @@ GGML_DEQUANTIZE = {
GGML_DEQUANTIZE_GPU
=
{
"F32"
:
dequantize_f32_gpu
,
"F16"
:
dequantize_f16_gpu
,
"Q4_0"
:
dequantize_q4_0_gpu
,
"Q5_0"
:
dequantize_q5_0_gpu
,
"Q8_0"
:
dequantize_q8_0_gpu
,
"Q2_K"
:
dequantize_q2_k_gpu
,
"Q3_K"
:
dequantize_q3_k_gpu
,
...
...
@@ -634,7 +697,34 @@ GGML_DEQUANTIZE_GPU = {
"Q6_K"
:
dequantize_q6_k_gpu
,
}
def
translate_name_to_gguf_mixtral
(
name
):
replacement_template
=
{
"w1.weight"
:
"ffn_gate"
,
"w2.weight"
:
"ffn_down"
,
"w3.weight"
:
"ffn_up"
}
pattern
=
re
.
compile
(
r
"model.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(w\d\.weight)"
)
def
replace_match
(
match
):
blk_id
=
match
.
group
(
1
)
expert_id
=
match
.
group
(
2
)
weight_type
=
match
.
group
(
3
)
if
weight_type
in
replacement_template
:
return
f
"blk.
{
blk_id
}
.
{
replacement_template
[
weight_type
]
}
.
{
expert_id
}
.weight"
else
:
return
match
.
group
(
0
)
new_name
=
re
.
sub
(
pattern
,
replace_match
,
name
)
return
new_name
def
translate_name_to_gguf
(
name
):
name
=
translate_name_to_gguf_mixtral
(
name
)
name
=
name
.
replace
(
"lm_head."
,
"output."
)
name
=
name
.
replace
(
"model.embed_tokens."
,
"token_embd."
)
name
=
name
.
replace
(
"model.norm."
,
"output_norm."
)
...
...
@@ -671,9 +761,14 @@ def translate_name_to_gguf(name):
name
=
name
.
replace
(
".mlp.experts.ffn_gate_exps"
,
".ffn_gate_exps"
)
name
=
name
.
replace
(
".mlp.experts.ffn_up_exps"
,
".ffn_up_exps"
)
name
=
name
.
replace
(
".block_sparse_moe.gate."
,
".ffn_gate_inp."
)
name
=
name
.
replace
(
".block_sparse_moe.experts"
,
""
)
return
name
if
__name__
==
'__main__'
:
gguf_path
=
'/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'
loader
=
GGUFLoader
(
gguf_path
)
loader
.
load_gguf_tensor
(
'token_embd.weight'
)
ktransformers/util/utils.py
View file @
f5f79f5c
...
...
@@ -39,6 +39,22 @@ def set_param(module: nn.Module, name: str, weights: torch.Tensor):
param
.
unsqueeze_
(
0
)
setattr
(
module
,
name
,
param
)
def
get_device
(
gguf_module_key
:
str
,
device_map
:
dict
):
if
gguf_module_key
in
device_map
:
return
device_map
[
gguf_module_key
][
"generate_device"
]
else
:
return
"cuda"
def
get_all_used_cuda_device
(
device_map
:
dict
):
all_device_list
=
set
()
for
key
in
device_map
:
all_device_list
.
add
(
device_map
[
key
][
"generate_device"
])
if
"generate_device"
in
device_map
[
key
]
else
None
all_device_list
.
add
(
device_map
[
key
][
"prefill_device"
])
if
"prefill_device"
in
device_map
[
key
]
else
None
if
"cpu"
in
all_device_list
:
all_device_list
.
remove
(
"cpu"
)
all_device_list
=
list
(
all_device_list
)
return
all_device_list
def
load_cur_state_dict
(
module
:
nn
.
Module
,
gguf_loader
:
GGUFLoader
,
prefix
:
str
=
""
):
prefix
=
prefix
.
replace
(
"orig_module."
,
""
)
persistent_buffers
=
{
k
:
v
for
k
,
v
in
module
.
_buffers
.
items
()
if
k
not
in
module
.
_non_persistent_buffers_set
}
...
...
@@ -47,18 +63,19 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
for
name
,
param
in
local_state
.
items
():
key
=
prefix
+
name
translated_key
=
translate_name_to_gguf
(
key
)
print
(
"default loading weights"
,
key
,
translated_key
)
if
translated_key
in
gguf_loader
.
tensor_file_map
:
target_dtype
=
torch
.
get_default_dtype
()
device
=
"cpu"
if
"embd"
in
translated_key
else
"cuda"
device
=
get_device
(
translated_key
[:
translated_key
.
rfind
(
"."
)],
gguf_loader
.
tensor_device_map
)
print
(
f
"loading
{
translated_key
}
to
{
device
}
"
)
# device = "cpu" if "embd" in translated_key else "cuda"
weights
=
gguf_loader
.
load_gguf_tensor
(
translated_key
,
device
=
device
).
to
(
dtype
=
target_dtype
)
set_param
(
module
,
name
,
weights
)
del
weights
else
:
#print(load_config.tensor_file_map.keys())
raise
Exception
(
f
"can't f
a
nd
{
translated_key
}
in GGUF file!"
)
raise
Exception
(
f
"can't f
i
nd
{
translated_key
}
in GGUF file!"
)
def
load_weights
(
module
:
nn
.
Module
,
gguf_loader
:
GGUFLoader
,
prefix
=
''
,
return_when_injected
:
bool
=
False
,
only_load_injected
:
bool
=
False
):
def
load_weights
(
module
:
nn
.
Module
,
gguf_loader
:
GGUFLoader
,
prefix
=
''
):
# print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}")
if
not
isinstance
(
module
,
base_operator
.
BaseInjectedModule
):
load_cur_state_dict
(
module
,
gguf_loader
,
prefix
)
...
...
@@ -66,27 +83,36 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_whe
load_weights
(
child
,
gguf_loader
,
prefix
+
name
+
"."
)
else
:
module
.
load
()
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
):
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
):
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
torch
.
_dynamo
.
config
.
suppress_errors
=
True
batch_size
,
seq_length
=
inputs
.
shape
torch_device
=
inputs
.
device
device_map
=
model
.
config
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
'blk.0.self_attn'
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
inputs
=
inputs
.
to
(
torch_device
)
all_cuda_device
=
get_all_used_cuda_device
(
device_map
)
tokens
=
[]
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
):
logits
=
cuda_graph_runner
(
cur_token
,
position_ids
,
cache_position
)
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
:
bool
=
True
):
if
use_cuda_graph
:
logits
=
cuda_graph_runner
(
cur_token
,
position_ids
,
cache_position
)
else
:
# custom_stream = torch.cuda.Stream()
torch
.
cuda
.
set_device
(
torch_device
)
inputs_embeds
=
model
.
model
.
embed_tokens
(
cur_token
.
to
(
"cpu"
)).
to
(
torch_device
)
# with torch.cuda.stream(custom_stream):
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
position_ids
=
position_ids
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
]
past_key_values
.
change_seq_length
(
1
)
"""
with torch.cuda.stream(custom_stream):
logits=model(cur_token,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
#"""
torch
.
cuda
.
synchronize
()
for
device
in
all_cuda_device
:
torch
.
cuda
.
synchronize
(
device
)
#print(logits)
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
if
generation_config
.
do_sample
:
...
...
@@ -95,11 +121,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
else
:
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
return
next_token
torch
.
cuda
.
set_device
(
torch_device
)
with
torch
.
no_grad
():
stream
=
TextStreamer
(
tokenizer
)
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
torch_
device
,
dtype
=
model
.
dtype
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device
_map
,
dtype
=
model
.
dtype
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
)
generated_ids
=
torch
.
zeros
(
...
...
@@ -108,23 +135,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
generated_ids
[:,
cache_position
]
=
inputs
.
to
(
torch_device
).
to
(
torch
.
int
)
past_key_values
.
cur_idx
=
cache_position
start_time
=
time
.
time
()
#custom_stream = torch.cuda.Stream()
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
"cuda"
)
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
torch_device
)
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
()
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
()
.
to
(
torch_device
)
generation_config
,
model_kwargs
=
model
.
_prepare_generation_config
(
None
,
max_length
=
max_new_tokens
,
do_sample
=
True
,
top_k
=
5
,
top_p
=
0.85
,
temperature
=
0.1
# change this to modify generate config
)
try
:
# transformers==4.43
logits_warper
=
(
model
.
_get_logits_warper
(
generation_config
,
device
=
inputs
.
device
)
if
generation_config
.
do_sample
else
None
model
.
_get_logits_warper
(
generation_config
,
device
=
inputs
.
device
)
)
except
:
logits_warper
=
(
model
.
_get_logits_warper
(
generation_config
)
if
generation_config
.
do_sample
else
None
model
.
_get_logits_warper
(
generation_config
)
)
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
if
generation_config
.
do_sample
:
...
...
@@ -136,7 +162,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
prefill_count
=
seq_length
prefill_time
=
first_token_time
print
(
stream
.
put
(
next_token
.
item
()),
end
=
""
,
flush
=
True
)
generated_ids
[:,
seq_length
]
=
next_token
tokens
.
append
(
next_token
)
...
...
@@ -144,12 +169,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
seq_length
+=
1
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)
if
use_cuda_graph
:
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
else
:
cuda_graph_runner
=
None
start_time
=
time
.
time
()
for
_
in
range
(
1
,
max_new_tokens
):
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
tokens
.
append
(
next_token
.
int
())
...
...
@@ -162,6 +191,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
print
(
stream
.
put
(
next_token
.
item
()),
end
=
""
,
flush
=
True
)
cache_position
+=
1
position_ids
=
cache_position
.
unsqueeze
(
0
)
total_time
=
time
.
time
()
-
start_time
tokens_generated
=
len
(
tokens
)
...
...
pyproject.toml
View file @
f5f79f5c
...
...
@@ -3,7 +3,8 @@ requires = [
"setuptools"
,
"torch >= 2.3.0"
,
"ninja"
,
"packaging"
"packaging"
,
"cpufeature"
]
build-backend
=
"setuptools.build_meta"
...
...
setup.py
View file @
f5f79f5c
...
...
@@ -6,7 +6,7 @@ Author : chenxl
Date : 2024-07-27 16:15:27
Version : 1.0.0
LastEditors : chenxl
LastEditTime : 2024-0
7-31
0
9
:4
4:46
LastEditTime : 2024-0
8-08
0
2
:4
5:15
Adapted from:
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao.
...
...
@@ -19,6 +19,7 @@ import re
import
ast
import
subprocess
import
platform
import
shutil
import
http.client
import
urllib.request
import
urllib.error
...
...
@@ -27,6 +28,7 @@ from packaging.version import parse
import
torch.version
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
from
setuptools
import
setup
,
Extension
from
cpufeature.extension
import
CPUFeature
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
class
CpuInstructInfo
:
...
...
@@ -67,6 +69,8 @@ class VersionInfo:
"""
if
sys
.
platform
.
startswith
(
"linux"
):
return
f
'linux_
{
platform
.
uname
().
machine
}
'
elif
sys
.
platform
==
"win32"
:
return
"win_amd64"
else
:
raise
ValueError
(
"Unsupported platform: {}"
.
format
(
sys
.
platform
))
...
...
@@ -97,6 +101,15 @@ class VersionInfo:
return
'avx2'
raise
ValueError
(
"Unsupported cpu Instructions: {}"
.
format
(
flags_line
))
elif
sys
.
platform
==
"win32"
:
if
CPUFeature
.
get
(
"AVX512bw"
,
False
):
return
'fancy'
if
CPUFeature
.
get
(
"AVX512f"
,
False
):
return
'avx512'
if
CPUFeature
.
get
(
"AVX2"
,
False
):
return
'avx2'
raise
ValueError
(
"Unsupported cpu Instructions: {}"
.
format
(
str
(
CPUFeature
)))
else
:
raise
ValueError
(
"Unsupported platform: {}"
.
format
(
sys
.
platform
))
...
...
@@ -154,7 +167,7 @@ class BuildWheelsCommand(_bdist_wheel):
wheel_path
=
os
.
path
.
join
(
self
.
dist_dir
,
archive_basename
+
".whl"
)
print
(
"Raw wheel path"
,
wheel_path
)
os
.
renam
e
(
wheel_filename
,
wheel_path
)
shutil
.
mov
e
(
wheel_filename
,
wheel_path
)
except
(
urllib
.
error
.
HTTPError
,
urllib
.
error
.
URLError
,
http
.
client
.
RemoteDisconnected
):
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
...
...
third_party/llamafile/iqk_mul_mat.inc
View file @
f5f79f5c
...
...
@@ -22,7 +22,7 @@
#include <cstring>
#include <type_traits>
#if defined __x86_64__ || defined __aarch64__
#if defined __x86_64__ || defined __aarch64__
|| defined(_M_X64)
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
...
...
@@ -225,7 +225,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const voi
return
true
;
}
#if defined __x86_64__
#if defined __x86_64__
|| defined(_M_X64)
#if defined HAVE_FANCY_SIMD
#undef HAVE_FANCY_SIMD
...
...
@@ -1412,7 +1412,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
bool
MulMat
::
set_mul_mat
(
int
typeA
,
int
ne00
,
MulMat
&
mm
,
int
&
row_size_q8
,
int
)
{
row_size_q8
=
ggml_row_size
(
GGML_TYPE_Q8_K
,
ne00
);
if
(
ne00
%
ggml_blck_size
(
GGML_TYPE_Q8_K
)
==
0
)
row_size_q8
=
ggml_row_size
(
GGML_TYPE_Q8_K
,
ne00
);
switch
(
typeA
)
{
case
GGML_TYPE_Q2_K
:
...
...
third_party/llamafile/iqk_mul_mat_amd_avx2.cpp
View file @
f5f79f5c
...
...
@@ -3,6 +3,6 @@
// Copyrigth 2024 Iwan Kawrakow.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#include "iqk_mul_mat.inc"
#endif // __x86_64__
third_party/llamafile/iqk_mul_mat_amd_zen4.cpp
View file @
f5f79f5c
...
...
@@ -3,7 +3,7 @@
// Copyrigth 2024 Iwan Kawrakow.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#define iqk_mul_mat iqk_mul_mat_zen4
#define iqk_mul_mat_moe iqk_mul_mat_moe_zen4
#include "iqk_mul_mat.inc"
...
...
third_party/llamafile/sgemm.cpp
View file @
f5f79f5c
...
...
@@ -22,19 +22,22 @@
#include "sgemm.h"
// #include <cosmo.h>
#include <cpuid.h>
//
#include <cpuid.h>
// #include <libc/sysv/consts/hwcap.h>
#include <stdio.h>
#include <sys/auxv.h>
//
#include <sys/auxv.h>
#include <cassert>
// #include "llamafile.h"
static
const
struct
GemmFuncs
{
typeof
(
llamafile_sgemm
)
*
sgemm
;
typeof
(
llamafile_mixmul
)
*
mixmul
;
typeof
(
llamafile_mixmul_iqk
)
*
iqk_mixmul
=
iqk_mul_mat_moe_unsupported
;
bool
(
*
sgemm
)(
long
,
long
,
long
,
const
void
*
,
long
,
const
void
*
,
long
,
void
*
,
long
,
int
,
int
,
int
,
int
,
int
,
int
,
int
);
bool
(
*
mixmul
)(
const
struct
ggml_compute_params
*
,
const
struct
ggml_tensor
*
,
const
struct
ggml_tensor
*
,
const
struct
ggml_tensor
*
,
struct
ggml_tensor
*
);
bool
(
*
iqk_mixmul
)(
long
,
long
,
long
,
int
,
int
,
const
void
*
,
const
void
*
,
float
*
,
long
,
long
,
const
void
*
,
int
,
int
);
// typeof(llamafile_sgemm)* sgemm;
// typeof(llamafile_mixmul)* mixmul;
// typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;
GemmFuncs
()
{
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
// if (X86_HAVE(AVX)) {
// if (X86_HAVE(FMA)) {
// if (X86_HAVE(AVX2)) {
...
...
@@ -86,10 +89,12 @@ static const struct GemmFuncs {
// sgemm = llamafile_sgemm_unsupported;
// mixmul = llamafile_mixmul_unsupported;
// }
#if defined(__AVX__)
#if defined(__FMA__)
#if defined(__FMA__)
|| (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX2__)
#if defined(__AVX512F__)
printf
(
"__AVX512F__
\n
"
);
#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)
// AMD Zen4+ (2023-)
sgemm
=
llamafile_sgemm_amd_zen4
;
...
...
third_party/llamafile/tinyblas_cpu.h
View file @
f5f79f5c
...
...
@@ -223,7 +223,7 @@ inline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t* e)
}
#endif
#if defined(__FMA__)
#if defined(__FMA__)
|| (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template
<
>
inline
__m256
madd
(
__m256
a
,
__m256
b
,
__m256
c
)
{
...
...
third_party/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp
View file @
f5f79f5c
...
...
@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avx
#include "tinyblas_cpu_mixmul.inc"
...
...
third_party/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp
View file @
f5f79f5c
...
...
@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avx2
#include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__
third_party/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp
View file @
f5f79f5c
...
...
@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avx512f
#include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__
third_party/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp
View file @
f5f79f5c
...
...
@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_avxvnni
#include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__
third_party/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp
View file @
f5f79f5c
...
...
@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_fma
#include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__
third_party/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp
View file @
f5f79f5c
...
...
@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#define llamafile_mixmul llamafile_mixmul_amd_zen4
#include "tinyblas_cpu_mixmul.inc"
#endif // __x86_64__
third_party/llamafile/tinyblas_cpu_sgemm.inc
View file @
f5f79f5c
...
...
@@ -321,8 +321,8 @@ bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void
assert
(
ith
<
nth
);
#if QK_K == 256
#if defined(__x86_64__)
#if defined(__AVX2__) && defined(__FMA__)
#if defined(__x86_64__)
|| defined(_M_X64)
#if defined(__AVX2__) &&
(
defined(__FMA__)
|| (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))
// if (X86_CHECK(AVX2) && X86_CHECK(FMA)) {
if
(
Btype
==
GGML_TYPE_Q8_K
&&
Ctype
==
GGML_TYPE_F32
)
{
if
(
iqk_mul_mat
(
m
,
n
,
k
*
QK_K
,
Atype
,
A
,
B
,
(
float
*
)
C
,
ldc
,
ith
,
nth
))
{
...
...
third_party/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp
View file @
f5f79f5c
...
...
@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_avx
#include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__
third_party/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp
View file @
f5f79f5c
...
...
@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_avx2
#include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__
third_party/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp
View file @
f5f79f5c
...
...
@@ -3,7 +3,7 @@
// Copyrigth 2024 Mozilla Foundation.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef
__x86_64__
#if
def
ined(
__x86_64__
) || defined(_M_X64)
#define llamafile_sgemm llamafile_sgemm_amd_avx512f
#include "tinyblas_cpu_sgemm.inc"
#endif // __x86_64__
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment