Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
18c42e67
Commit
18c42e67
authored
Jul 27, 2024
by
chenxl
Browse files
Initial commit
parents
Changes
247
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2746 additions
and
0 deletions
+2746
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py
...transformers_ext/operators/custom_marlin/quantize/gptq.py
+206
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py
...rmers_ext/operators/custom_marlin/quantize/gptq_marlin.py
+458
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py
...formers_ext/operators/custom_marlin/quantize/quantizer.py
+140
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py
...ansformers_ext/operators/custom_marlin/quantize/repack.py
+99
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/__init__.py
...rs_ext/operators/custom_marlin/quantize/utils/__init__.py
+0
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/format_24.py
...s_ext/operators/custom_marlin/quantize/utils/format_24.py
+308
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_24_perms.py
...operators/custom_marlin/quantize/utils/marlin_24_perms.py
+60
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_perms.py
...xt/operators/custom_marlin/quantize/utils/marlin_perms.py
+60
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
...xt/operators/custom_marlin/quantize/utils/marlin_utils.py
+232
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py
...ext/operators/custom_marlin/quantize/utils/quant_utils.py
+146
-0
ktransformers/ktransformers_ext/operators/llamafile/conversion.h
...ormers/ktransformers_ext/operators/llamafile/conversion.h
+33
-0
ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
...sformers/ktransformers_ext/operators/llamafile/linear.cpp
+48
-0
ktransformers/ktransformers_ext/operators/llamafile/linear.h
ktransformers/ktransformers_ext/operators/llamafile/linear.h
+56
-0
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
+103
-0
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
+67
-0
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
+311
-0
ktransformers/ktransformers_ext/operators/llamafile/moe.h
ktransformers/ktransformers_ext/operators/llamafile/moe.h
+97
-0
ktransformers/local_chat.py
ktransformers/local_chat.py
+115
-0
ktransformers/models/__init__.py
ktransformers/models/__init__.py
+0
-0
ktransformers/models/configuration_deepseek.py
ktransformers/models/configuration_deepseek.py
+207
-0
No files found.
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py
0 → 100644
View file @
18c42e67
import
math
import
os
import
time
from
logging
import
getLogger
import
torch
import
torch.nn
as
nn
import
transformers
from
.quantizer
import
Quantizer
logger
=
getLogger
(
__name__
)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
class
GPTQ
:
def
__init__
(
self
,
layer
):
self
.
layer
=
layer
self
.
dev
=
self
.
layer
.
weight
.
device
W
=
layer
.
weight
.
data
.
clone
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
W
=
W
.
flatten
(
1
)
if
isinstance
(
self
.
layer
,
transformers
.
pytorch_utils
.
Conv1D
):
W
=
W
.
t
()
self
.
rows
=
W
.
shape
[
0
]
self
.
columns
=
W
.
shape
[
1
]
self
.
H
=
torch
.
zeros
((
self
.
columns
,
self
.
columns
),
device
=
self
.
dev
)
self
.
nsamples
=
0
self
.
quantizer
=
Quantizer
()
def
add_batch
(
self
,
inp
,
out
):
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
inp1
=
inp
self
.
out1
=
out
if
len
(
inp
.
shape
)
==
2
:
inp
=
inp
.
unsqueeze
(
0
)
tmp
=
inp
.
shape
[
0
]
if
isinstance
(
self
.
layer
,
nn
.
Linear
)
or
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
if
len
(
inp
.
shape
)
==
3
:
inp
=
inp
.
reshape
((
-
1
,
inp
.
shape
[
-
1
]))
inp
=
inp
.
t
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
unfold
=
nn
.
Unfold
(
self
.
layer
.
kernel_size
,
dilation
=
self
.
layer
.
dilation
,
padding
=
self
.
layer
.
padding
,
stride
=
self
.
layer
.
stride
,
)
inp
=
unfold
(
inp
)
inp
=
inp
.
permute
([
1
,
0
,
2
])
inp
=
inp
.
flatten
(
1
)
self
.
H
*=
self
.
nsamples
/
(
self
.
nsamples
+
tmp
)
self
.
nsamples
+=
tmp
# inp = inp.float()
inp
=
math
.
sqrt
(
2
/
self
.
nsamples
)
*
inp
.
float
()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self
.
H
+=
inp
.
matmul
(
inp
.
t
())
def
fasterquant
(
self
,
blocksize
=
128
,
percdamp
=
0.01
,
group_size
=-
1
,
actorder
=
False
,
static_groups
=
False
,
):
W
=
self
.
layer
.
weight
.
data
.
clone
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
W
=
W
.
flatten
(
1
)
if
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
W
=
W
.
t
()
W
=
W
.
float
()
tick
=
time
.
time
()
if
not
self
.
quantizer
.
ready
():
self
.
quantizer
.
find_params
(
W
,
weight
=
True
)
H
=
self
.
H
del
self
.
H
dead
=
torch
.
diag
(
H
)
==
0
H
[
dead
,
dead
]
=
1
W
[:,
dead
]
=
0
g_idx
=
[]
scale
=
[]
zero
=
[]
now_idx
=
1
if
static_groups
:
import
copy
groups
=
[]
for
i
in
range
(
0
,
self
.
columns
,
group_size
):
quantizer
=
copy
.
deepcopy
(
self
.
quantizer
)
quantizer
.
find_params
(
W
[:,
i
:
(
i
+
group_size
)],
weight
=
True
)
scale
.
append
(
quantizer
.
scale
)
zero
.
append
(
quantizer
.
zero
)
groups
.
append
(
quantizer
)
if
actorder
:
perm
=
torch
.
argsort
(
torch
.
diag
(
H
),
descending
=
True
)
W
=
W
[:,
perm
]
H
=
H
[
perm
][:,
perm
]
invperm
=
torch
.
argsort
(
perm
)
Losses
=
torch
.
zeros_like
(
W
)
Q
=
torch
.
zeros_like
(
W
)
damp
=
percdamp
*
torch
.
mean
(
torch
.
diag
(
H
))
diag
=
torch
.
arange
(
self
.
columns
,
device
=
self
.
dev
)
H
[
diag
,
diag
]
+=
damp
H
=
torch
.
linalg
.
cholesky
(
H
)
H
=
torch
.
cholesky_inverse
(
H
)
H
=
torch
.
linalg
.
cholesky
(
H
,
upper
=
True
)
Hinv
=
H
for
i1
in
range
(
0
,
self
.
columns
,
blocksize
):
i2
=
min
(
i1
+
blocksize
,
self
.
columns
)
count
=
i2
-
i1
W1
=
W
[:,
i1
:
i2
].
clone
()
Q1
=
torch
.
zeros_like
(
W1
)
Err1
=
torch
.
zeros_like
(
W1
)
Losses1
=
torch
.
zeros_like
(
W1
)
Hinv1
=
Hinv
[
i1
:
i2
,
i1
:
i2
]
for
i
in
range
(
count
):
w
=
W1
[:,
i
]
d
=
Hinv1
[
i
,
i
]
if
group_size
!=
-
1
:
if
not
static_groups
:
if
(
i1
+
i
)
%
group_size
==
0
:
self
.
quantizer
.
find_params
(
W
[:,
(
i1
+
i
)
:
(
i1
+
i
+
group_size
)],
weight
=
True
)
if
((
i1
+
i
)
//
group_size
)
-
now_idx
==
-
1
:
scale
.
append
(
self
.
quantizer
.
scale
)
zero
.
append
(
self
.
quantizer
.
zero
)
now_idx
+=
1
else
:
idx
=
i1
+
i
if
actorder
:
idx
=
perm
[
idx
]
self
.
quantizer
=
groups
[
idx
//
group_size
]
q
=
self
.
quantizer
.
quantize
(
w
.
unsqueeze
(
1
)).
flatten
()
Q1
[:,
i
]
=
q
Losses1
[:,
i
]
=
(
w
-
q
)
**
2
/
d
**
2
err1
=
(
w
-
q
)
/
d
W1
[:,
i
:]
-=
err1
.
unsqueeze
(
1
).
matmul
(
Hinv1
[
i
,
i
:].
unsqueeze
(
0
))
Err1
[:,
i
]
=
err1
Q
[:,
i1
:
i2
]
=
Q1
Losses
[:,
i1
:
i2
]
=
Losses1
/
2
W
[:,
i2
:]
-=
Err1
.
matmul
(
Hinv
[
i1
:
i2
,
i2
:])
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
layer
.
weight
.
data
[:,
:
i2
]
=
Q
[:,
:
i2
]
self
.
layer
.
weight
.
data
[:,
i2
:]
=
W
[:,
i2
:]
logger
.
debug
(
torch
.
sum
((
self
.
layer
(
self
.
inp1
)
-
self
.
out1
)
**
2
))
logger
.
debug
(
torch
.
sum
(
Losses
))
torch
.
cuda
.
synchronize
()
logger
.
info
(
f
"duration:
{
(
time
.
time
()
-
tick
)
}
"
)
logger
.
info
(
f
"avg loss:
{
torch
.
sum
(
Losses
).
item
()
/
self
.
nsamples
}
"
)
group_size
=
group_size
if
group_size
!=
-
1
else
self
.
columns
if
static_groups
and
actorder
:
g_idx
=
[
perm
[
i
]
//
group_size
for
i
in
range
(
self
.
columns
)]
else
:
g_idx
=
[
i
//
group_size
for
i
in
range
(
self
.
columns
)]
g_idx
=
torch
.
tensor
(
g_idx
,
dtype
=
torch
.
int32
,
device
=
Q
.
device
)
if
actorder
:
Q
=
Q
[:,
invperm
]
g_idx
=
g_idx
[
invperm
]
if
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
Q
=
Q
.
t
()
self
.
layer
.
weight
.
data
=
Q
.
reshape
(
self
.
layer
.
weight
.
shape
).
type_as
(
self
.
layer
.
weight
.
data
)
if
os
.
environ
.
get
(
"DEBUG"
):
logger
.
debug
(
torch
.
sum
((
self
.
layer
(
self
.
inp1
)
-
self
.
out1
)
**
2
))
if
scale
==
[]:
scale
.
append
(
self
.
quantizer
.
scale
)
zero
.
append
(
self
.
quantizer
.
zero
)
scale
=
torch
.
cat
(
scale
,
dim
=
1
)
zero
=
torch
.
cat
(
zero
,
dim
=
1
)
return
scale
,
zero
,
g_idx
def
free
(
self
):
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
inp1
=
None
self
.
out1
=
None
self
.
H
=
None
self
.
Losses
=
None
self
.
Trace
=
None
torch
.
cuda
.
empty_cache
()
__all__
=
[
"GPTQ"
]
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py
0 → 100644
View file @
18c42e67
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
get_pack_factor
(
num_bits
:
int
):
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"Marlin does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"Marlin does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
raise
ValueError
(
f
"Marlin does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
# Init
self
.
pack_factor
=
get_pack_factor
(
weight_bits
)
self
.
tile_size
=
GPTQ_MARLIN_TILE
self
.
min_thread_n
=
GPTQ_MARLIN_MIN_THREAD_N
self
.
min_thread_k
=
GPTQ_MARLIN_MIN_THREAD_K
self
.
max_parallel
=
GPTQ_MARLIN_MAX_PARALLEL
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"gptq"
:
logger
.
info
(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
is_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
return
False
# If the capability of the device is too low, cannot convert.
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
cls
.
get_min_capability
():
return
False
# Otherwise, can convert if model satisfies marlin constraints.
return
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and
sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
# Validate dtype
if
params_dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]:
raise
ValueError
(
f
"The params dtype must be float16 "
f
"or bfloat16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_thread_n
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
self
.
quant_config
.
min_thread_n
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_thread_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
self
.
quant_config
.
min_thread_k
}
."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
)
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size
=
input_size
//
group_size
scales_and_zp_input_dim
=
None
if
self
.
quant_config
.
desc_act
:
# Act-order case
assert
self
.
quant_config
.
group_size
!=
-
1
is_k_full
=
input_size_per_partition
==
input_size
else
:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full
=
True
# If this is a row-parallel case, then shard scales/zp
if
(
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_input_dim
=
0
# Init buffers
# Quantized weights
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Activation order
g_idx
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
g_idx_sort_indices
=
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
},
)
# Quantized zero-points
qzeros
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
device
=
"meta"
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Allocate marlin workspace
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_thread_n
)
*
self
.
quant_config
.
max_parallel
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
layer
.
workspace
=
workspace
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
is_k_full
=
is_k_full
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py
0 → 100644
View file @
18c42e67
from
logging
import
getLogger
import
torch
import
torch.nn
as
nn
logger
=
getLogger
(
__name__
)
def
quantize
(
x
,
scale
,
zero
,
maxq
):
if
maxq
<
0
:
return
(
x
>
scale
/
2
).
float
()
*
scale
+
(
x
<
zero
/
2
).
float
()
*
zero
q
=
torch
.
clamp
(
torch
.
round
(
x
/
scale
)
+
zero
,
0
,
maxq
)
return
scale
*
(
q
-
zero
)
class
Quantizer
(
nn
.
Module
):
def
__init__
(
self
,
shape
=
1
):
super
(
Quantizer
,
self
).
__init__
()
self
.
register_buffer
(
"maxq"
,
torch
.
tensor
(
0
))
self
.
register_buffer
(
"scale"
,
torch
.
zeros
(
shape
))
self
.
register_buffer
(
"zero"
,
torch
.
zeros
(
shape
))
def
configure
(
self
,
bits
,
perchannel
=
False
,
sym
=
True
,
mse
=
False
,
norm
=
2.4
,
grid
=
100
,
maxshrink
=
0.8
,
trits
=
False
,
):
self
.
maxq
=
torch
.
tensor
(
2
**
bits
-
1
)
self
.
perchannel
=
perchannel
self
.
sym
=
sym
self
.
mse
=
mse
self
.
norm
=
norm
self
.
grid
=
grid
self
.
maxshrink
=
maxshrink
if
trits
:
self
.
maxq
=
torch
.
tensor
(
-
1
)
def
find_params
(
self
,
x
,
weight
=
False
):
dev
=
x
.
device
self
.
maxq
=
self
.
maxq
.
to
(
dev
)
shape
=
x
.
shape
if
self
.
perchannel
:
if
weight
:
x
=
x
.
flatten
(
1
)
else
:
if
len
(
shape
)
==
4
:
x
=
x
.
permute
([
1
,
0
,
2
,
3
])
x
=
x
.
flatten
(
1
)
if
len
(
shape
)
==
3
:
x
=
x
.
reshape
((
-
1
,
shape
[
-
1
])).
t
()
if
len
(
shape
)
==
2
:
x
=
x
.
t
()
else
:
x
=
x
.
flatten
().
unsqueeze
(
0
)
tmp
=
torch
.
zeros
(
x
.
shape
[
0
],
device
=
dev
)
xmin
=
torch
.
minimum
(
x
.
min
(
1
)[
0
],
tmp
)
xmax
=
torch
.
maximum
(
x
.
max
(
1
)[
0
],
tmp
)
if
self
.
sym
:
xmax
=
torch
.
maximum
(
torch
.
abs
(
xmin
),
xmax
)
tmp
=
xmin
<
0
if
torch
.
any
(
tmp
):
xmin
[
tmp
]
=
-
xmax
[
tmp
]
tmp
=
(
xmin
==
0
)
&
(
xmax
==
0
)
xmin
[
tmp
]
=
-
1
xmax
[
tmp
]
=
+
1
if
self
.
maxq
<
0
:
self
.
scale
=
xmax
self
.
zero
=
xmin
else
:
self
.
scale
=
(
xmax
-
xmin
)
/
self
.
maxq
if
self
.
sym
:
self
.
zero
=
torch
.
full_like
(
self
.
scale
,
(
self
.
maxq
+
1
)
/
2
)
else
:
self
.
zero
=
torch
.
round
(
-
xmin
/
self
.
scale
)
if
self
.
mse
:
best
=
torch
.
full
([
x
.
shape
[
0
]],
float
(
"inf"
),
device
=
dev
)
for
i
in
range
(
int
(
self
.
maxshrink
*
self
.
grid
)):
p
=
1
-
i
/
self
.
grid
xmin1
=
p
*
xmin
xmax1
=
p
*
xmax
scale1
=
(
xmax1
-
xmin1
)
/
self
.
maxq
zero1
=
torch
.
round
(
-
xmin1
/
scale1
)
if
not
self
.
sym
else
self
.
zero
q
=
quantize
(
x
,
scale1
.
unsqueeze
(
1
),
zero1
.
unsqueeze
(
1
),
self
.
maxq
)
q
-=
x
q
.
abs_
()
q
.
pow_
(
self
.
norm
)
err
=
torch
.
sum
(
q
,
1
)
tmp
=
err
<
best
if
torch
.
any
(
tmp
):
best
[
tmp
]
=
err
[
tmp
]
self
.
scale
[
tmp
]
=
scale1
[
tmp
]
self
.
zero
[
tmp
]
=
zero1
[
tmp
]
if
not
self
.
perchannel
:
if
weight
:
tmp
=
shape
[
0
]
else
:
tmp
=
shape
[
1
]
if
len
(
shape
)
!=
3
else
shape
[
2
]
self
.
scale
=
self
.
scale
.
repeat
(
tmp
)
self
.
zero
=
self
.
zero
.
repeat
(
tmp
)
if
weight
:
shape
=
[
-
1
]
+
[
1
]
*
(
len
(
shape
)
-
1
)
self
.
scale
=
self
.
scale
.
reshape
(
shape
)
self
.
zero
=
self
.
zero
.
reshape
(
shape
)
return
if
len
(
shape
)
==
4
:
self
.
scale
=
self
.
scale
.
reshape
((
1
,
-
1
,
1
,
1
))
self
.
zero
=
self
.
zero
.
reshape
((
1
,
-
1
,
1
,
1
))
if
len
(
shape
)
==
3
:
self
.
scale
=
self
.
scale
.
reshape
((
1
,
1
,
-
1
))
self
.
zero
=
self
.
zero
.
reshape
((
1
,
1
,
-
1
))
if
len
(
shape
)
==
2
:
self
.
scale
=
self
.
scale
.
unsqueeze
(
0
)
self
.
zero
=
self
.
zero
.
unsqueeze
(
0
)
def
quantize
(
self
,
x
):
if
self
.
ready
():
return
quantize
(
x
,
self
.
scale
,
self
.
zero
,
self
.
maxq
)
return
x
def
enabled
(
self
):
return
self
.
maxq
>
0
def
ready
(
self
):
return
torch
.
all
(
self
.
scale
!=
0
)
__all__
=
[
"Quantizer"
]
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py
0 → 100644
View file @
18c42e67
import
torch
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
torch.nn.parameter
import
Parameter
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/__init__.py
0 → 100644
View file @
18c42e67
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/format_24.py
0 → 100644
View file @
18c42e67
#
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
#
import
torch
# This is PyTorch implementation of main part of reorder_meta()
# function, from tools/util/include/cutlass/util/host_reorder.h file
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
# GEMM decides upon layout of this matrix, and at the moment for the
# sparse GEMM executed on tensor cores, this is layout described by
# ColumnMajorInterleaved<2> data structure, in
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
# reordering of meta matrix into meta_reordered matrix calculated
# according to these segments of CUTLASS code is re-implemented here.
# Note that this calculation produces offsets for scattering metadata
# matrix elements into reordered metadata matrix elements (or,
# equivalently, for gathering reordered metadata matrix element back
# into metadata matrix elements).
def
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
):
dst_rows
=
torch
.
arange
(
0
,
m
,
device
=
device
)[:,
None
].
repeat
(
1
,
meta_ncols
)
dst_cols
=
torch
.
arange
(
0
,
meta_ncols
,
device
=
device
).
repeat
(
m
,
1
)
# Reorder the rows, then swizzle the 2x2 blocks.
group_x
=
64
group_y
=
32
if
meta_dtype
.
itemsize
==
2
else
16
dst_rows
=
(
dst_rows
//
group_x
*
group_x
+
(
dst_rows
%
2
)
*
2
+
(
dst_rows
%
8
)
//
4
+
((
dst_rows
%
group_y
)
%
4
)
//
2
*
32
+
((
dst_rows
%
group_x
)
//
8
)
*
4
)
topright
=
((
dst_rows
%
2
==
0
)
&
(
dst_cols
%
2
==
1
)).
to
(
torch
.
int8
)
bottomleft
=
((
dst_rows
%
2
==
1
)
&
(
dst_cols
%
2
==
0
)).
to
(
torch
.
int8
)
dst_rows
+=
topright
-
bottomleft
dst_cols
-=
topright
-
bottomleft
# Assumed that meta tensor is to be stored in CUTLASS
# InterleavedColumnMajor layout, and reverse engineered
# corresponding code to store values into this tensor.
interleave
=
2
cols_maj
=
dst_cols
//
interleave
cols_min
=
dst_cols
%
interleave
return
(
cols_maj
*
m
*
interleave
+
dst_rows
*
interleave
+
cols_min
).
view
(
-
1
)
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
def
sparse_semi_structured_from_dense_cutlass
(
dense
):
if
dense
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
m
,
k
=
dense
.
shape
device
=
dense
.
device
meta_dtype
=
torch
.
int8
if
dense
.
dtype
==
torch
.
int8
:
meta_dtype
=
torch
.
int32
elif
dense
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
int32
]:
meta_dtype
=
torch
.
int16
else
:
raise
RuntimeError
(
f
"Invalid datatype
{
dense
.
dtype
}
of dense matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
if
quadbits_per_meta_elem
not
in
(
4
,
8
):
raise
RuntimeError
(
"Invalid number of elements per meta element calculated"
)
if
meta_dtype
==
torch
.
int32
:
if
m
%
16
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 16"
)
else
:
if
m
%
32
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 32"
)
if
k
%
(
4
*
quadbits_per_meta_elem
)
!=
0
:
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
# noqa: E501
)
if
dense
.
dtype
!=
torch
.
float
:
ksparse
=
4
dense_4
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m1
,
m2
,
m3
=
(
dense_4
!=
0
).
unbind
(
-
1
)
else
:
ksparse
=
2
dense_2
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m2
=
m1
,
m3
=
(
dense_2
!=
0
).
unbind
(
-
1
)
meta_ncols
=
k
//
(
ksparse
*
quadbits_per_meta_elem
)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0
=
m0
&
m1
expr1
=
~
m0
&
m1
expr2
=
~
m0
&
~
m1
bit0
=
expr1
bit1
=
expr2
bit2
=
expr0
|
expr2
|
m3
bit3
=
expr1
|
~
m1
idxs0
=
bit0
|
(
bit1
.
to
(
torch
.
int64
)
<<
1
)
idxs1
=
bit2
|
(
bit3
.
to
(
torch
.
int64
)
<<
1
)
if
dense
.
dtype
!=
torch
.
float
:
sparse0
=
dense_4
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
))
# type: ignore[possibly-undefined]
sparse1
=
dense_4
.
gather
(
-
1
,
idxs1
.
unsqueeze
(
-
1
))
sparse
=
torch
.
stack
((
sparse0
,
sparse1
),
dim
=-
1
).
view
(
m
,
k
//
2
)
else
:
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
meta_4
=
idxs0
|
(
idxs1
<<
2
)
meta_n
=
meta_4
.
view
(
(
-
1
,
meta_ncols
,
quadbits_per_meta_elem
)).
to
(
meta_dtype
)
if
quadbits_per_meta_elem
==
4
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
))
elif
quadbits_per_meta_elem
==
8
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
)
|
(
meta_n
[:,
:,
4
]
<<
16
)
|
(
meta_n
[:,
:,
5
]
<<
20
)
|
(
meta_n
[:,
:,
6
]
<<
24
)
|
(
meta_n
[:,
:,
7
]
<<
28
))
# Reorder meta tensor elements.
meta_reordered
=
meta
.
new_empty
(
(
m
*
meta_ncols
,
))
# type: ignore[possibly-undefined]
meta_offsets
=
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
)
meta_reordered
.
scatter_
(
0
,
meta_offsets
,
meta
.
view
(
-
1
))
return
(
sparse
,
meta_reordered
.
view
(
m
,
meta_ncols
))
# This function performs reverse of the function above - it
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
def
sparse_semi_structured_to_dense_cutlass
(
sparse
,
meta_reordered
):
if
sparse
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional sparse tensor, got
{
sparse
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
m
,
k
=
sparse
.
shape
device
=
sparse
.
device
if
meta_reordered
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional meta tensor, got
{
meta_reordered
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
if
meta_reordered
.
device
!=
device
:
raise
RuntimeError
(
f
"Expected meta matrix to be on
{
device
}
device, got matrix on
{
meta_reordered
.
device
}
device"
# noqa: E501
)
meta_dtype
=
meta_reordered
.
dtype
if
meta_dtype
not
in
(
torch
.
int16
,
torch
.
int32
):
raise
RuntimeError
(
f
"Invalid datatype
{
meta_dtype
}
of meta matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
ksparse
=
4
if
sparse
.
dtype
!=
torch
.
float
else
2
meta_nrows
,
meta_ncols
=
meta_reordered
.
shape
if
meta_nrows
!=
m
:
raise
RuntimeError
(
f
"Number of rows of meta matrix
{
meta_nrows
}
must be equal to number of columns of spase matrix
{
m
}
"
# noqa: E501
)
if
meta_ncols
*
ksparse
*
quadbits_per_meta_elem
!=
2
*
k
:
raise
RuntimeError
(
f
"Number of columns of sparse matrix
{
k
}
different from the
{
meta_ncols
*
ksparse
*
quadbits_per_meta_elem
//
2
}
, "
# noqa: E501
"expected according to the number of columns of meta matrix"
)
# Undo meta tensor elements reordering.
meta_offsets
=
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
)
meta
=
torch
.
gather
(
meta_reordered
.
view
(
-
1
),
0
,
meta_offsets
).
view
(
m
,
meta_ncols
)
# Unpack sparse tensor back to original dense tensor, using
# information provided by meta tensor. Note that torch.float
# datatype is handled pretty much the same as
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
# value is encoded as if underlying 8 bytes contain four
# torch.half/torch.bfloat16 values, where either first two or last
# two are zeros.
meta_2
=
torch
.
empty
(
(
m
,
meta_ncols
,
2
*
quadbits_per_meta_elem
),
dtype
=
meta_dtype
,
device
=
device
,
)
if
quadbits_per_meta_elem
==
4
:
meta_2
[:,
:,
0
]
=
meta
&
0b11
meta_2
[:,
:,
1
]
=
(
meta
>>
2
)
&
0b11
meta_2
[:,
:,
2
]
=
(
meta
>>
4
)
&
0b11
meta_2
[:,
:,
3
]
=
(
meta
>>
6
)
&
0b11
meta_2
[:,
:,
4
]
=
(
meta
>>
8
)
&
0b11
meta_2
[:,
:,
5
]
=
(
meta
>>
10
)
&
0b11
meta_2
[:,
:,
6
]
=
(
meta
>>
12
)
&
0b11
meta_2
[:,
:,
7
]
=
(
meta
>>
14
)
&
0b11
elif
quadbits_per_meta_elem
==
8
:
meta_2
[:,
:,
0
]
=
meta
&
0b11
meta_2
[:,
:,
1
]
=
(
meta
>>
2
)
&
0b11
meta_2
[:,
:,
2
]
=
(
meta
>>
4
)
&
0b11
meta_2
[:,
:,
3
]
=
(
meta
>>
6
)
&
0b11
meta_2
[:,
:,
4
]
=
(
meta
>>
8
)
&
0b11
meta_2
[:,
:,
5
]
=
(
meta
>>
10
)
&
0b11
meta_2
[:,
:,
6
]
=
(
meta
>>
12
)
&
0b11
meta_2
[:,
:,
7
]
=
(
meta
>>
14
)
&
0b11
meta_2
[:,
:,
8
]
=
(
meta
>>
16
)
&
0b11
meta_2
[:,
:,
9
]
=
(
meta
>>
18
)
&
0b11
meta_2
[:,
:,
10
]
=
(
meta
>>
20
)
&
0b11
meta_2
[:,
:,
11
]
=
(
meta
>>
22
)
&
0b11
meta_2
[:,
:,
12
]
=
(
meta
>>
24
)
&
0b11
meta_2
[:,
:,
13
]
=
(
meta
>>
26
)
&
0b11
meta_2
[:,
:,
14
]
=
(
meta
>>
28
)
&
0b11
meta_2
[:,
:,
15
]
=
(
meta
>>
30
)
&
0b11
dense_offsets
=
meta_2
.
view
(
-
1
)
+
(
torch
.
arange
(
0
,
2
*
m
*
k
//
ksparse
,
device
=
device
)
*
4
).
view
(
-
1
,
1
).
repeat
(
1
,
2
).
view
(
-
1
)
dense
=
torch
.
zeros
((
m
*
2
*
k
,
),
dtype
=
sparse
.
dtype
,
device
=
device
)
if
sparse
.
dtype
!=
torch
.
float
:
# dense.scatter_(0, dense_offsets, sparse.view(-1))
dense
.
scatter_
(
0
,
dense_offsets
,
sparse
.
reshape
(
-
1
))
else
:
dense
.
view
(
torch
.
half
).
scatter_
(
0
,
dense_offsets
,
sparse
.
view
(
torch
.
half
).
view
(
-
1
))
return
dense
.
view
(
m
,
2
*
k
)
def
mask_creator
(
tensor
):
"""
Class for creating N:M sparsity masks.
Masks will be created using the N:M ratio, where for every block of
M weights, N will be pruned based on ranked weight value. Each mask
will correspond to the given tensor.
:param N: The number of weights in a group to keep
:param M: The size of a weight group
"""
N
=
2
M
=
4
mask
=
None
# for i, tensor in enumerate(tensors):
if
tensor
.
numel
()
%
M
!=
0
:
raise
ValueError
(
f
"Tensor of size
{
tensor
.
shape
}
can't be evenly divided into "
f
"
{
M
}
groups"
)
num_groups
=
tensor
.
numel
()
//
M
# N:M sparsity for linear layers
tensor_temp
=
tensor
.
detach
().
abs
().
reshape
(
num_groups
,
M
)
index
=
torch
.
argsort
(
tensor_temp
,
dim
=
1
)[:,
:
int
(
M
-
N
)]
w_b
=
torch
.
ones
(
tensor_temp
.
shape
,
device
=
tensor_temp
.
device
)
mask
=
w_b
.
scatter_
(
dim
=
1
,
index
=
index
,
value
=
0
).
reshape
(
tensor
.
shape
)
return
mask
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_24_perms.py
0 → 100644
View file @
18c42e67
"""This file is used for /tests and /benchmarks"""
from
typing
import
Dict
,
List
import
numpy
import
torch
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms_24
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
col_o
=
col
//
2
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col_o
*
256
+
8
*
(
col
%
2
)
+
4
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
1
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
ValueError
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
*
8
+
j
for
j
in
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm_single
.
extend
([
8
*
i
+
j
for
j
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_24_perm
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
marlin_24_scale_perm
:
Dict
[
int
,
List
[
int
]]
=
{}
marlin_24_scale_perm_single
:
Dict
[
int
,
List
[
int
]]
=
{}
for
num_bits
in
[
4
,
8
]:
perm_24
,
scale_perm_24
,
scale_perm_single_24
=
get_perms_24
(
num_bits
)
marlin_24_perm
[
num_bits
]
=
perm_24
marlin_24_scale_perm
[
num_bits
]
=
scale_perm_24
marlin_24_scale_perm_single
[
num_bits
]
=
scale_perm_single_24
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_perms.py
0 → 100644
View file @
18c42e67
"""This file is used for /tests and /benchmarks"""
from
typing
import
Dict
,
List
import
numpy
import
torch
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_perm
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
marlin_scale_perm
:
Dict
[
int
,
List
[
int
]]
=
{}
marlin_scale_perm_single
:
Dict
[
int
,
List
[
int
]]
=
{}
for
num_bits
in
[
4
,
8
]:
perm
,
scale_perm
,
scale_perm_single
=
get_perms
(
num_bits
)
marlin_perm
[
num_bits
]
=
perm
marlin_scale_perm
[
num_bits
]
=
scale_perm
marlin_scale_perm_single
[
num_bits
]
=
scale_perm_single
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
0 → 100644
View file @
18c42e67
"""This file is used for /tests and /benchmarks"""
import
random
import
numpy
import
torch
from
ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.format_24
import
(
mask_creator
,
sparse_semi_structured_from_dense_cutlass
)
from
ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_24_perms
import
(
marlin_24_perm
,
marlin_24_scale_perm
,
marlin_24_scale_perm_single
)
from
ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_perms
import
(
marlin_perm
,
marlin_scale_perm
,
marlin_scale_perm_single
)
from
ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
sort_weights
)
__cuda_arch
=
torch
.
cuda
.
get_device_capability
()
MARLIN_TILE
=
16
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
def
is_marlin_supported
():
return
__cuda_arch
[
0
]
>=
8
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
# Permute weights to 16x64 marlin tiles
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
tile
,
size_n
//
tile
,
tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
((
-
1
,
perm
.
numel
()))[:,
perm
].
reshape
(
q_w
.
shape
)
return
q_w
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
scale_perm
,
scale_perm_single
):
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Reformat to marlin
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
marlin_perm
[
num_bits
])
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_scale_perm
[
num_bits
],
marlin_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
inject_24
(
w
,
size_k
,
size_n
):
assert
w
.
shape
==
(
size_k
,
size_n
)
mask
=
mask_creator
(
w
.
t
()).
t
().
cuda
().
bool
()
return
(
mask
*
w
).
contiguous
(),
mask
.
contiguous
()
def
check_24
(
w
,
num_rows_to_sample
=
50
,
_verbose
=
False
):
BLOCK_SIZE
=
4
MAX_NON_ZEROS
=
2
w
=
w
.
t
().
contiguous
()
print
(
"check_24: w.shape = {}"
.
format
(
w
.
shape
))
num_rows
,
num_cols
=
w
.
shape
sampled_row_idxs
=
random
.
choices
(
range
(
num_rows
),
k
=
num_rows_to_sample
)
if
_verbose
:
print
(
f
"Sampled row idxs =
{
sampled_row_idxs
}
"
)
total_segments
=
0
non_24_segments
=
0
for
i
in
sampled_row_idxs
:
for
j
in
range
(
0
,
num_cols
-
BLOCK_SIZE
,
BLOCK_SIZE
):
total_segments
+=
1
block
=
w
[
i
,
j
:
j
+
BLOCK_SIZE
]
num_nonzero
=
torch
.
count_nonzero
(
block
)
if
num_nonzero
>
MAX_NON_ZEROS
:
print
(
"i = {} j = {} block = {}"
.
format
(
i
,
j
,
block
))
non_24_segments
+=
1
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
q_24_no_zp_comp
,
meta
=
sparse_semi_structured_from_dense_cutlass
(
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
return
q_24_comp
,
meta
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Inject 2:4 sparsity
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
num_bits
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num_bits
,
marlin_24_perm
[
num_bits
])
marlin_24_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_24_scale_perm
[
num_bits
],
marlin_24_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_24_ref
,
marlin_24_q_w_comp
,
meta
,
marlin_24_s
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
max_workspace_size
=
((
out_features
//
min_thread_n
)
*
max_parallel
)
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py
0 → 100644
View file @
18c42e67
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
SUPPORTED_NUM_BITS
=
[
4
,
8
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
):
assert
q_w
.
shape
==
w_ref
.
shape
orig_device
=
q_w
.
device
k_size
,
_
=
q_w
.
shape
g_idx
=
torch
.
zeros
((
k_size
,
),
dtype
=
torch
.
int32
)
for
i
in
range
(
k_size
):
g_idx
[
i
]
=
i
//
group_size
# Simulate act_order by doing a random permutation on K
rand_perm
=
torch
.
randperm
(
k_size
)
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
w_ref
=
w_ref
[
rand_perm
,
:].
contiguous
()
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
w
=
w
.
view
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
s
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s
*=
2
/
max_q_val
# 2 => symmetric
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s
# Restore original shapes
if
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
# Apply act_order
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
rand_perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
assert
(
group_size
<
size_k
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
w_ref
,
q_w
,
g_idx
,
rand_perm
=
permute_rows
(
q_w
,
w_ref
,
group_size
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
orig_device
=
q_w
.
device
sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
dtype
=
torch
.
int32
)
# Sort based on g_idx
g_idx
=
g_idx
[
sort_indices
].
contiguous
()
q_w
=
q_w
[
sort_indices
,
:].
contiguous
()
return
(
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
sort_indices
.
to
(
device
=
orig_device
),
)
def
gptq_pack
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_k
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
//
pack_factor
,
size_n
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[
i
::
pack_factor
,
:]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_res
ktransformers/ktransformers_ext/operators/llamafile/conversion.h
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:34:55
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_CONVERSION_H
#define CPUINFER_CONVERSION_H
#include <memory.h>
#include "llama.cpp/ggml.h"
inline
void
to_float
(
const
void
*
input
,
float
*
output
,
int
size
,
ggml_type
type
)
{
if
(
type
==
ggml_type
::
GGML_TYPE_F32
)
{
memcpy
(
output
,
input
,
size
*
sizeof
(
float
));
}
else
{
ggml_internal_get_type_traits
(
type
).
to_float
(
input
,
output
,
size
);
}
}
inline
void
from_float
(
const
float
*
input
,
void
*
output
,
int
size
,
ggml_type
type
)
{
if
(
type
==
ggml_type
::
GGML_TYPE_F32
)
{
memcpy
(
output
,
input
,
size
*
sizeof
(
float
));
}
else
{
ggml_internal_get_type_traits
(
type
).
from_float
(
input
,
output
,
size
);
}
}
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:34:58
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "linear.h"
Linear
::
Linear
(
LinearConfig
config
)
{
config_
=
config
;
proj_
=
config_
.
proj
;
input_fp32_
.
resize
(
config_
.
input_size
);
proj_input_
.
resize
(
config_
.
input_size
*
4
);
proj_output_
.
resize
(
config_
.
output_size
);
}
void
Linear
::
warm_up
(
Backend
*
backend
)
{
std
::
vector
<
float
>
input_fp32
(
config_
.
input_size
);
std
::
vector
<
uint8_t
>
input
(
config_
.
input_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
std
::
vector
<
uint8_t
>
output
(
config_
.
output_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
for
(
int
i
=
0
;
i
<
config_
.
input_size
;
i
++
)
{
input_fp32
[
i
]
=
0
;
}
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
input_size
,
config_
.
hidden_type
);
forward
(
input
.
data
(),
output
.
data
(),
backend
);
}
void
Linear
::
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
const
void
*
proj_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
)
{
proj_input_ptr
=
input
;
}
else
{
to_float
(
input
,
input_fp32_
.
data
(),
config_
.
input_size
,
config_
.
hidden_type
);
from_float
(
input_fp32_
.
data
(),
proj_input_
.
data
(),
config_
.
input_size
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
);
proj_input_ptr
=
proj_input_
.
data
();
}
int
nth
=
config_
.
output_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
%
nth
;
llamafile_sgemm
(
config_
.
output_size
,
1
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_input_ptr
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_output_
.
data
(),
config_
.
output_size
,
ith
,
nth
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
proj_type
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
});
from_float
(
proj_output_
.
data
(),
output
,
config_
.
output_size
,
config_
.
hidden_type
);
}
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/linear.h
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:00
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_LINEAR_H
#define CPUINFER_OPERATOR_LINEAR_H
#include <cmath>
#include <cstdio>
#include <functional>
#include <mutex>
#include <vector>
#include "../../cpu_backend/backend.h"
#include "conversion.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
struct
LinearConfig
{
int
input_size
;
int
output_size
;
int
stride
;
void
*
proj
;
ggml_type
proj_type
;
ggml_type
hidden_type
;
LinearConfig
()
{}
LinearConfig
(
int
input_size
,
int
output_size
,
int
stride
,
void
*
proj
,
ggml_type
proj_type
,
ggml_type
hidden_type
)
:
input_size
(
input_size
),
output_size
(
output_size
),
stride
(
stride
),
proj
(
proj
),
proj_type
(
proj_type
),
hidden_type
(
hidden_type
)
{}
};
class
Linear
{
public:
Linear
(
LinearConfig
);
void
warm_up
(
Backend
*
backend
);
void
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
private:
LinearConfig
config_
;
void
*
proj_
;
// [output_size * input_size ( /32 if quantized)]
std
::
vector
<
float
>
input_fp32_
;
// [input_size]
std
::
vector
<
uint8_t
>
proj_input_
;
// [input_size * 4]
std
::
vector
<
float
>
proj_output_
;
// [output_size]
};
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "mlp.h"
MLP
::
MLP
(
MLPConfig
config
)
{
config_
=
config
;
gate_proj_
=
config_
.
gate_proj
;
up_proj_
=
config_
.
up_proj
;
down_proj_
=
config_
.
down_proj
;
input_fp32_
.
resize
(
config_
.
hidden_size
);
gate_input_
.
resize
(
config_
.
hidden_size
*
4
);
up_input_
.
resize
(
config_
.
hidden_size
*
4
);
gate_output_
.
resize
(
config_
.
intermediate_size
);
up_output_
.
resize
(
config_
.
intermediate_size
);
intermediate_fp32_
.
resize
(
config_
.
intermediate_size
);
down_input_
.
resize
(
config_
.
intermediate_size
*
4
);
down_output_
.
resize
(
config_
.
hidden_size
);
}
void
MLP
::
warm_up
(
Backend
*
backend
)
{
std
::
vector
<
float
>
input_fp32
(
config_
.
hidden_size
);
std
::
vector
<
uint8_t
>
input
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
std
::
vector
<
uint8_t
>
output
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
for
(
int
i
=
0
;
i
<
config_
.
hidden_size
;
i
++
)
{
input_fp32
[
i
]
=
0
;
}
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
forward
(
input
.
data
(),
output
.
data
(),
backend
);
}
static
float
act_fn
(
float
x
)
{
return
x
/
(
1.0
f
+
expf
(
-
x
));
}
void
MLP
::
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
const
void
*
gate_input_ptr
;
const
void
*
up_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
&&
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
gate_input_ptr
=
up_input_ptr
=
input
;
}
else
{
to_float
(
input
,
input_fp32_
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
if
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
(),
gate_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
up_input_ptr
=
gate_input_
.
data
();
}
else
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
(),
gate_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
gate_input_
.
data
();
}
else
{
gate_input_ptr
=
input
;
}
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
(),
up_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
up_input_ptr
=
up_input_
.
data
();
}
else
{
up_input_ptr
=
input
;
}
}
}
int
nth
=
config_
.
intermediate_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
void
*
gate_proj_ptr
=
gate_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
float
*
gate_output_ptr
=
gate_output_
.
data
()
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
up_proj_ptr
=
up_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
float
*
up_output_ptr
=
up_output_
.
data
()
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
ith
*
config_
.
stride
;
i
<
(
ith
+
1
)
*
config_
.
stride
;
i
++
)
{
intermediate_fp32_
[
i
]
=
act_fn
(
gate_output_
[
i
])
*
up_output_
[
i
];
}
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
==
0
)
{
float
*
intermediate_fp32_ptr
=
intermediate_fp32_
.
data
()
+
ith
*
config_
.
stride
;
void
*
down_input_ptr
=
down_input_
.
data
()
+
ith
*
config_
.
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
config_
.
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
});
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
!=
0
)
{
from_float
(
intermediate_fp32_
.
data
(),
down_input_
.
data
(),
config_
.
intermediate_size
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
nth
=
config_
.
hidden_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
void
*
down_proj_ptr
=
down_proj_
+
ith
*
config_
.
stride
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
float
*
down_output_ptr
=
down_output_
.
data
()
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_input_
.
data
(),
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
void
*
output_ptr
=
output
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
from_float
(
down_output_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
}
});
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
!=
0
)
{
from_float
(
down_output_
.
data
(),
output
,
config_
.
hidden_size
,
config_
.
hidden_type
);
}
}
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_MLP_H
#define CPUINFER_OPERATOR_MLP_H
#include <cmath>
#include <cstdio>
#include <functional>
#include <mutex>
#include <vector>
#include "../../cpu_backend/backend.h"
#include "conversion.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
struct
MLPConfig
{
int
hidden_size
;
int
intermediate_size
;
int
stride
;
void
*
gate_proj
;
void
*
up_proj
;
void
*
down_proj
;
ggml_type
gate_type
;
ggml_type
up_type
;
ggml_type
down_type
;
ggml_type
hidden_type
;
MLPConfig
()
{}
MLPConfig
(
int
hidden_size
,
int
intermediate_size
,
int
stride
,
void
*
gate_proj
,
void
*
up_proj
,
void
*
down_proj
,
ggml_type
gate_type
,
ggml_type
up_type
,
ggml_type
down_type
,
ggml_type
hidden_type
)
:
hidden_size
(
hidden_size
),
intermediate_size
(
intermediate_size
),
stride
(
stride
),
gate_proj
(
gate_proj
),
up_proj
(
up_proj
),
down_proj
(
down_proj
),
gate_type
(
gate_type
),
up_type
(
up_type
),
down_type
(
down_type
),
hidden_type
(
hidden_type
)
{}
};
class
MLP
{
public:
MLP
(
MLPConfig
);
void
warm_up
(
Backend
*
backend
);
void
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
private:
MLPConfig
config_
;
void
*
gate_proj_
;
// [intermediate_size * hidden_size ( /32 if quantized)]
void
*
up_proj_
;
// [intermediate_size * hidden_size ( /32 if quantized)]
void
*
down_proj_
;
// [hidden_size * intermediate_size ( /32 if quantized)]
std
::
vector
<
float
>
input_fp32_
;
// [hidden_size]
std
::
vector
<
uint8_t
>
gate_input_
;
// [hidden_size * 4]
std
::
vector
<
uint8_t
>
up_input_
;
// [hidden_size * 4]
std
::
vector
<
float
>
gate_output_
;
// [intermediate_size]
std
::
vector
<
float
>
up_output_
;
// [intermediate_size]
std
::
vector
<
float
>
intermediate_fp32_
;
// [intermediate_size]
std
::
vector
<
uint8_t
>
down_input_
;
// [intermediate_size * 4]
std
::
vector
<
float
>
down_output_
;
// [hidden_size]
};
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:07
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "moe.h"
#include <iostream>
#include "unistd.h"
void
*
MOE
::
buffer_
=
nullptr
;
MOE
::
MOE
(
MOEConfig
config
)
{
config_
=
config
;
gate_proj_
=
config_
.
gate_proj
;
up_proj_
=
config_
.
up_proj
;
down_proj_
=
config_
.
down_proj
;
if
(
MOE
::
buffer_
==
nullptr
)
{
uint64_t
buffer_size
=
0
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
;
buffer_size
+=
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
buffer_size
+=
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
;
buffer_
=
malloc
(
buffer_size
);
}
uint64_t
offset
=
0
;
s_input_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
s_gate_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
s_up_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
s_gate_output_
.
resize
(
config_
.
routed_expert_num
);
s_up_output_
.
resize
(
config_
.
routed_expert_num
);
s_intermediate_fp32_
.
resize
(
config_
.
routed_expert_num
);
s_down_input_
.
resize
(
config_
.
routed_expert_num
);
s_down_output_
.
resize
(
config_
.
routed_expert_num
);
for
(
int
i
=
0
;
i
<
config_
.
routed_expert_num
;
i
++
)
{
s_gate_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_up_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_intermediate_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_down_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
s_down_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
}
s_output_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
=
0
;
m_input_fp32_
.
resize
(
config_
.
group_max_len
);
m_gate_input_
.
resize
(
config_
.
group_max_len
);
m_up_input_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
m_input_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
m_gate_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
m_up_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
}
m_local_gate_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
m_local_up_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
m_local_gate_output_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_up_output_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_intermediate_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_down_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
m_local_down_output_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
;
m_output_fp32_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
m_output_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
}
m_local_pos_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
m_local_pos_
[
i
].
reserve
(
config_
.
expert_num
);
}
m_local_num_
.
resize
(
config_
.
expert_num
);
m_local_gate_input_ptr_
.
resize
(
config_
.
expert_num
);
m_local_up_input_ptr_
.
resize
(
config_
.
expert_num
);
m_local_gate_output_ptr_
.
resize
(
config_
.
expert_num
);
m_local_up_output_ptr_
.
resize
(
config_
.
expert_num
);
m_local_intermediate_fp32_ptr_
.
resize
(
config_
.
expert_num
);
m_local_down_input_ptr_
.
resize
(
config_
.
expert_num
);
m_local_down_output_ptr_
.
resize
(
config_
.
expert_num
);
}
void
MOE
::
warm_up
(
Backend
*
backend
)
{
std
::
vector
<
float
>
input_fp32
(
config_
.
hidden_size
);
std
::
vector
<
uint8_t
>
input
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
std
::
vector
<
uint8_t
>
output
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
for
(
int
i
=
0
;
i
<
config_
.
hidden_size
;
i
++
)
{
input_fp32
[
i
]
=
0
;
}
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
for
(
int
i
=
0
;
i
<
config_
.
expert_num
;
i
++
)
{
uint64_t
expert_ids
=
i
;
float
weights
=
0
;
forward_one
(
1
,
&
expert_ids
,
&
weights
,
input
.
data
(),
output
.
data
(),
backend
);
}
}
static
float
act_fn
(
float
x
)
{
return
x
/
(
1.0
f
+
expf
(
-
x
));
}
void
MOE
::
forward_one
(
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
const
void
*
gate_input_ptr
;
const
void
*
up_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
&&
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
gate_input_ptr
=
up_input_ptr
=
input
;
}
else
{
to_float
(
input
,
s_input_fp32_
,
config_
.
hidden_size
,
config_
.
hidden_type
);
if
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
s_input_fp32_
,
s_gate_input_
,
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
up_input_ptr
=
s_gate_input_
;
}
else
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
{
from_float
(
s_input_fp32_
,
s_gate_input_
,
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
s_gate_input_
;
}
else
{
gate_input_ptr
=
input
;
}
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
s_input_fp32_
,
s_up_input_
,
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
up_input_ptr
=
s_up_input_
;
}
else
{
up_input_ptr
=
input
;
}
}
}
int
nth
=
config_
.
intermediate_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
*
k
,
[
&
](
int
task_id
)
{
int
expert_idx
=
task_id
/
nth
;
uint64_t
expert_id
=
expert_ids
[
expert_idx
];
int
ith
=
task_id
%
nth
;
void
*
gate_proj_ptr
=
gate_proj_
+
(
expert_id
*
config_
.
intermediate_size
+
ith
*
config_
.
stride
)
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
float
*
gate_output_ptr
=
s_gate_output_
[
expert_idx
]
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
up_proj_ptr
=
up_proj_
+
(
expert_id
*
config_
.
intermediate_size
+
ith
*
config_
.
stride
)
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
float
*
up_output_ptr
=
s_up_output_
[
expert_idx
]
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
ith
*
config_
.
stride
;
i
<
(
ith
+
1
)
*
config_
.
stride
;
i
++
)
{
s_intermediate_fp32_
[
expert_idx
][
i
]
=
act_fn
(
s_gate_output_
[
expert_idx
][
i
])
*
s_up_output_
[
expert_idx
][
i
];
}
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
==
0
)
{
float
*
intermediate_fp32_ptr
=
s_intermediate_fp32_
[
expert_idx
]
+
ith
*
config_
.
stride
;
void
*
down_input_ptr
=
s_down_input_
[
expert_idx
]
+
ith
*
config_
.
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
config_
.
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
});
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
!=
0
)
{
for
(
int
i
=
0
;
i
<
k
;
i
++
)
{
from_float
(
s_intermediate_fp32_
[
i
],
s_down_input_
[
i
],
config_
.
intermediate_size
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
}
nth
=
config_
.
hidden_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
for
(
int
i
=
ith
*
config_
.
stride
;
i
<
(
ith
+
1
)
*
config_
.
stride
;
i
++
)
{
s_output_fp32_
[
i
]
=
0
;
}
for
(
int
expert_idx
=
0
;
expert_idx
<
k
;
expert_idx
++
)
{
uint64_t
expert_id
=
expert_ids
[
expert_idx
];
void
*
down_proj_ptr
=
down_proj_
+
(
expert_id
*
config_
.
hidden_size
+
ith
*
config_
.
stride
)
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
float
*
down_output_ptr
=
s_down_output_
[
expert_idx
]
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
s_down_input_
[
expert_idx
],
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
ith
*
config_
.
stride
;
i
<
(
ith
+
1
)
*
config_
.
stride
;
i
++
)
{
s_output_fp32_
[
i
]
+=
s_down_output_
[
expert_idx
][
i
]
*
weights
[
expert_idx
];
}
}
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
float
*
output_fp32_ptr
=
s_output_fp32_
+
ith
*
config_
.
stride
;
void
*
output_ptr
=
output
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
from_float
(
output_fp32_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
}
});
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
!=
0
)
{
from_float
(
s_output_fp32_
,
output
,
config_
.
hidden_size
,
config_
.
hidden_type
);
}
}
void
MOE
::
forward_many
(
int
qlen
,
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
for
(
int
i
=
0
;
i
<
config_
.
expert_num
;
i
++
)
{
m_local_num_
[
i
]
=
0
;
}
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
for
(
int
j
=
0
;
j
<
k
;
j
++
)
{
m_local_pos_
[
i
][
j
]
=
m_local_num_
[
expert_ids
[
i
*
k
+
j
]]
++
;
}
}
uint64_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
config_
.
expert_num
;
i
++
)
{
m_local_gate_input_ptr_
[
i
]
=
m_local_gate_input_
+
offset
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
m_local_up_input_ptr_
[
i
]
=
m_local_up_input_
+
offset
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
m_local_gate_output_ptr_
[
i
]
=
m_local_gate_output_
+
offset
*
config_
.
intermediate_size
;
m_local_up_output_ptr_
[
i
]
=
m_local_up_output_
+
offset
*
config_
.
intermediate_size
;
m_local_intermediate_fp32_ptr_
[
i
]
=
m_local_intermediate_fp32_
+
offset
*
config_
.
intermediate_size
;
m_local_down_input_ptr_
[
i
]
=
m_local_down_input_
+
offset
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
m_local_down_output_ptr_
[
i
]
=
m_local_down_output_
+
offset
*
config_
.
hidden_size
;
offset
+=
m_local_num_
[
i
];
}
backend
->
do_work_stealing_job
(
qlen
,
[
&
](
int
i
)
{
const
void
*
gate_input_ptr
;
const
void
*
up_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
&&
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
gate_input_ptr
=
up_input_ptr
=
input
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
}
else
{
to_float
(
input
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
m_input_fp32_
[
i
],
config_
.
hidden_size
,
config_
.
hidden_type
);
if
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
m_input_fp32_
[
i
],
m_gate_input_
[
i
],
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
up_input_ptr
=
m_gate_input_
[
i
];
}
else
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
{
from_float
(
m_input_fp32_
[
i
],
m_gate_input_
[
i
],
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
m_gate_input_
[
i
];
}
else
{
gate_input_ptr
=
input
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
}
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
m_input_fp32_
[
i
],
m_up_input_
[
i
],
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
up_input_ptr
=
m_up_input_
[
i
];
}
else
{
up_input_ptr
=
input
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
}
}
}
for
(
int
j
=
0
;
j
<
k
;
j
++
)
{
memcpy
(
m_local_gate_input_ptr_
[
expert_ids
[
i
*
k
+
j
]]
+
m_local_pos_
[
i
][
j
]
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
),
gate_input_ptr
,
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
));
memcpy
(
m_local_up_input_ptr_
[
expert_ids
[
i
*
k
+
j
]]
+
m_local_pos_
[
i
][
j
]
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
),
up_input_ptr
,
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
));
}
});
int
stride
=
QK_K
;
int
nth
=
config_
.
intermediate_size
/
stride
;
backend
->
do_work_stealing_job
(
nth
*
config_
.
expert_num
,
[
&
](
int
task_id
)
{
int
expert_idx
=
task_id
/
nth
;
int
ith
=
task_id
%
nth
;
void
*
gate_input_ptr
=
m_local_gate_input_ptr_
[
expert_idx
];
void
*
gate_proj_ptr
=
gate_proj_
+
(
expert_idx
*
config_
.
intermediate_size
+
ith
*
stride
)
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
float
*
gate_output_ptr
=
m_local_gate_output_ptr_
[
expert_idx
]
+
ith
*
stride
;
llamafile_sgemm
(
stride
,
m_local_num_
[
expert_idx
],
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
intermediate_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
up_input_ptr
=
m_local_up_input_ptr_
[
expert_idx
];
void
*
up_proj_ptr
=
up_proj_
+
(
expert_idx
*
config_
.
intermediate_size
+
ith
*
stride
)
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
float
*
up_output_ptr
=
m_local_up_output_ptr_
[
expert_idx
]
+
ith
*
stride
;
llamafile_sgemm
(
stride
,
m_local_num_
[
expert_idx
],
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
intermediate_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
0
;
i
<
m_local_num_
[
expert_idx
];
i
++
)
{
for
(
int
j
=
ith
*
stride
;
j
<
(
ith
+
1
)
*
stride
;
j
++
)
{
m_local_intermediate_fp32_ptr_
[
expert_idx
][
i
*
config_
.
intermediate_size
+
j
]
=
act_fn
(
m_local_gate_output_ptr_
[
expert_idx
][
i
*
config_
.
intermediate_size
+
j
])
*
m_local_up_output_ptr_
[
expert_idx
][
i
*
config_
.
intermediate_size
+
j
];
}
float
*
intermediate_fp32_ptr
=
m_local_intermediate_fp32_ptr_
[
expert_idx
]
+
i
*
config_
.
intermediate_size
+
ith
*
stride
;
void
*
down_input_ptr
=
m_local_down_input_ptr_
[
expert_idx
]
+
i
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
+
ith
*
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
});
stride
=
QK_K
;
nth
=
config_
.
hidden_size
/
stride
;
backend
->
do_work_stealing_job
(
nth
*
config_
.
expert_num
,
[
&
](
int
task_id
)
{
int
expert_idx
=
task_id
/
nth
;
int
ith
=
task_id
%
nth
;
void
*
down_input_ptr
=
m_local_down_input_ptr_
[
expert_idx
];
void
*
down_proj_ptr
=
down_proj_
+
(
expert_idx
*
config_
.
hidden_size
+
ith
*
stride
)
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
float
*
down_output_ptr
=
m_local_down_output_ptr_
[
expert_idx
]
+
ith
*
stride
;
llamafile_sgemm
(
stride
,
m_local_num_
[
expert_idx
],
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_input_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
hidden_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
});
backend
->
do_work_stealing_job
(
qlen
,
[
&
](
int
i
)
{
for
(
int
e
=
0
;
e
<
config_
.
hidden_size
;
e
++
)
{
m_output_fp32_
[
i
][
e
]
=
0
;
}
for
(
int
j
=
0
;
j
<
k
;
j
++
)
{
for
(
int
e
=
0
;
e
<
config_
.
hidden_size
;
e
++
)
{
m_output_fp32_
[
i
][
e
]
+=
m_local_down_output_ptr_
[
expert_ids
[
i
*
k
+
j
]][
m_local_pos_
[
i
][
j
]
*
config_
.
hidden_size
+
e
]
*
weights
[
i
*
k
+
j
];
}
}
from_float
(
m_output_fp32_
[
i
],
output
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
config_
.
hidden_size
,
config_
.
hidden_type
);
});
}
void
MOE
::
forward
(
int
qlen
,
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
if
(
qlen
<
config_
.
group_min_len
)
{
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
forward_one
(
k
,
expert_ids
+
i
*
k
,
weights
+
i
*
k
,
input
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
output
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
}
return
;
}
int
forward_len
=
std
::
min
(
config_
.
group_max_len
,
qlen
);
forward_many
(
forward_len
,
k
,
expert_ids
,
weights
,
input
,
output
,
backend
);
forward
(
qlen
-
forward_len
,
k
,
expert_ids
+
forward_len
*
k
,
weights
+
forward_len
*
k
,
input
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
output
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
}
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/moe.h
0 → 100644
View file @
18c42e67
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:10
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_MOE_H
#define CPUINFER_OPERATOR_MOE_H
#include <cmath>
#include <cstdio>
#include <functional>
#include <mutex>
#include <vector>
#include "../../cpu_backend/backend.h"
#include "conversion.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
struct
MOEConfig
{
int
expert_num
;
int
routed_expert_num
;
int
hidden_size
;
int
intermediate_size
;
int
stride
;
int
group_min_len
;
int
group_max_len
;
void
*
gate_proj
;
void
*
up_proj
;
void
*
down_proj
;
ggml_type
gate_type
;
ggml_type
up_type
;
ggml_type
down_type
;
ggml_type
hidden_type
;
MOEConfig
()
{}
MOEConfig
(
int
expert_num
,
int
routed_expert_num
,
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_min_len
,
int
group_max_len
,
void
*
gate_proj
,
void
*
up_proj
,
void
*
down_proj
,
ggml_type
gate_type
,
ggml_type
up_type
,
ggml_type
down_type
,
ggml_type
hidden_type
)
:
expert_num
(
expert_num
),
routed_expert_num
(
routed_expert_num
),
hidden_size
(
hidden_size
),
intermediate_size
(
intermediate_size
),
stride
(
stride
),
group_min_len
(
group_min_len
),
group_max_len
(
group_max_len
),
gate_proj
(
gate_proj
),
up_proj
(
up_proj
),
down_proj
(
down_proj
),
gate_type
(
gate_type
),
up_type
(
up_type
),
down_type
(
down_type
),
hidden_type
(
hidden_type
)
{}
};
class
MOE
{
public:
MOE
(
MOEConfig
);
void
warm_up
(
Backend
*
backend
);
void
forward_one
(
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward_many
(
int
qlen
,
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward
(
int
qlen
,
int
k
,
const
uint64_t
*
expert_ids
,
const
float
*
weights
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
private:
static
void
*
buffer_
;
MOEConfig
config_
;
void
*
gate_proj_
;
// [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void
*
up_proj_
;
// [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void
*
down_proj_
;
// [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
float
*
s_input_fp32_
;
// [hidden_size]
uint8_t
*
s_gate_input_
;
// [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
uint8_t
*
s_up_input_
;
// [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
std
::
vector
<
float
*>
s_gate_output_
;
// [routed_expert_num, intermediate_size]
std
::
vector
<
float
*>
s_up_output_
;
// [routed_expert_num, intermediate_size]
std
::
vector
<
float
*>
s_intermediate_fp32_
;
// [routed_expert_num, intermediate_size]
std
::
vector
<
uint8_t
*>
s_down_input_
;
// [routed_expert_num, intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]
std
::
vector
<
float
*>
s_down_output_
;
// [routed_expert_num, hidden_size]
float
*
s_output_fp32_
;
// [hidden_size]
std
::
vector
<
float
*>
m_input_fp32_
;
// [group_max_len, hidden_size]
std
::
vector
<
uint8_t
*>
m_gate_input_
;
// [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
std
::
vector
<
uint8_t
*>
m_up_input_
;
// [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
uint8_t
*
m_local_gate_input_
;
// [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
uint8_t
*
m_local_up_input_
;
// [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
float
*
m_local_gate_output_
;
// [routed_expert_num * group_max_len * intermediate_size]
float
*
m_local_up_output_
;
// [routed_expert_num * group_max_len * intermediate_size]
float
*
m_local_intermediate_fp32_
;
// [routed_expert_num * group_max_len * intermediate_size]
uint8_t
*
m_local_down_input_
;
// [routed_expert_num * group_max_len * intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]
float
*
m_local_down_output_
;
// [routed_expert_num * group_max_len * hidden_size]
std
::
vector
<
float
*>
m_output_fp32_
;
// [group_max_len, hidden_size]
std
::
vector
<
std
::
vector
<
int
>>
m_local_pos_
;
// [group_max_len, routed_expert_num]
std
::
vector
<
int
>
m_local_num_
;
// [expert_num]
std
::
vector
<
uint8_t
*>
m_local_gate_input_ptr_
;
// [expert_num]
std
::
vector
<
uint8_t
*>
m_local_up_input_ptr_
;
// [expert_num]
std
::
vector
<
float
*>
m_local_gate_output_ptr_
;
// [expert_num]
std
::
vector
<
float
*>
m_local_up_output_ptr_
;
// [expert_num]
std
::
vector
<
float
*>
m_local_intermediate_fp32_ptr_
;
// [expert_num]
std
::
vector
<
uint8_t
*>
m_local_down_input_ptr_
;
// [expert_num]
std
::
vector
<
float
*>
m_local_down_output_ptr_
;
// [expert_num]
};
#endif
\ No newline at end of file
ktransformers/local_chat.py
0 → 100644
View file @
18c42e67
# Copyright 2024 Shaoyuan Chen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
platform
import
sys
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
sys
.
path
.
insert
(
0
,
project_dir
)
import
torch
import
logging
from
transformers
import
(
AutoTokenizer
,
AutoConfig
,
AutoModelForCausalLM
,
GenerationConfig
,
TextStreamer
,
)
import
json
import
fire
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
from
ktransformers.server.config.config
import
Config
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"Qwen2MoeForCausalLM"
:
Qwen2MoeForCausalLM
,
}
ktransformer_rules_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/optimize/optimize_rules/"
default_optimize_rules
=
{
"DeepseekV2ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V2-Chat.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-57B-A14B-Instruct.yaml"
,
}
def
local_chat
(
model_path
:
str
,
optimize_rule_path
:
str
=
None
,
gguf_path
:
str
=
None
,
max_new_tokens
:
int
=
1000
,
cpu_infer
:
int
=
Config
().
cpu_infer
):
torch
.
set_grad_enabled
(
False
)
Config
().
cpu_infer
=
cpu_infer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
in
custom_models
:
print
(
"using custom modeling_xxx.py."
)
if
"Qwen2Moe"
in
config
.
architectures
[
0
]:
# Qwen2Moe must use flash_attention_2 to avoid overflow.
config
.
_attn_implementation
=
"flash_attention_2"
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
else
:
model
=
AutoModelForCausalLM
.
from_config
(
config
,
trust_remote_code
=
True
,
attn_implementation
=
"flash_attention_2"
)
if
optimize_rule_path
is
None
:
if
config
.
architectures
[
0
]
in
default_optimize_rules
:
print
(
"using default_optimize_rule for"
,
config
.
architectures
[
0
])
optimize_rule_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
optimize_rule_path
=
input
(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if
gguf_path
is
None
:
gguf_path
=
input
(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf
(
model
,
optimize_rule_path
,
gguf_path
,
config
)
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model_path
)
if
model
.
generation_config
.
pad_token_id
is
None
:
model
.
generation_config
.
pad_token_id
=
model
.
generation_config
.
eos_token_id
model
.
eval
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
system
=
platform
.
system
()
if
(
system
==
u
'Windows'
):
os
.
system
(
'cls'
)
else
:
os
.
system
(
'clear'
)
while
True
:
content
=
input
(
"Chat: "
)
# if content is num
if
content
==
""
:
content
=
"Please write a piece of quicksort code in C++."
messages
=
[{
"role"
:
"user"
,
"content"
:
content
}]
input_tensor
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
# TODO: Remove this, replace dtype using config
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
local_chat
)
ktransformers/models/__init__.py
0 → 100644
View file @
18c42e67
ktransformers/models/configuration_deepseek.py
0 → 100644
View file @
18c42e67
# Adapted from
# https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/configuration_deepseek.py
# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
class
DeepseekV2Config
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V2.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 102400):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DeepseekV2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 1407):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
n_shared_experts (`int`, *optional*, defaults to None):
Number of shared experts, None means dense model.
n_routed_experts (`int`, *optional*, defaults to None):
Number of routed experts, None means dense model.
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
Scaling factor or routed experts.
topk_method (`str`, *optional*, defaults to `gready`):
Topk method used in routed gate.
n_group (`int`, *optional*, defaults to None):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to None):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to None):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to False):
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to 'softmax'):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
seq_aux = (`bool`, *optional*, defaults to True):
Whether to compute the auxiliary loss for each individual sample.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV2Model, DeepseekV2Config
>>> # Initializing a Deepseek-V2 style configuration
>>> configuration = DeepseekV2Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"deepseek_v2"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
102400
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
moe_intermediate_size
=
1407
,
num_hidden_layers
=
30
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
n_shared_experts
=
None
,
n_routed_experts
=
None
,
ep_size
=
1
,
routed_scaling_factor
=
1.0
,
kv_lora_rank
=
512
,
q_lora_rank
=
1536
,
qk_rope_head_dim
=
64
,
v_head_dim
=
128
,
qk_nope_head_dim
=
128
,
topk_method
=
'gready'
,
n_group
=
None
,
topk_group
=
None
,
num_experts_per_tok
=
None
,
moe_layer_freq
=
1
,
first_k_dense_replace
=
0
,
norm_topk_prob
=
False
,
scoring_func
=
'softmax'
,
aux_loss_alpha
=
0.001
,
seq_aux
=
True
,
hidden_act
=
"silu"
,
max_position_embeddings
=
2048
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
100000
,
eos_token_id
=
100001
,
pretraining_tp
=
1
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
cpu_quant
=
None
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
n_shared_experts
=
n_shared_experts
self
.
n_routed_experts
=
n_routed_experts
self
.
ep_size
=
ep_size
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
kv_lora_rank
=
kv_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
topk_method
=
topk_method
self
.
n_group
=
n_group
self
.
topk_group
=
topk_group
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
moe_layer_freq
=
moe_layer_freq
self
.
first_k_dense_replace
=
first_k_dense_replace
self
.
norm_topk_prob
=
norm_topk_prob
self
.
scoring_func
=
scoring_func
self
.
aux_loss_alpha
=
aux_loss_alpha
self
.
seq_aux
=
seq_aux
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
pretraining_tp
=
pretraining_tp
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
cpu_quant
=
cpu_quant
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
Prev
1
2
3
4
5
6
7
8
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment