Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
18c42e67
Commit
18c42e67
authored
Jul 27, 2024
by
chenxl
Browse files
Initial commit
parents
Changes
247
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2746 additions
and
0 deletions
+2746
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py
...transformers_ext/operators/custom_marlin/quantize/gptq.py
+206
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py
...rmers_ext/operators/custom_marlin/quantize/gptq_marlin.py
+458
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py
...formers_ext/operators/custom_marlin/quantize/quantizer.py
+140
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py
...ansformers_ext/operators/custom_marlin/quantize/repack.py
+99
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/__init__.py
...rs_ext/operators/custom_marlin/quantize/utils/__init__.py
+0
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/format_24.py
...s_ext/operators/custom_marlin/quantize/utils/format_24.py
+308
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_24_perms.py
...operators/custom_marlin/quantize/utils/marlin_24_perms.py
+60
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_perms.py
...xt/operators/custom_marlin/quantize/utils/marlin_perms.py
+60
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
...xt/operators/custom_marlin/quantize/utils/marlin_utils.py
+232
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py
...ext/operators/custom_marlin/quantize/utils/quant_utils.py
+146
-0
ktransformers/ktransformers_ext/operators/llamafile/conversion.h
...ormers/ktransformers_ext/operators/llamafile/conversion.h
+33
-0
ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
...sformers/ktransformers_ext/operators/llamafile/linear.cpp
+48
-0
ktransformers/ktransformers_ext/operators/llamafile/linear.h
ktransformers/ktransformers_ext/operators/llamafile/linear.h
+56
-0
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
+103
-0
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
+67
-0
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
+311
-0
ktransformers/ktransformers_ext/operators/llamafile/moe.h
ktransformers/ktransformers_ext/operators/llamafile/moe.h
+97
-0
ktransformers/local_chat.py
ktransformers/local_chat.py
+115
-0
ktransformers/models/__init__.py
ktransformers/models/__init__.py
+0
-0
ktransformers/models/configuration_deepseek.py
ktransformers/models/configuration_deepseek.py
+207
-0
No files found.
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py
0 → 100644
View file @
18c42e67
import
math
import
os
import
time
from
logging
import
getLogger
import
torch
import
torch.nn
as
nn
import
transformers
from
.quantizer
import
Quantizer
logger
=
getLogger
(
__name__
)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
class
GPTQ
:
def
__init__
(
self
,
layer
):
self
.
layer
=
layer
self
.
dev
=
self
.
layer
.
weight
.
device
W
=
layer
.
weight
.
data
.
clone
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
W
=
W
.
flatten
(
1
)
if
isinstance
(
self
.
layer
,
transformers
.
pytorch_utils
.
Conv1D
):
W
=
W
.
t
()
self
.
rows
=
W
.
shape
[
0
]
self
.
columns
=
W
.
shape
[
1
]
self
.
H
=
torch
.
zeros
((
self
.
columns
,
self
.
columns
),
device
=
self
.
dev
)
self
.
nsamples
=
0
self
.
quantizer
=
Quantizer
()
def
add_batch
(
self
,
inp
,
out
):
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
inp1
=
inp
self
.
out1
=
out
if
len
(
inp
.
shape
)
==
2
:
inp
=
inp
.
unsqueeze
(
0
)
tmp
=
inp
.
shape
[
0
]
if
isinstance
(
self
.
layer
,
nn
.
Linear
)
or
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
if
len
(
inp
.
shape
)
==
3
:
inp
=
inp
.
reshape
((
-
1
,
inp
.
shape
[
-
1
]))
inp
=
inp
.
t
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
unfold
=
nn
.
Unfold
(
self
.
layer
.
kernel_size
,
dilation
=
self
.
layer
.
dilation
,
padding
=
self
.
layer
.
padding
,
stride
=
self
.
layer
.
stride
,
)
inp
=
unfold
(
inp
)
inp
=
inp
.
permute
([
1
,
0
,
2
])
inp
=
inp
.
flatten
(
1
)
self
.
H
*=
self
.
nsamples
/
(
self
.
nsamples
+
tmp
)
self
.
nsamples
+=
tmp
# inp = inp.float()
inp
=
math
.
sqrt
(
2
/
self
.
nsamples
)
*
inp
.
float
()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self
.
H
+=
inp
.
matmul
(
inp
.
t
())
def
fasterquant
(
self
,
blocksize
=
128
,
percdamp
=
0.01
,
group_size
=-
1
,
actorder
=
False
,
static_groups
=
False
,
):
W
=
self
.
layer
.
weight
.
data
.
clone
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
W
=
W
.
flatten
(
1
)
if
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
W
=
W
.
t
()
W
=
W
.
float
()
tick
=
time
.
time
()
if
not
self
.
quantizer
.
ready
():
self
.
quantizer
.
find_params
(
W
,
weight
=
True
)
H
=
self
.
H
del
self
.
H
dead
=
torch
.
diag
(
H
)
==
0
H
[
dead
,
dead
]
=
1
W
[:,
dead
]
=
0
g_idx
=
[]
scale
=
[]
zero
=
[]
now_idx
=
1
if
static_groups
:
import
copy
groups
=
[]
for
i
in
range
(
0
,
self
.
columns
,
group_size
):
quantizer
=
copy
.
deepcopy
(
self
.
quantizer
)
quantizer
.
find_params
(
W
[:,
i
:
(
i
+
group_size
)],
weight
=
True
)
scale
.
append
(
quantizer
.
scale
)
zero
.
append
(
quantizer
.
zero
)
groups
.
append
(
quantizer
)
if
actorder
:
perm
=
torch
.
argsort
(
torch
.
diag
(
H
),
descending
=
True
)
W
=
W
[:,
perm
]
H
=
H
[
perm
][:,
perm
]
invperm
=
torch
.
argsort
(
perm
)
Losses
=
torch
.
zeros_like
(
W
)
Q
=
torch
.
zeros_like
(
W
)
damp
=
percdamp
*
torch
.
mean
(
torch
.
diag
(
H
))
diag
=
torch
.
arange
(
self
.
columns
,
device
=
self
.
dev
)
H
[
diag
,
diag
]
+=
damp
H
=
torch
.
linalg
.
cholesky
(
H
)
H
=
torch
.
cholesky_inverse
(
H
)
H
=
torch
.
linalg
.
cholesky
(
H
,
upper
=
True
)
Hinv
=
H
for
i1
in
range
(
0
,
self
.
columns
,
blocksize
):
i2
=
min
(
i1
+
blocksize
,
self
.
columns
)
count
=
i2
-
i1
W1
=
W
[:,
i1
:
i2
].
clone
()
Q1
=
torch
.
zeros_like
(
W1
)
Err1
=
torch
.
zeros_like
(
W1
)
Losses1
=
torch
.
zeros_like
(
W1
)
Hinv1
=
Hinv
[
i1
:
i2
,
i1
:
i2
]
for
i
in
range
(
count
):
w
=
W1
[:,
i
]
d
=
Hinv1
[
i
,
i
]
if
group_size
!=
-
1
:
if
not
static_groups
:
if
(
i1
+
i
)
%
group_size
==
0
:
self
.
quantizer
.
find_params
(
W
[:,
(
i1
+
i
)
:
(
i1
+
i
+
group_size
)],
weight
=
True
)
if
((
i1
+
i
)
//
group_size
)
-
now_idx
==
-
1
:
scale
.
append
(
self
.
quantizer
.
scale
)
zero
.
append
(
self
.
quantizer
.
zero
)
now_idx
+=
1
else
:
idx
=
i1
+
i
if
actorder
:
idx
=
perm
[
idx
]
self
.
quantizer
=
groups
[
idx
//
group_size
]
q
=
self
.
quantizer
.
quantize
(
w
.
unsqueeze
(
1
)).
flatten
()
Q1
[:,
i
]
=
q
Losses1
[:,
i
]
=
(
w
-
q
)
**
2
/
d
**
2
err1
=
(
w
-
q
)
/
d
W1
[:,
i
:]
-=
err1
.
unsqueeze
(
1
).
matmul
(
Hinv1
[
i
,
i
:].
unsqueeze
(
0
))
Err1
[:,
i
]
=
err1
Q
[:,
i1
:
i2
]
=
Q1
Losses
[:,
i1
:
i2
]
=
Losses1
/
2
W
[:,
i2
:]
-=
Err1
.
matmul
(
Hinv
[
i1
:
i2
,
i2
:])
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
layer
.
weight
.
data
[:,
:
i2
]
=
Q
[:,
:
i2
]
self
.
layer
.
weight
.
data
[:,
i2
:]
=
W
[:,
i2
:]
logger
.
debug
(
torch
.
sum
((
self
.
layer
(
self
.
inp1
)
-
self
.
out1
)
**
2
))
logger
.
debug
(
torch
.
sum
(
Losses
))
torch
.
cuda
.
synchronize
()
logger
.
info
(
f
"duration:
{
(
time
.
time
()
-
tick
)
}
"
)
logger
.
info
(
f
"avg loss:
{
torch
.
sum
(
Losses
).
item
()
/
self
.
nsamples
}
"
)
group_size
=
group_size
if
group_size
!=
-
1
else
self
.
columns
if
static_groups
and
actorder
:
g_idx
=
[
perm
[
i
]
//
group_size
for
i
in
range
(
self
.
columns
)]
else
:
g_idx
=
[
i
//
group_size
for
i
in
range
(
self
.
columns
)]
g_idx
=
torch
.
tensor
(
g_idx
,
dtype
=
torch
.
int32
,
device
=
Q
.
device
)
if
actorder
:
Q
=
Q
[:,
invperm
]
g_idx
=
g_idx
[
invperm
]
if
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
Q
=
Q
.
t
()
self
.
layer
.
weight
.
data
=
Q
.
reshape
(
self
.
layer
.
weight
.
shape
).
type_as
(
self
.
layer
.
weight
.
data
)
if
os
.
environ
.
get
(
"DEBUG"
):
logger
.
debug
(
torch
.
sum
((
self
.
layer
(
self
.
inp1
)
-
self
.
out1
)
**
2
))
if
scale
==
[]:
scale
.
append
(
self
.
quantizer
.
scale
)
zero
.
append
(
self
.
quantizer
.
zero
)
scale
=
torch
.
cat
(
scale
,
dim
=
1
)
zero
=
torch
.
cat
(
zero
,
dim
=
1
)
return
scale
,
zero
,
g_idx
def
free
(
self
):
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
inp1
=
None
self
.
out1
=
None
self
.
H
=
None
self
.
Losses
=
None
self
.
Trace
=
None
torch
.
cuda
.
empty_cache
()
__all__
=
[
"GPTQ"
]
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py
0 → 100644
View file @
18c42e67
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
get_pack_factor
(
num_bits
:
int
):
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"Marlin does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"Marlin does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
raise
ValueError
(
f
"Marlin does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
# Init
self
.
pack_factor
=
get_pack_factor
(
weight_bits
)
self
.
tile_size
=
GPTQ_MARLIN_TILE
self
.
min_thread_n
=
GPTQ_MARLIN_MIN_THREAD_N
self
.
min_thread_k
=
GPTQ_MARLIN_MIN_THREAD_K
self
.
max_parallel
=
GPTQ_MARLIN_MAX_PARALLEL
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"gptq"
:
logger
.
info
(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
is_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
return
False
# If the capability of the device is too low, cannot convert.
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
cls
.
get_min_capability
():
return
False
# Otherwise, can convert if model satisfies marlin constraints.
return
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and
sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
# Validate dtype
if
params_dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]:
raise
ValueError
(
f
"The params dtype must be float16 "
f
"or bfloat16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_thread_n
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
self
.
quant_config
.
min_thread_n
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_thread_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
self
.
quant_config
.
min_thread_k
}
."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
)
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size
=
input_size
//
group_size
scales_and_zp_input_dim
=
None
if
self
.
quant_config
.
desc_act
:
# Act-order case
assert
self
.
quant_config
.
group_size
!=
-
1
is_k_full
=
input_size_per_partition
==
input_size
else
:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full
=
True
# If this is a row-parallel case, then shard scales/zp
if
(
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_input_dim
=
0
# Init buffers
# Quantized weights
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Activation order
g_idx
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
g_idx_sort_indices
=
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
},
)
# Quantized zero-points
qzeros
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
device
=
"meta"
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Allocate marlin workspace
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_thread_n
)
*
self
.
quant_config
.
max_parallel
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
layer
.
workspace
=
workspace
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
is_k_full
=
is_k_full
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py
0 → 100644
View file @
18c42e67
from
logging
import
getLogger
import
torch
import
torch.nn
as
nn
logger
=
getLogger
(
__name__
)
def
quantize
(
x
,
scale
,
zero
,
maxq
):
if
maxq
<
0
:
return
(
x
>
scale
/
2
).
float
()
*
scale
+
(
x
<
zero
/
2
).
float
()
*
zero
q
=
torch
.
clamp
(
torch
.
round
(
x
/
scale
)
+
zero
,
0
,
maxq
)
return
scale
*
(
q
-
zero
)
class
Quantizer
(
nn
.
Module
):
def
__init__
(
self
,
shape
=
1
):
super
(
Quantizer
,
self
).
__init__
()
self
.
register_buffer
(
"maxq"
,
torch
.
tensor
(
0
))
self
.
register_buffer
(
"scale"
,
torch
.
zeros
(
shape
))
self
.
register_buffer
(
"zero"
,
torch
.
zeros
(
shape
))
def
configure
(
self
,
bits
,
perchannel
=
False
,
sym
=
True
,
mse
=
False
,
norm
=
2.4
,
grid
=
100
,
maxshrink
=
0.8
,
trits
=
False
,
):
self
.
maxq
=
torch
.
tensor
(
2
**
bits
-
1
)
self
.
perchannel
=
perchannel
self
.
sym
=
sym
self
.
mse
=
mse
self
.
norm
=
norm
self
.
grid
=
grid
self
.
maxshrink
=
maxshrink
if
trits
:
self
.
maxq
=
torch
.
tensor
(
-
1
)
def
find_params
(
self
,
x
,
weight
=
False
):
dev
=
x
.
device
self
.
maxq
=
self
.
maxq
.
to
(
dev
)
shape
=
x
.
shape
if
self
.
perchannel
:
if
weight
:
x
=
x
.
flatten
(
1
)
else
:
if
len
(
shape
)
==
4
:
x
=
x
.
permute
([
1
,
0
,
2
,
3
])
x
=
x
.
flatten
(
1
)
if
len
(
shape
)
==
3
:
x
=
x
.
reshape
((
-
1
,
shape
[
-
1
])).
t
()
if
len
(
shape
)
==
2
:
x
=
x
.
t
()
else
:
x
=
x
.
flatten
().
unsqueeze
(
0
)
tmp
=
torch
.
zeros
(
x
.
shape
[
0
],
device
=
dev
)
xmin
=
torch
.
minimum
(
x
.
min
(
1
)[
0
],
tmp
)
xmax
=
torch
.
maximum
(
x
.
max
(
1
)[
0
],
tmp
)
if
self
.
sym
:
xmax
=
torch
.
maximum
(
torch
.
abs
(
xmin
),
xmax
)
tmp
=
xmin
<
0
if
torch
.
any
(
tmp
):
xmin
[
tmp
]
=
-
xmax
[
tmp
]
tmp
=
(
xmin
==
0
)
&
(
xmax
==
0
)
xmin
[
tmp
]
=
-
1
xmax
[
tmp
]
=
+
1
if
self
.
maxq
<
0
:
self
.
scale
=
xmax
self
.
zero
=
xmin
else
:
self
.
scale
=
(
xmax
-
xmin
)
/
self
.
maxq
if
self
.
sym
:
self
.
zero
=
torch
.
full_like
(
self
.
scale
,
(
self
.
maxq
+
1
)
/
2
)
else
:
self
.
zero
=
torch
.
round
(
-
xmin
/
self
.
scale
)
if
self
.
mse
:
best
=
torch
.
full
([
x
.
shape
[
0
]],
float
(
"inf"
),
device
=
dev
)
for
i
in
range
(
int
(
self
.
maxshrink
*
self
.
grid
)):
p
=
1
-
i
/
self
.
grid
xmin1
=
p
*
xmin
xmax1
=
p
*
xmax
scale1
=
(
xmax1
-
xmin1
)
/
self
.
maxq
zero1
=
torch
.
round
(
-
xmin1
/
scale1
)
if
not
self
.
sym
else
self
.
zero
q
=
quantize
(
x
,
scale1
.
unsqueeze
(
1
),
zero1
.
unsqueeze
(
1
),
self
.
maxq
)
q
-=
x
q
.
abs_
()
q
.
pow_
(
self
.
norm
)
err
=
torch
.
sum
(
q
,
1
)
tmp
=
err
<
best
if
torch
.
any
(
tmp
):
best
[
tmp
]
=
err
[
tmp
]
self
.
scale
[
tmp
]
=
scale1
[
tmp
]
self
.
zero
[
tmp
]
=
zero1
[
tmp
]
if
not
self
.
perchannel
:
if
weight
:
tmp
=
shape
[
0
]
else
:
tmp
=
shape
[
1
]
if
len
(
shape
)
!=
3
else
shape
[
2
]
self
.
scale
=
self
.
scale
.
repeat
(
tmp
)
self
.
zero
=
self
.
zero
.
repeat
(
tmp
)
if
weight
:
shape
=
[
-
1
]
+
[
1
]
*
(
len
(
shape
)
-
1
)
self
.
scale
=
self
.
scale
.
reshape
(
shape
)
self
.
zero
=
self
.
zero
.
reshape
(
shape
)
return
if
len
(
shape
)
==
4
:
self
.
scale
=
self
.
scale
.
reshape
((
1
,
-
1
,
1
,
1
))
self
.
zero
=
self
.
zero
.
reshape
((
1
,
-
1
,
1
,
1
))
if
len
(
shape
)
==
3
:
self
.
scale
=
self
.
scale
.
reshape
((
1
,
1
,
-
1
))
self
.
zero
=
self
.
zero
.
reshape
((
1
,
1
,
-
1
))
if
len
(
shape
)
==
2
:
self
.
scale
=
self
.
scale
.
unsqueeze
(
0
)
self
.
zero
=
self
.
zero
.
unsqueeze
(
0
)
def
quantize
(
self
,
x
):
if
self
.
ready
():
return
quantize
(
x
,
self
.
scale
,
self
.
zero
,
self
.
maxq
)
return
x
def
enabled
(
self
):
return
self
.
maxq
>
0
def
ready
(
self
):
return
torch
.
all
(
self
.
scale
!=
0
)
__all__
=
[
"Quantizer"
]
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py
0 → 100644
View file @
18c42e67
import
torch
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
torch.nn.parameter
import
Parameter
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/__init__.py
0 → 100644
View file @
18c42e67
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/format_24.py
0 → 100644
View file @
18c42e67
#
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
#
import
torch
# This is PyTorch implementation of main part of reorder_meta()
# function, from tools/util/include/cutlass/util/host_reorder.h file
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
# GEMM decides upon layout of this matrix, and at the moment for the
# sparse GEMM executed on tensor cores, this is layout described by
# ColumnMajorInterleaved<2> data structure, in
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
# reordering of meta matrix into meta_reordered matrix calculated
# according to these segments of CUTLASS code is re-implemented here.
# Note that this calculation produces offsets for scattering metadata
# matrix elements into reordered metadata matrix elements (or,
# equivalently, for gathering reordered metadata matrix element back
# into metadata matrix elements).
def
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
):
dst_rows
=
torch
.
arange
(
0
,
m
,
device
=
device
)[:,
None
].
repeat
(
1
,
meta_ncols
)
dst_cols
=
torch
.
arange
(
0
,
meta_ncols
,
device
=
device
).
repeat
(
m
,
1
)
# Reorder the rows, then swizzle the 2x2 blocks.
group_x
=
64
group_y
=
32
if
meta_dtype
.
itemsize
==
2
else
16
dst_rows
=
(
dst_rows
//
group_x
*
group_x
+
(
dst_rows
%
2
)
*
2
+
(
dst_rows
%
8
)
//
4
+
((
dst_rows
%
group_y
)
%
4
)
//
2
*
32
+
((
dst_rows
%
group_x
)
//
8
)
*
4
)
topright
=
((
dst_rows
%
2
==
0
)
&
(
dst_cols
%
2
==
1
)).
to
(
torch
.
int8
)
bottomleft
=
((
dst_rows
%
2
==
1
)
&
(
dst_cols
%
2
==
0
)).
to
(
torch
.
int8
)
dst_rows
+=
topright
-
bottomleft
dst_cols
-=
topright
-
bottomleft
# Assumed that meta tensor is to be stored in CUTLASS
# InterleavedColumnMajor layout, and reverse engineered
# corresponding code to store values into this tensor.
interleave
=
2
cols_maj
=
dst_cols
//
interleave
cols_min
=
dst_cols
%
interleave
return
(
cols_maj
*
m
*
interleave
+
dst_rows
*
interleave
+
cols_min
).
view
(
-
1
)
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
def
sparse_semi_structured_from_dense_cutlass
(
dense
):
if
dense
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
m
,
k
=
dense
.
shape
device
=
dense
.
device
meta_dtype
=
torch
.
int8
if
dense
.
dtype
==
torch
.
int8
:
meta_dtype
=
torch
.
int32
elif
dense
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
int32
]:
meta_dtype
=
torch
.
int16
else
:
raise
RuntimeError
(
f
"Invalid datatype
{
dense
.
dtype
}
of dense matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
if
quadbits_per_meta_elem
not
in
(
4
,
8
):
raise
RuntimeError
(
"Invalid number of elements per meta element calculated"
)
if
meta_dtype
==
torch
.
int32
:
if
m
%
16
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 16"
)
else
:
if
m
%
32
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 32"
)
if
k
%
(
4
*
quadbits_per_meta_elem
)
!=
0
:
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
# noqa: E501
)
if
dense
.
dtype
!=
torch
.
float
:
ksparse
=
4
dense_4
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m1
,
m2
,
m3
=
(
dense_4
!=
0
).
unbind
(
-
1
)
else
:
ksparse
=
2
dense_2
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m2
=
m1
,
m3
=
(
dense_2
!=
0
).
unbind
(
-
1
)
meta_ncols
=
k
//
(
ksparse
*
quadbits_per_meta_elem
)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0
=
m0
&
m1
expr1
=
~
m0
&
m1
expr2
=
~
m0
&
~
m1
bit0
=
expr1
bit1
=
expr2
bit2
=
expr0
|
expr2
|
m3
bit3
=
expr1
|
~
m1
idxs0
=
bit0
|
(
bit1
.
to
(
torch
.
int64
)
<<
1
)
idxs1
=
bit2
|
(
bit3
.
to
(
torch
.
int64
)
<<
1
)
if
dense
.
dtype
!=
torch
.
float
:
sparse0
=
dense_4
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
))
# type: ignore[possibly-undefined]
sparse1
=
dense_4
.
gather
(
-
1
,
idxs1
.
unsqueeze
(
-
1
))
sparse
=
torch
.
stack
((
sparse0
,
sparse1
),
dim
=-
1
).
view
(
m
,
k
//
2
)
else
:
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
meta_4
=
idxs0
|
(
idxs1
<<
2
)
meta_n
=
meta_4
.
view
(
(
-
1
,
meta_ncols
,
quadbits_per_meta_elem
)).
to
(
meta_dtype
)
if
quadbits_per_meta_elem
==
4
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
))
elif
quadbits_per_meta_elem
==
8
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
)
|
(
meta_n
[:,
:,
4
]
<<
16
)
|
(
meta_n
[:,
:,
5
]
<<
20
)
|
(
meta_n
[:,
:,
6
]
<<
24
)
|
(
meta_n
[:,
:,
7
]
<<
28
))
# Reorder meta tensor elements.
meta_reordered
=
meta
.
new_empty
(
(
m
*
meta_ncols
,
))
# type: ignore[possibly-undefined]
meta_offsets
=
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
)
meta_reordered
.
scatter_
(
0
,
meta_offsets
,
meta
.
view
(
-
1
))
return
(
sparse
,
meta_reordered
.
view
(
m
,
meta_ncols
))
# This function performs reverse of the function above - it
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
def
sparse_semi_structured_to_dense_cutlass
(
sparse
,
meta_reordered
):
if
sparse
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional sparse tensor, got
{
sparse
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
m
,
k
=
sparse
.
shape
device
=
sparse
.
device
if
meta_reordered
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional meta tensor, got
{
meta_reordered
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
if
meta_reordered
.
device
!=
device
:
raise
RuntimeError
(
f
"Expected meta matrix to be on
{
device
}
device, got matrix on
{
meta_reordered
.
device
}
device"
# noqa: E501
)
meta_dtype
=
meta_reordered
.
dtype
if
meta_dtype
not
in
(
torch
.
int16
,
torch
.
int32
):
raise
RuntimeError
(
f
"Invalid datatype
{
meta_dtype
}
of meta matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
ksparse
=
4
if
sparse
.
dtype
!=
torch
.
float
else
2
meta_nrows
,
meta_ncols
=
meta_reordered
.
shape
if
meta_nrows
!=
m
:
raise
RuntimeError
(
f
"Number of rows of meta matrix
{
meta_nrows
}
must be equal to number of columns of spase matrix
{
m
}
"
# noqa: E501
)
if
meta_ncols
*
ksparse
*
quadbits_per_meta_elem
!=
2
*
k
:
raise
RuntimeError
(
f
"Number of columns of sparse matrix
{
k
}
different from the
{
meta_ncols
*
ksparse
*
quadbits_per_meta_elem
//
2
}
, "
# noqa: E501
"expected according to the number of columns of meta matrix"
)
# Undo meta tensor elements reordering.
meta_offsets
=
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
)
meta
=
torch
.
gather
(
meta_reordered
.
view
(
-
1
),
0
,
meta_offsets
).
view
(
m
,
meta_ncols
)
# Unpack sparse tensor back to original dense tensor, using
# information provided by meta tensor. Note that torch.float
# datatype is handled pretty much the same as
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
# value is encoded as if underlying 8 bytes contain four
# torch.half/torch.bfloat16 values, where either first two or last
# two are zeros.
meta_2
=
torch
.
empty
(
(
m
,
meta_ncols
,
2
*
quadbits_per_meta_elem
),
dtype
=
meta_dtype
,
device
=
device
,
)
if
quadbits_per_meta_elem
==
4
:
meta_2
[:,
:,
0
]
=
meta
&
0b11
meta_2
[:,
:,
1
]
=
(
meta
>>
2
)
&
0b11
meta_2
[:,
:,
2
]
=
(
meta
>>
4
)
&
0b11
meta_2
[:,
:,
3
]
=
(
meta
>>
6
)
&
0b11
meta_2
[:,
:,
4
]
=
(
meta
>>
8
)
&
0b11
meta_2
[:,
:,
5
]
=
(
meta
>>
10
)
&
0b11
meta_2
[:,
:,
6
]
=
(
meta
>>
12
)
&
0b11
meta_2
[:,
:,
7
]
=
(
meta
>>
14
)
&
0b11
elif
quadbits_per_meta_elem
==
8
:
meta_2
[:,
:,
0
]
=
meta
&
0b11
meta_2
[:,
:,
1
]
=
(
meta
>>
2
)
&
0b11
meta_2
[:,
:,
2
]
=
(
meta
>>
4
)
&
0b11
meta_2
[:,
:,
3
]
=
(
meta
>>
6
)
&
0b11
meta_2
[:,
:,
4
]
=
(
meta
>>
8
)
&
0b11
meta_2
[:,
:,
5
]
=
(
meta
>>
10
)
&
0b11
meta_2
[:,
:,
6
]
=
(
meta
>>
12
)
&
0b11
meta_2
[:,
:,
7
]
=
(
meta
>>
14
)
&
0b11
meta_2
[:,
:,
8
]
=
(
meta
>>
16
)
&
0b11
meta_2
[:,
:,
9
]
=
(
meta
>>
18
)
&
0b11
meta_2
[:,
:,
10
]
=
(
meta
>>
20
)
&
0b11
meta_2
[:,
:,
11
]
=
(
meta
>>
22
)
&
0b11
meta_2
[:,
:,
12
]
=
(
meta
>>
24
)
&
0b11
meta_2
[:,
:,
13
]
=
(
meta
>>
26
)
&
0b11
meta_2
[:,
:,
14
]
=
(
meta
>>
28
)
&
0b11
meta_2
[:,
:,
15
]
=
(
meta
>>
30
)
&
0b11
dense_offsets
=
meta_2
.
view
(
-
1
)
+
(
torch
.
arange
(
0
,
2
*
m
*
k
//
ksparse
,
device
=
device
)
*
4
).
view
(
-
1
,
1
).
repeat
(
1
,
2
).
view
(
-
1
)
dense
=
torch
.
zeros
((
m
*
2
*
k
,
),
dtype
=
sparse
.
dtype
,
device
=
device
)
if
sparse
.
dtype
!=
torch
.
float
:
# dense.scatter_(0, dense_offsets, sparse.view(-1))
dense
.
scatter_
(
0
,
dense_offsets
,
sparse
.
reshape
(
-
1
))
else
:
dense
.
view
(
torch
.
half
).
scatter_
(
0
,
dense_offsets
,
sparse
.
view
(
torch
.
half
).
view
(
-
1
))
return
dense
.
view
(
m
,
2
*
k
)
def
mask_creator
(
tensor
):
"""
Class for creating N:M sparsity masks.
Masks will be created using the N:M ratio, where for every block of
M weights, N will be pruned based on ranked weight value. Each mask
will correspond to the given tensor.
:param N: The number of weights in a group to keep
:param M: The size of a weight group
"""
N
=
2
M
=
4
mask
=
None
# for i, tensor in enumerate(tensors):
if
tensor
.
numel
()
%
M
!=
0
:
raise
ValueError
(
f
"Tensor of size
{
tensor
.
shape
}
can't be evenly divided into "
f
"
{
M
}
groups"
)
num_groups
=
tensor
.
numel
()
//
M
# N:M sparsity for linear layers
tensor_temp
=
tensor
.
detach
().
abs
().
reshape
(
num_groups
,
M
)
index
=
torch
.
argsort
(
tensor_temp
,
dim
=
1
)[:,
:
int
(
M
-
N
)]
w_b
=
torch
.
ones
(
tensor_temp
.
shape
,
device
=
tensor_temp
.
device
)
mask
=
w_b
.
scatter_
(
dim
=
1
,
index
=
index
,
value
=
0
).
reshape
(
tensor
.
shape
)
return
mask
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_24_perms.py
0 → 100644
View file @
18c42e67
"""This file is used for /tests and /benchmarks"""
from
typing
import
Dict
,
List
import
numpy
import
torch
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms_24
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
col_o
=
col
//
2
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col_o
*
256
+
8
*
(
col
%
2
)
+
4
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
1
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
ValueError
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
*
8
+
j
for
j
in
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm_single
.
extend
([
8
*
i
+
j
for
j
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_24_perm
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
marlin_24_scale_perm
:
Dict
[
int
,
List
[
int
]]
=
{}
marlin_24_scale_perm_single
:
Dict
[
int
,
List
[
int
]]
=
{}
for
num_bits
in
[
4
,
8
]:
perm_24
,
scale_perm_24
,
scale_perm_single_24
=
get_perms_24
(
num_bits
)
marlin_24_perm
[
num_bits
]
=
perm_24
marlin_24_scale_perm
[
num_bits
]
=
scale_perm_24
marlin_24_scale_perm_single
[
num_bits
]
=
scale_perm_single_24
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_perms.py
0 → 100644
View file @
18c42e67
"""This file is used for /tests and /benchmarks"""
from
typing
import
Dict
,
List
import
numpy
import
torch
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_perm
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
marlin_scale_perm
:
Dict
[
int
,
List
[
int
]]
=
{}
marlin_scale_perm_single
:
Dict
[
int
,
List
[
int
]]
=
{}
for
num_bits
in
[
4
,
8
]:
perm
,
scale_perm
,
scale_perm_single
=
get_perms
(
num_bits
)
marlin_perm
[
num_bits
]
=
perm
marlin_scale_perm
[
num_bits
]
=
scale_perm
marlin_scale_perm_single
[
num_bits
]
=
scale_perm_single
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
0 → 100644
View file @
18c42e67
"""This file is used for /tests and /benchmarks"""
import
random
import
numpy
import
torch
from
ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.format_24
import
(
mask_creator
,
sparse_semi_structured_from_dense_cutlass
)
from
ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_24_perms
import
(
marlin_24_perm
,
marlin_24_scale_perm
,
marlin_24_scale_perm_single
)
from
ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_perms
import
(
marlin_perm
,
marlin_scale_perm
,
marlin_scale_perm_single
)
from
ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
sort_weights
)
__cuda_arch
=
torch
.
cuda
.
get_device_capability
()
MARLIN_TILE
=
16
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
def
is_marlin_supported
():
return
__cuda_arch
[
0
]
>=
8
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
# Permute weights to 16x64 marlin tiles
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
tile
,
size_n
//
tile
,
tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
((
-
1
,
perm
.
numel
()))[:,
perm
].
reshape
(
q_w
.
shape
)
return
q_w
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
scale_perm
,
scale_perm_single
):
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Reformat to marlin
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
marlin_perm
[
num_bits
])
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_scale_perm
[
num_bits
],
marlin_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
inject_24
(
w
,
size_k
,
size_n
):
assert
w
.
shape
==
(
size_k
,
size_n
)
mask
=
mask_creator
(
w
.
t
()).
t
().
cuda
().
bool
()
return
(
mask
*
w
).
contiguous
(),
mask
.
contiguous
()
def
check_24
(
w
,
num_rows_to_sample
=
50
,
_verbose
=
False
):
BLOCK_SIZE
=
4
MAX_NON_ZEROS
=
2
w
=
w
.
t
().
contiguous
()
print
(
"check_24: w.shape = {}"
.
format
(
w
.
shape
))
num_rows
,
num_cols
=
w
.
shape
sampled_row_idxs
=
random
.
choices
(
range
(
num_rows
),
k
=
num_rows_to_sample
)
if
_verbose
:
print
(
f
"Sampled row idxs =
{
sampled_row_idxs
}
"
)
total_segments
=
0
non_24_segments
=
0
for
i
in
sampled_row_idxs
:
for
j
in
range
(
0
,
num_cols
-
BLOCK_SIZE
,
BLOCK_SIZE
):
total_segments
+=
1
block
=
w
[
i
,
j
:
j
+
BLOCK_SIZE
]
num_nonzero
=
torch
.
count_nonzero
(
block
)
if
num_nonzero
>
MAX_NON_ZEROS
:
print
(
"i = {} j = {} block = {}"
.
format
(
i
,
j
,
block
))
non_24_segments
+=
1
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
q_24_no_zp_comp
,
meta
=
sparse_semi_structured_from_dense_cutlass
(
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
return
q_24_comp
,
meta
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Inject 2:4 sparsity
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
num_bits
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num_bits
,
marlin_24_perm
[
num_bits
])
marlin_24_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_24_scale_perm
[
num_bits
],
marlin_24_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_24_ref
,
marlin_24_q_w_comp
,
meta
,
marlin_24_s
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
max_workspace_size
=
((
out_features
//
min_thread_n
)
*
max_parallel
)
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py
0 → 100644
View file @
18c42e67
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
SUPPORTED_NUM_BITS
=
[
4
,
8
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
):
assert
q_w
.
shape
==
w_ref
.
shape
orig_device
=
q_w
.
device
k_size
,
_
=
q_w
.
shape
g_idx
=
torch
.
zeros
((
k_size
,
),
dtype
=
torch
.
int32
)
for
i
in
range
(
k_size
):
g_idx
[
i
]
=
i
//
group_size
# Simulate act_order by doing a random permutation on K
rand_perm
=
torch
.
randperm
(
k_size
)
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
w_ref
=
w_ref
[
rand_perm
,
:].
contiguous
()
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
w
=
w
.
view
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
s
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s
*=
2
/
max_q_val
# 2 => symmetric
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s
# Restore original shapes
if
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
# Apply act_order
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
rand_perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
assert
(
group_size
<
size_k
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
w_ref
,
q_w
,
g_idx
,
rand_perm
=
permute_rows
(
q_w
,
w_ref
,
group_size
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
orig_device
=
q_w
.
device
sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
dtype
=
torch
.
int32
)
# Sort based on g_idx
g_idx
=
g_idx
[
sort_indices
].
contiguous
()
q_w
=
q_w
[
sort_indices
,
:].
contiguous
()
return
(
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
sort_indices
.
to
(
device
=
orig_device
),
)
def
gptq_pack
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_k
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
//
pack_factor
,
size_n
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[
i
::
pack_factor
,
:]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_res
ktransformers/ktransformers_ext/operators/llamafile/conversion.h
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:34:55
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_CONVERSION_H
#define CPUINFER_CONVERSION_H
#include <memory.h>
#include "llama.cpp/ggml.h"
inline
void
to_float
(
const
void
*
input
,
float
*
output
,
int
size
,
ggml_type
type
)
{
if
(
type
==
ggml_type
::
GGML_TYPE_F32
)
{
memcpy
(
output
,
input
,
size
*
sizeof
(
float
));
}
else
{
ggml_internal_get_type_traits
(
type
).
to_float
(
input
,
output
,
size
);
}
}
inline
void
from_float
(
const
float
*
input
,
void
*
output
,
int
size
,
ggml_type
type
)
{
if
(
type
==
ggml_type
::
GGML_TYPE_F32
)
{
memcpy
(
output
,
input
,
size
*
sizeof
(
float
));
}
else
{
ggml_internal_get_type_traits
(
type
).
from_float
(
input
,
output
,
size
);
}
}
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:34:58
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "linear.h"
Linear
::
Linear
(
LinearConfig
config
)
{
config_
=
config
;
proj_
=
config_
.
proj
;
input_fp32_
.
resize
(
config_
.
input_size
);
proj_input_
.
resize
(
config_
.
input_size
*
4
);
proj_output_
.
resize
(
config_
.
output_size
);
}
void
Linear
::
warm_up
(
Backend
*
backend
)
{
std
::
vector
<
float
>
input_fp32
(
config_
.
input_size
);
std
::
vector
<
uint8_t
>
input
(
config_
.
input_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
std
::
vector
<
uint8_t
>
output
(
config_
.
output_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
for
(
int
i
=
0
;
i
<
config_
.
input_size
;
i
++
)
{
input_fp32
[
i
]
=
0
;
}
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
input_size
,
config_
.
hidden_type
);
forward
(
input
.
data
(),
output
.
data
(),
backend
);
}
void
Linear
::
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
const
void
*
proj_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
)
{
proj_input_ptr
=
input
;
}
else
{
to_float
(
input
,
input_fp32_
.
data
(),
config_
.
input_size
,
config_
.
hidden_type
);
from_float
(
input_fp32_
.
data
(),
proj_input_
.
data
(),
config_
.
input_size
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
);
proj_input_ptr
=
proj_input_
.
data
();
}
int
nth
=
config_
.
output_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
%
nth
;
llamafile_sgemm
(
config_
.
output_size
,
1
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_input_ptr
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_output_
.
data
(),
config_
.
output_size
,
ith
,
nth
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
proj_type
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
});
from_float
(
proj_output_
.
data
(),
output
,
config_
.
output_size
,
config_
.
hidden_type
);
}
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/linear.h
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:00
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_LINEAR_H
#define CPUINFER_OPERATOR_LINEAR_H
#include <cmath>
#include <cstdio>
#include <functional>
#include <mutex>
#include <vector>
#include "../../cpu_backend/backend.h"
#include "conversion.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
struct
LinearConfig
{
int
input_size
;
int
output_size
;
int
stride
;
void
*
proj
;
ggml_type
proj_type
;
ggml_type
hidden_type
;
LinearConfig
()
{}
LinearConfig
(
int
input_size
,
int
output_size
,
int
stride
,
void
*
proj
,
ggml_type
proj_type
,
ggml_type
hidden_type
)
:
input_size
(
input_size
),
output_size
(
output_size
),
stride
(
stride
),
proj
(
proj
),
proj_type
(
proj_type
),
hidden_type
(
hidden_type
)
{}
};
class
Linear
{
public:
Linear
(
LinearConfig
);
void
warm_up
(
Backend
*
backend
);
void
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
private:
LinearConfig
config_
;
void
*
proj_
;
// [output_size * input_size ( /32 if quantized)]
std
::
vector
<
float
>
input_fp32_
;
// [input_size]
std
::
vector
<
uint8_t
>
proj_input_
;
// [input_size * 4]
std
::
vector
<
float
>
proj_output_
;
// [output_size]
};
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "mlp.h"
MLP
::
MLP
(
MLPConfig
config
)
{
config_
=
config
;
gate_proj_
=
config_
.
gate_proj
;
up_proj_
=
config_
.
up_proj
;
down_proj_
=
config_
.
down_proj
;
input_fp32_
.
resize
(
config_
.
hidden_size
);
gate_input_
.
resize
(
config_
.
hidden_size
*
4
);
up_input_
.
resize
(
config_
.
hidden_size
*
4
);
gate_output_
.
resize
(
config_
.
intermediate_size
);
up_output_
.
resize
(
config_
.
intermediate_size
);
intermediate_fp32_
.
resize
(
config_
.
intermediate_size
);
down_input_
.
resize
(
config_
.
intermediate_size
*
4
);
down_output_
.
resize
(
config_
.
hidden_size
);
}
void
MLP
::
warm_up
(
Backend
*
backend
)
{
std
::
vector
<
float
>
input_fp32
(
config_
.
hidden_size
);
std
::
vector
<
uint8_t
>
input
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
std
::
vector
<
uint8_t
>
output
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
for
(
int
i
=
0
;
i
<
config_
.
hidden_size
;
i
++
)
{
input_fp32
[
i
]
=
0
;
}
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
forward
(
input
.
data
(),
output
.
data
(),
backend
);
}
static
float
act_fn
(
float
x
)
{
return
x
/
(
1.0
f
+
expf
(
-
x
));
}
void
MLP
::
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
const
void
*
gate_input_ptr
;
const
void
*
up_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
&&
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
gate_input_ptr
=
up_input_ptr
=
input
;
}
else
{
to_float
(
input
,
input_fp32_
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
if
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
(),
gate_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
up_input_ptr
=
gate_input_
.
data
();
}
else
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
(),
gate_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
gate_input_
.
data
();
}
else
{
gate_input_ptr
=
input
;
}
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
(),
up_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
up_input_ptr
=
up_input_
.
data
();
}
else
{
up_input_ptr
=
input
;
}
}
}
int
nth
=
config_
.
intermediate_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
void
*
gate_proj_ptr
=
gate_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
float
*
gate_output_ptr
=
gate_output_
.
data
()
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
up_proj_ptr
=
up_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
float
*
up_output_ptr
=
up_output_
.
data
()
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
ith
*
config_
.
stride
;
i
<
(
ith
+
1
)
*
config_
.
stride
;
i
++
)
{
intermediate_fp32_
[
i
]
=
act_fn
(
gate_output_
[
i
])
*
up_output_
[
i
];
}
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
==
0
)
{
float
*
intermediate_fp32_ptr
=
intermediate_fp32_
.
data
()
+
ith
*
config_
.
stride
;
void
*
down_input_ptr
=
down_input_
.
data
()
+
ith
*
config_
.
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
config_
.
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
});
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
!=
0
)
{
from_float
(
intermediate_fp32_
.
data
(),
down_input_
.
data
(),
config_
.
intermediate_size
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
nth
=
config_
.
hidden_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
void
*
down_proj_ptr
=
down_proj_
+
ith
*
config_
.
stride
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
float
*
down_output_ptr
=
down_output_
.
data
()
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_input_
.
data
(),
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
void
*
output_ptr
=
output
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
from_float
(
down_output_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
}
});
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
!=
0
)
{
from_float
(
down_output_
.
data
(),
output
,
config_
.
hidden_size
,
config_
.
hidden_type
);
}
}
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_MLP_H
#define CPUINFER_OPERATOR_MLP_H
#include <cmath>
#include <cstdio>
#include <functional>
#include <mutex>
#include <vector>
#include "../../cpu_backend/backend.h"
#include "conversion.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
struct
MLPConfig
{
int
hidden_size
;
int
intermediate_size
;
int
stride
;
void
*
gate_proj
;
void
*
up_proj
;
void
*
down_proj
;
ggml_type
gate_type
;
ggml_type
up_type
;
ggml_type
down_type
;
ggml_type
hidden_type
;
MLPConfig
()
{}
MLPConfig
(
int
hidden_size
,
int
intermediate_size
,
int
stride
,
void
*
gate_proj
,
void
*
up_proj
,
void
*
down_proj
,
ggml_type
gate_type
,
ggml_type
up_type
,
ggml_type
down_type
,
ggml_type
hidden_type
)
:
hidden_size
(
hidden_size
),
intermediate_size
(
intermediate_size
),
stride
(
stride
),
gate_proj
(
gate_proj
),
up_proj
(
up_proj
),
down_proj
(
down_proj
),
gate_type
(
gate_type
),
up_type
(
up_type
),
down_type
(
down_type
),
hidden_type
(
hidden_type
)
{}
};
class
MLP
{
public:
MLP
(
MLPConfig
);
void
warm_up
(
Backend
*
backend
);
void
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
private:
MLPConfig
config_
;
void
*
gate_proj_
;
// [intermediate_size * hidden_size ( /32 if quantized)]
void
*
up_proj_
;
// [intermediate_size * hidden_size ( /32 if quantized)]
void
*
down_proj_
;
// [hidden_size * intermediate_size ( /32 if quantized)]
std
::
vector
<
float
>
input_fp32_
;
// [hidden_size]
std
::
vector
<
uint8_t
>
gate_input_
;
// [hidden_size * 4]
std
::
vector
<
uint8_t
>
up_input_
;
// [hidden_size * 4]
std
::
vector
<
float
>
gate_output_
;
// [intermediate_size]
std
::
vector
<
float
>
up_output_
;
// [intermediate_size]
std
::
vector
<
float
>
intermediate_fp32_
;
// [intermediate_size]
std
::
vector
<
uint8_t
>
down_input_
;
// [intermediate_size * 4]
std
::
vector
<
float
>
down_output_
;
// [hidden_size]
};
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:07
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "moe.h"
#include <iostream>
#include "unistd.h"
void
*
MOE
::
buffer_
=
nullptr
;
MOE
::
MOE
(
MOEConfig
config
)
{
config_
=
config
;
gate_proj_
=
config_
.
gate_proj
;
up_proj_
=
config_
.
up_proj
;
down_proj_
=
config_
.
down_proj
;
if
(
MOE
::
buffer_
==
nullptr
)
{
uint64_t
buffer_size
=
0
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
;
buffer_size
+=
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
buffer_size
+=
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
;
buffer_
=
malloc
(
buffer_size
);
}
uint64_t
offset
=
0
;
s_input_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
s_gate_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
s_up_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
s_gate_output_
.
resize
(
config_
.
routed_expert_num
);
s_up_output_
.
resize
(
config_
.
routed_expert_num
);
s_intermediate_fp32_
.
resize
(
config_
.
routed_expert_num
);
s_down_input_
.
resize
(
config_
.
routed_expert_num
);
s_down_output_
.
resize
(
config_
.
routed_expert_num
);
for
(
int
i
=
0
;
i
<
config_
.
routed_expert_num
;
i
++
)
{
s_gate_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_up_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_intermediate_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_down_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
s_down_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
}
s_output_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
=
0
;
m_input_fp32_
.
resize
(
config_
.
group_max_len
);
m_gate_input_
.
resize
(
config_
.
group_max_len
);
m_up_input_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
m_input_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
m_gate_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
m_up_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
}
m_local_gate_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
m_local_up_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
m_local_gate_output_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_up_output_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_intermediate_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_down_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
m_local_down_output_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
;
m_output_fp32_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
m_output_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
}
m_local_pos_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
m_local_pos_
[
i
].
reserve
(
config_
.
expert_num
);
}
m_local_num_
.
resize
(
config_
.
expert_num
);
m_local_gate_input_ptr_
.
resize
(
config_
.
expert_num
);
m_local_up_input_ptr_
.
resize
(
config_
.
expert_num
);
m_local_gate_output_ptr_
.
resize
(
config_
.
expert_num
);
m_local_up_output_ptr_
.
resize
(
config_
.
expert_num
);
m_local_intermediate_fp32_ptr_
.
resize
(
config_
.
expert_num
);
m_local_down_input_ptr_
.
resize
(
config_
.
expert_num
);
m_local_down_output_ptr_
.
resize
(
config_
.
expert_num
);
}
void
MOE
::
warm_up
(
Backend
*
backend
)
{
std
::
vector
<
float
>
input_fp32
(
config_
.
hidden_size
);
std
::
vector
<
uint8_t
>
input
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
std
::
vector
<
uint8_t
>
output
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
for
(
int
i
=
0
;
i
<
config_
.
hidden_size
;
i
++
)
{
input_fp32
[
i
]
=
0
;
}
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
for
(
int
i
=
0
;
i
<
config_
.
expert_num
;
i
++
)
{
uint64_t
expert_ids
=
i
;
float
weights
=
0
;
forward_one
(
1
,
&
expert_ids
,
&
weights
,
input
.
data
(),
output
.
data
(),
backend
);
}
}
static
float
act_fn
(
float
x
)
{
return
x
/
(
1.0
f
+
expf
(
-
x
));
}
void
MOE
::
forward_one
(
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
const
void
*
gate_input_ptr
;
const
void
*
up_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
&&
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
gate_input_ptr
=
up_input_ptr
=
input
;
}
else
{
to_float
(
input
,
s_input_fp32_
,
config_
.
hidden_size
,
config_
.
hidden_type
);
if
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
s_input_fp32_
,
s_gate_input_
,
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
up_input_ptr
=
s_gate_input_
;
}
else
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
{
from_float
(
s_input_fp32_
,
s_gate_input_
,
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
s_gate_input_
;
}
else
{
gate_input_ptr
=
input
;
}
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
s_input_fp32_
,
s_up_input_
,
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
up_input_ptr
=
s_up_input_
;
}
else
{
up_input_ptr
=
input
;
}
}
}
int
nth
=
config_
.
intermediate_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
*
k
,
[
&
](
int
task_id
)
{
int
expert_idx
=
task_id
/
nth
;
uint64_t
expert_id
=
expert_ids
[
expert_idx
];
int
ith
=
task_id
%
nth
;
void
*
gate_proj_ptr
=
gate_proj_
+
(
expert_id
*
config_
.
intermediate_size
+
ith
*
config_
.
stride
)
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
float
*
gate_output_ptr
=
s_gate_output_
[
expert_idx
]
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
up_proj_ptr
=
up_proj_
+
(
expert_id
*
config_
.
intermediate_size
+
ith
*
config_
.
stride
)
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
float
*
up_output_ptr
=
s_up_output_
[
expert_idx
]
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
ith
*
config_
.
stride
;
i
<
(
ith
+
1
)
*
config_
.
stride
;
i
++
)
{
s_intermediate_fp32_
[
expert_idx
][
i
]
=
act_fn
(
s_gate_output_
[
expert_idx
][
i
])
*
s_up_output_
[
expert_idx
][
i
];
}
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
==
0
)
{
float
*
intermediate_fp32_ptr
=
s_intermediate_fp32_
[
expert_idx
]
+
ith
*
config_
.
stride
;
void
*
down_input_ptr
=
s_down_input_
[
expert_idx
]
+
ith
*
config_
.
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
config_
.
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
});
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
!=
0
)
{
for
(
int
i
=
0
;
i
<
k
;
i
++
)
{
from_float
(
s_intermediate_fp32_
[
i
],
s_down_input_
[
i
],
config_
.
intermediate_size
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
}
nth
=
config_
.
hidden_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
for
(
int
i
=
ith
*
config_
.
stride
;
i
<
(
ith
+
1
)
*
config_
.
stride
;
i
++
)
{
s_output_fp32_
[
i
]
=
0
;
}
for
(
int
expert_idx
=
0
;
expert_idx
<
k
;
expert_idx
++
)
{
uint64_t
expert_id
=
expert_ids
[
expert_idx
];
void
*
down_proj_ptr
=
down_proj_
+
(
expert_id
*
config_
.
hidden_size
+
ith
*
config_
.
stride
)
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
float
*
down_output_ptr
=
s_down_output_
[
expert_idx
]
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
s_down_input_
[
expert_idx
],
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
ith
*
config_
.
stride
;
i
<
(
ith
+
1
)
*
config_
.
stride
;
i
++
)
{
s_output_fp32_
[
i
]
+=
s_down_output_
[
expert_idx
][
i
]
*
weights
[
expert_idx
];
}
}
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
float
*
output_fp32_ptr
=
s_output_fp32_
+
ith
*
config_
.
stride
;
void
*
output_ptr
=
output
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
from_float
(
output_fp32_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
}
});
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
!=
0
)
{
from_float
(
s_output_fp32_
,
output
,
config_
.
hidden_size
,
config_
.
hidden_type
);
}
}
void
MOE
::
forward_many
(
int
qlen
,
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
for
(
int
i
=
0
;
i
<
config_
.
expert_num
;
i
++
)
{
m_local_num_
[
i
]
=
0
;
}
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
for
(
int
j
=
0
;
j
<
k
;
j
++
)
{
m_local_pos_
[
i
][
j
]
=
m_local_num_
[
expert_ids
[
i
*
k
+
j
]]
++
;
}
}
uint64_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
config_
.
expert_num
;
i
++
)
{
m_local_gate_input_ptr_
[
i
]
=
m_local_gate_input_
+
offset
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
m_local_up_input_ptr_
[
i
]
=
m_local_up_input_
+
offset
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
m_local_gate_output_ptr_
[
i
]
=
m_local_gate_output_
+
offset
*
config_
.
intermediate_size
;
m_local_up_output_ptr_
[
i
]
=
m_local_up_output_
+
offset
*
config_
.
intermediate_size
;
m_local_intermediate_fp32_ptr_
[
i
]
=
m_local_intermediate_fp32_
+
offset
*
config_
.
intermediate_size
;
m_local_down_input_ptr_
[
i
]
=
m_local_down_input_
+
offset
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
m_local_down_output_ptr_
[
i
]
=
m_local_down_output_
+
offset
*
config_
.
hidden_size
;
offset
+=
m_local_num_
[
i
];
}
backend
->
do_work_stealing_job
(
qlen
,
[
&
](
int
i
)
{
const
void
*
gate_input_ptr
;
const
void
*
up_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
&&
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
gate_input_ptr
=
up_input_ptr
=
input
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
}
else
{
to_float
(
input
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
m_input_fp32_
[
i
],
config_
.
hidden_size
,
config_
.
hidden_type
);
if
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
m_input_fp32_
[
i
],
m_gate_input_
[
i
],
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
up_input_ptr
=
m_gate_input_
[
i
];
}
else
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
{
from_float
(
m_input_fp32_
[
i
],
m_gate_input_
[
i
],
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
m_gate_input_
[
i
];
}
else
{
gate_input_ptr
=
input
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
}
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
m_input_fp32_
[
i
],
m_up_input_
[
i
],
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
up_input_ptr
=
m_up_input_
[
i
];
}
else
{
up_input_ptr
=
input
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
}
}
}
for
(
int
j
=
0
;
j
<
k
;
j
++
)
{
memcpy
(
m_local_gate_input_ptr_
[
expert_ids
[
i
*
k
+
j
]]
+
m_local_pos_
[
i
][
j
]
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
),
gate_input_ptr
,
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
));
memcpy
(
m_local_up_input_ptr_
[
expert_ids
[
i
*
k
+
j
]]
+
m_local_pos_
[
i
][
j
]
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
),
up_input_ptr
,
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
));
}
});
int
stride
=
QK_K
;
int
nth
=
config_
.
intermediate_size
/
stride
;
backend
->
do_work_stealing_job
(
nth
*
config_
.
expert_num
,
[
&
](
int
task_id
)
{
int
expert_idx
=
task_id
/
nth
;
int
ith
=
task_id
%
nth
;
void
*
gate_input_ptr
=
m_local_gate_input_ptr_
[
expert_idx
];
void
*
gate_proj_ptr
=
gate_proj_
+
(
expert_idx
*
config_
.
intermediate_size
+
ith
*
stride
)
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
float
*
gate_output_ptr
=
m_local_gate_output_ptr_
[
expert_idx
]
+
ith
*
stride
;
llamafile_sgemm
(
stride
,
m_local_num_
[
expert_idx
],
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
intermediate_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
up_input_ptr
=
m_local_up_input_ptr_
[
expert_idx
];
void
*
up_proj_ptr
=
up_proj_
+
(
expert_idx
*
config_
.
intermediate_size
+
ith
*
stride
)
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
float
*
up_output_ptr
=
m_local_up_output_ptr_
[
expert_idx
]
+
ith
*
stride
;
llamafile_sgemm
(
stride
,
m_local_num_
[
expert_idx
],
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
intermediate_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
0
;
i
<
m_local_num_
[
expert_idx
];
i
++
)
{
for
(
int
j
=
ith
*
stride
;
j
<
(
ith
+
1
)
*
stride
;
j
++
)
{
m_local_intermediate_fp32_ptr_
[
expert_idx
][
i
*
config_
.
intermediate_size
+
j
]
=
act_fn
(
m_local_gate_output_ptr_
[
expert_idx
][
i
*
config_
.
intermediate_size
+
j
])
*
m_local_up_output_ptr_
[
expert_idx
][
i
*
config_
.
intermediate_size
+
j
];
}
float
*
intermediate_fp32_ptr
=
m_local_intermediate_fp32_ptr_
[
expert_idx
]
+
i
*
config_
.
intermediate_size
+
ith
*
stride
;
void
*
down_input_ptr
=
m_local_down_input_ptr_
[
expert_idx
]
+
i
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
+
ith
*
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
});
stride
=
QK_K
;
nth
=
config_
.
hidden_size
/
stride
;
backend
->
do_work_stealing_job
(
nth
*
config_
.
expert_num
,
[
&
](
int
task_id
)
{
int
expert_idx
=
task_id
/
nth
;
int
ith
=
task_id
%
nth
;
void
*
down_input_ptr
=
m_local_down_input_ptr_
[
expert_idx
];
void
*
down_proj_ptr
=
down_proj_
+
(
expert_idx
*
config_
.
hidden_size
+
ith
*
stride
)
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
float
*
down_output_ptr
=
m_local_down_output_ptr_
[
expert_idx
]
+
ith
*
stride
;
llamafile_sgemm
(
stride
,
m_local_num_
[
expert_idx
],
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_input_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
hidden_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
});
backend
->
do_work_stealing_job
(
qlen
,
[
&
](
int
i
)
{
for
(
int
e
=
0
;
e
<
config_
.
hidden_size
;
e
++
)
{
m_output_fp32_
[
i
][
e
]
=
0
;
}
for
(
int
j
=
0
;
j
<
k
;
j
++
)
{
for
(
int
e
=
0
;
e
<
config_
.
hidden_size
;
e
++
)
{
m_output_fp32_
[
i
][
e
]
+=
m_local_down_output_ptr_
[
expert_ids
[
i
*
k
+
j
]][
m_local_pos_
[
i
][
j
]
*
config_
.
hidden_size
+
e
]
*
weights
[
i
*
k
+
j
];
}
}
from_float
(
m_output_fp32_
[
i
],
output
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
config_
.
hidden_size
,
config_
.
hidden_type
);
});
}
void
MOE
::
forward
(
int
qlen
,
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
if
(
qlen
<
config_
.
group_min_len
)
{
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
forward_one
(
k
,
expert_ids
+
i
*
k
,
weights
+
i
*
k
,
input
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
output
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
}
return
;
}
int
forward_len
=
std
::
min
(
config_
.
group_max_len
,
qlen
);
forward_many
(
forward_len
,
k
,
expert_ids
,
weights
,
input
,
output
,
backend
);
forward
(
qlen
-
forward_len
,
k
,
expert_ids
+
forward_len
*
k
,
weights
+
forward_len
*
k
,
input
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
output
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
}
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/moe.h
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:10
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_MOE_H
#define CPUINFER_OPERATOR_MOE_H
#include <cmath>
#include <cstdio>
#include <functional>
#include <mutex>
#include <vector>
#include "../../cpu_backend/backend.h"
#include "conversion.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
struct
MOEConfig
{
int
expert_num
;
int
routed_expert_num
;
int
hidden_size
;
int
intermediate_size
;
int
stride
;
int
group_min_len
;
int
group_max_len
;
void
*
gate_proj
;
void
*
up_proj
;
void
*
down_proj
;
ggml_type
gate_type
;
ggml_type
up_type
;
ggml_type
down_type
;
ggml_type
hidden_type
;
MOEConfig
()
{}
MOEConfig
(
int
expert_num
,
int
routed_expert_num
,
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_min_len
,
int
group_max_len
,
void
*
gate_proj
,
void
*
up_proj
,
void
*
down_proj
,
ggml_type
gate_type
,
ggml_type
up_type
,
ggml_type
down_type
,
ggml_type
hidden_type
)
:
expert_num
(
expert_num
),
routed_expert_num
(
routed_expert_num
),
hidden_size
(
hidden_size
),
intermediate_size
(
intermediate_size
),
stride
(
stride
),
group_min_len
(
group_min_len
),
group_max_len
(
group_max_len
),
gate_proj
(
gate_proj
),
up_proj
(
up_proj
),
down_proj
(
down_proj
),
gate_type
(
gate_type
),
up_type
(
up_type
),
down_type
(
down_type
),
hidden_type
(
hidden_type
)
{}
};
class
MOE
{
public:
MOE
(
MOEConfig
);
void
warm_up
(
Backend
*
backend
);
void
forward_one
(
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward_many
(
int
qlen
,
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward
(
int
qlen
,
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
private:
static
void
*
buffer_
;
MOEConfig
config_
;
void
*
gate_proj_
;
// [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void
*
up_proj_
;
// [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void
*
down_proj_
;
// [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
float
*
s_input_fp32_
;
// [hidden_size]
uint8_t
*
s_gate_input_
;
// [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
uint8_t
*
s_up_input_
;
// [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
std
::
vector
<
float
*>
s_gate_output_
;
// [routed_expert_num, intermediate_size]
std
::
vector
<
float
*>
s_up_output_
;
// [routed_expert_num, intermediate_size]
std
::
vector
<
float
*>
s_intermediate_fp32_
;
// [routed_expert_num, intermediate_size]
std
::
vector
<
uint8_t
*>
s_down_input_
;
// [routed_expert_num, intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]
std
::
vector
<
float
*>
s_down_output_
;
// [routed_expert_num, hidden_size]
float
*
s_output_fp32_
;
// [hidden_size]
std
::
vector
<
float
*>
m_input_fp32_
;
// [group_max_len, hidden_size]
std
::
vector
<
uint8_t
*>
m_gate_input_
;
// [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
std
::
vector
<
uint8_t
*>
m_up_input_
;
// [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
uint8_t
*
m_local_gate_input_
;
// [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
uint8_t
*
m_local_up_input_
;
// [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
float
*
m_local_gate_output_
;
// [routed_expert_num * group_max_len * intermediate_size]
float
*
m_local_up_output_
;
// [routed_expert_num * group_max_len * intermediate_size]
float
*
m_local_intermediate_fp32_
;
// [routed_expert_num * group_max_len * intermediate_size]
uint8_t
*
m_local_down_input_
;
// [routed_expert_num * group_max_len * intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]
float
*
m_local_down_output_
;
// [routed_expert_num * group_max_len * hidden_size]
std
::
vector
<
float
*>
m_output_fp32_
;
// [group_max_len, hidden_size]
std
::
vector
<
std
::
vector
<
int
>>
m_local_pos_
;
// [group_max_len, routed_expert_num]
std
::
vector
<
int
>
m_local_num_
;
// [expert_num]
std
::
vector
<
uint8_t
*>
m_local_gate_input_ptr_
;
// [expert_num]
std
::
vector
<
uint8_t
*>
m_local_up_input_ptr_
;
// [expert_num]
std
::
vector
<
float
*>
m_local_gate_output_ptr_
;
// [expert_num]
std
::
vector
<
float
*>
m_local_up_output_ptr_
;
// [expert_num]
std
::
vector
<
float
*>
m_local_intermediate_fp32_ptr_
;
// [expert_num]
std
::
vector
<
uint8_t
*>
m_local_down_input_ptr_
;
// [expert_num]
std
::
vector
<
float
*>
m_local_down_output_ptr_
;
// [expert_num]
};
#endif
\ No newline at end of file
ktransformers/local_chat.py
0 → 100644
View file @
18c42e67
# Copyright 2024 Shaoyuan Chen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
platform
import
sys
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
sys
.
path
.
insert
(
0
,
project_dir
)
import
torch
import
logging
from
transformers
import
(
AutoTokenizer
,
AutoConfig
,
AutoModelForCausalLM
,
GenerationConfig
,
TextStreamer
,
)
import
json
import
fire
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
from
ktransformers.server.config.config
import
Config
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"Qwen2MoeForCausalLM"
:
Qwen2MoeForCausalLM
,
}
ktransformer_rules_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/optimize/optimize_rules/"
default_optimize_rules
=
{
"DeepseekV2ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V2-Chat.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-57B-A14B-Instruct.yaml"
,
}
def
local_chat
(
model_path
:
str
,
optimize_rule_path
:
str
=
None
,
gguf_path
:
str
=
None
,
max_new_tokens
:
int
=
1000
,
cpu_infer
:
int
=
Config
().
cpu_infer
):
torch
.
set_grad_enabled
(
False
)
Config
().
cpu_infer
=
cpu_infer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
in
custom_models
:
print
(
"using custom modeling_xxx.py."
)
if
"Qwen2Moe"
in
config
.
architectures
[
0
]:
# Qwen2Moe must use flash_attention_2 to avoid overflow.
config
.
_attn_implementation
=
"flash_attention_2"
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
else
:
model
=
AutoModelForCausalLM
.
from_config
(
config
,
trust_remote_code
=
True
,
attn_implementation
=
"flash_attention_2"
)
if
optimize_rule_path
is
None
:
if
config
.
architectures
[
0
]
in
default_optimize_rules
:
print
(
"using default_optimize_rule for"
,
config
.
architectures
[
0
])
optimize_rule_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
optimize_rule_path
=
input
(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if
gguf_path
is
None
:
gguf_path
=
input
(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf
(
model
,
optimize_rule_path
,
gguf_path
,
config
)
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model_path
)
if
model
.
generation_config
.
pad_token_id
is
None
:
model
.
generation_config
.
pad_token_id
=
model
.
generation_config
.
eos_token_id
model
.
eval
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
system
=
platform
.
system
()
if
(
system
==
u
'Windows'
):
os
.
system
(
'cls'
)
else
:
os
.
system
(
'clear'
)
while
True
:
content
=
input
(
"Chat: "
)
# if content is num
if
content
==
""
:
content
=
"Please write a piece of quicksort code in C++."
messages
=
[{
"role"
:
"user"
,
"content"
:
content
}]
input_tensor
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
# TODO: Remove this, replace dtype using config
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
local_chat
)
ktransformers/models/__init__.py
0 → 100644
View file @
18c42e67
ktransformers/models/configuration_deepseek.py
0 → 100644
View file @
18c42e67
# Adapted from
# https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/configuration_deepseek.py
# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
class
DeepseekV2Config
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V2.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 102400):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DeepseekV2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 1407):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
n_shared_experts (`int`, *optional*, defaults to None):
Number of shared experts, None means dense model.
n_routed_experts (`int`, *optional*, defaults to None):
Number of routed experts, None means dense model.
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
Scaling factor or routed experts.
topk_method (`str`, *optional*, defaults to `gready`):
Topk method used in routed gate.
n_group (`int`, *optional*, defaults to None):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to None):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to None):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to False):
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to 'softmax'):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
seq_aux = (`bool`, *optional*, defaults to True):
Whether to compute the auxiliary loss for each individual sample.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV2Model, DeepseekV2Config
>>> # Initializing a Deepseek-V2 style configuration
>>> configuration = DeepseekV2Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"deepseek_v2"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
102400
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
moe_intermediate_size
=
1407
,
num_hidden_layers
=
30
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
n_shared_experts
=
None
,
n_routed_experts
=
None
,
ep_size
=
1
,
routed_scaling_factor
=
1.0
,
kv_lora_rank
=
512
,
q_lora_rank
=
1536
,
qk_rope_head_dim
=
64
,
v_head_dim
=
128
,
qk_nope_head_dim
=
128
,
topk_method
=
'gready'
,
n_group
=
None
,
topk_group
=
None
,
num_experts_per_tok
=
None
,
moe_layer_freq
=
1
,
first_k_dense_replace
=
0
,
norm_topk_prob
=
False
,
scoring_func
=
'softmax'
,
aux_loss_alpha
=
0.001
,
seq_aux
=
True
,
hidden_act
=
"silu"
,
max_position_embeddings
=
2048
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
100000
,
eos_token_id
=
100001
,
pretraining_tp
=
1
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
cpu_quant
=
None
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
n_shared_experts
=
n_shared_experts
self
.
n_routed_experts
=
n_routed_experts
self
.
ep_size
=
ep_size
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
kv_lora_rank
=
kv_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
topk_method
=
topk_method
self
.
n_group
=
n_group
self
.
topk_group
=
topk_group
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
moe_layer_freq
=
moe_layer_freq
self
.
first_k_dense_replace
=
first_k_dense_replace
self
.
norm_topk_prob
=
norm_topk_prob
self
.
scoring_func
=
scoring_func
self
.
aux_loss_alpha
=
aux_loss_alpha
self
.
seq_aux
=
seq_aux
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
pretraining_tp
=
pretraining_tp
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
cpu_quant
=
cpu_quant
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
Prev
1
2
3
4
5
6
7
8
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment