Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
77a34c28
Unverified
Commit
77a34c28
authored
Aug 15, 2024
by
UnicornChan
Committed by
GitHub
Aug 15, 2024
Browse files
Merge pull request #36 from kvcache-ai/develop-0.1.2
Release v0.1.2
parents
44f57270
395cd3e7
Changes
69
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
617 additions
and
1453 deletions
+617
-1453
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
+8
-0
ktransformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h
...sformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h
+0
-39
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+180
-2
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
+6
-3
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
...formers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
+2
-2
ktransformers/ktransformers_ext/cuda/setup.py
ktransformers/ktransformers_ext/cuda/setup.py
+19
-11
ktransformers/ktransformers_ext/examples/test_linear.py
ktransformers/ktransformers_ext/examples/test_linear.py
+22
-43
ktransformers/ktransformers_ext/examples/test_mlp.py
ktransformers/ktransformers_ext/examples/test_mlp.py
+35
-51
ktransformers/ktransformers_ext/examples/test_moe.py
ktransformers/ktransformers_ext/examples/test_moe.py
+73
-65
ktransformers/ktransformers_ext/ext_bindings.cpp
ktransformers/ktransformers_ext/ext_bindings.cpp
+118
-203
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py
...transformers_ext/operators/custom_marlin/quantize/gptq.py
+0
-206
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py
...rmers_ext/operators/custom_marlin/quantize/gptq_marlin.py
+0
-458
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py
...formers_ext/operators/custom_marlin/quantize/quantizer.py
+0
-140
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py
...ansformers_ext/operators/custom_marlin/quantize/repack.py
+0
-99
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
...xt/operators/custom_marlin/quantize/utils/marlin_utils.py
+2
-2
ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
...sformers/ktransformers_ext/operators/llamafile/linear.cpp
+38
-12
ktransformers/ktransformers_ext/operators/llamafile/linear.h
ktransformers/ktransformers_ext/operators/llamafile/linear.h
+11
-7
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
+55
-35
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
+16
-12
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
+32
-63
No files found.
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
View file @
77a34c28
...
@@ -12,14 +12,22 @@ int test(){
...
@@ -12,14 +12,22 @@ int test(){
}
}
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
PYBIND11_MODULE
(
cudaops
,
m
)
{
PYBIND11_MODULE
(
cudaops
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
}
}
ktransformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h
deleted
100644 → 0
View file @
44f57270
#include <cuda_fp16.h>
__device__
float
ggml_compute_fp16_to_fp32
(
uint16_t
h
)
{
return
__uint2float_rd
(
h
);
}
static
inline
float
ggml_compute_fp16_to_fp32
(
uint16_t
h
)
{
uint16_t
tmp
;
memcpy
(
&
tmp
,
&
h
,
sizeof
(
ggml_fp16_t
));
return
(
float
)
tmp
;
}
// define the global table for fp16 to fp32 conversion
__device__
float
ggml_table_f32_f16
[
1
<<
16
];
// CUDA Kernel to init the table
__global__
void
init_fp16_to_fp32_table
()
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
blk_id
=
idx
;
blk_id
<
(
1
<<
16
);
blk_id
+=
blockDim
.
x
*
gridDim
.
x
){
ggml_table_f32_f16
[
blk_id
]
=
GGML_COMPUTE_FP16_TO_FP32
(
blk_id
);
}
}
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
extern
__device__
float
ggml_table_f32_f16
[
1
<<
16
];
// Declare as __device__ if used within device code
// This version of the function is designed to be called from within a CUDA kernel
#if !defined(GGML_FP16_TO_FP32)
__device__
float
ggml_lookup_fp16_to_fp32
(
uint16_t
f
)
{
return
ggml_table_f32_f16
[
f
];
}
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
77a34c28
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* @Author : Azure-Tang, Boxin Zhang
* @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors :
Azure
* @LastEditors :
kkk1nak0
* @LastEditTime : 2024-0
7-26 11:58:50
* @LastEditTime : 2024-0
8-12 04:18:04
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include <torch/extension.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <torch/torch.h>
#include <cstdint>
#include <cstdint>
#include <c10/cuda/CUDAGuard.h>
__global__
void
dequantize_q8_0_kernel
(
float
*
output
,
const
float
*
scales
,
const
int8_t
*
qs
,
int
num_blocks
,
int
blk_size
)
{
__global__
void
dequantize_q8_0_kernel
(
float
*
output
,
const
float
*
scales
,
const
int8_t
*
qs
,
int
num_blocks
,
int
blk_size
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -35,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
...
@@ -35,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}
}
}
}
__global__
void
dequantize_q2_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
82
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
int
is
=
0
;
float
dl
,
ml
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
uint8_t
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
]
>>
shift
)
&
3
))
-
ml
;
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
))
-
ml
;
shift
+=
2
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q3_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask2
=
0x0f0f0f0f
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
uint32_t
aux
[
4
];
const
int8_t
*
scales
=
(
const
int8_t
*
)
aux
;
const
float
d_all
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
108
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
32
);
const
uint8_t
*
__restrict__
hm
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
0
);
uint8_t
m
=
1
;
uint8_t
*
block_scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
96
);
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
aux
[
i
]
=
0
;
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
aux
[
i
]
|=
((
uint32_t
)
block_scales
[
i
*
4
+
j
])
<<
(
j
*
8
);
}
}
uint32_t
tmp
=
aux
[
2
];
aux
[
2
]
=
((
aux
[
0
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
4
)
&
kmask1
)
<<
4
);
aux
[
3
]
=
((
aux
[
1
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
6
)
&
kmask1
)
<<
4
);
aux
[
0
]
=
(
aux
[
0
]
&
kmask2
)
|
(((
tmp
>>
0
)
&
kmask1
)
<<
4
);
aux
[
1
]
=
(
aux
[
1
]
&
kmask2
)
|
(((
tmp
>>
2
)
&
kmask1
)
<<
4
);
int
is
=
0
;
float
dl
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
0
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
0
]
&
m
)
?
0
:
4
));
}
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
16
]
&
m
)
?
0
:
4
));
}
shift
+=
2
;
m
<<=
1
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
...
@@ -59,6 +151,35 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -59,6 +151,35 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
}
}
}
}
__global__
void
dequantize_q5_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
2
)));
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
48
);
int
is
=
0
;
uint8_t
sc
,
m
;
uint8_t
u1
=
1
,
u2
=
2
;
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
4
);
for
(
int
j
=
0
;
j
<
256
;
j
+=
64
)
{
get_scale_min_k4
(
is
+
0
,
scales
,
&
sc
,
&
m
);
const
float
d1
=
d
*
sc
;
const
float
m1
=
min
*
m
;
get_scale_min_k4
(
is
+
1
,
scales
,
&
sc
,
&
m
);
const
float
d2
=
d
*
sc
;
const
float
m2
=
min
*
m
;
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
d1
*
((
ql
[
l
]
&
0xF
)
+
(
qh
[
l
]
&
u1
?
16
:
0
))
-
m1
;
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
d2
*
((
ql
[
l
]
>>
4
)
+
(
qh
[
l
]
&
u2
?
16
:
0
))
-
m2
;
ql
+=
32
;
is
+=
2
;
u1
<<=
2
;
u2
<<=
2
;
}
}
}
__global__
void
dequantize_q6_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q6_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
...
@@ -94,6 +215,7 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -94,6 +215,7 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
data
.
numel
()
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
// create gpu
// create gpu
auto
options_scales
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options_scales
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options_qs
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options_qs
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
...
@@ -128,6 +250,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -128,6 +250,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
// data.numel%blk_size should be 0, else raise err
// data.numel%blk_size should be 0, else raise err
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
data
.
numel
()
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
...
@@ -144,9 +267,28 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -144,9 +267,28 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q5_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
// data.numel%blk_size should be 0, else raise err
// data.numel%blk_size should be 0, else raise err
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
data
.
numel
()
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
...
@@ -162,3 +304,39 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -162,3 +304,39 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q3_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q2_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
View file @
77a34c28
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* @Author : Azure-Tang
* @Author : Azure-Tang
* @Date : 2024-07-22 09:27:55
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors :
Azure
* @LastEditors :
kkk1nak0
* @LastEditTime : 2024-0
7-26 08:38:20
* @LastEditTime : 2024-0
8-12 03:48:46
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
#pragma once
#pragma once
...
@@ -15,4 +15,7 @@
...
@@ -15,4 +15,7 @@
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
\ No newline at end of file
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
View file @
77a34c28
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include "gptq_marlin_dtypes.cuh"
#include <c10/cuda/CUDAGuard.h>
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
std::is_same<scalar_t, nv_bfloat16>::value, \
...
@@ -1774,6 +1774,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1774,6 +1774,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
int64_t
size_k
,
bool
is_k_full
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
// Verify num_bits
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
"num_bits must be 4 or 8. Got = "
,
num_bits
);
...
@@ -1816,7 +1817,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1816,7 +1817,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
// Alloc buffers
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
...
...
ktransformers/ktransformers_ext/cuda/setup.py
View file @
77a34c28
...
@@ -2,17 +2,25 @@
...
@@ -2,17 +2,25 @@
from
setuptools
import
setup
,
Extension
from
setuptools
import
setup
,
Extension
from
torch.utils
import
cpp_extension
from
torch.utils
import
cpp_extension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
setup
(
# setup marlin gemm
name
=
'KTransformersOps'
,
setup
(
name
=
'KTransformersOps'
,
ext_modules
=
[
ext_modules
=
[
CUDAExtension
(
CUDAExtension
(
'KTransformersOps'
,
[
'KTransformersOps'
,
[
'custom_gguf/dequant.cu'
,
'custom_gguf/dequant.cu'
,
'binding.cpp'
,
'binding.cpp'
,
'gptq_marlin/gptq_marlin.cu'
,
'gptq_marlin/gptq_marlin.cu'
,
# 'gptq_marlin_repack.cu',
# 'gptq_marlin_repack.cu',
])
],
],
extra_compile_args
=
{
cmdclass
=
{
'build_ext'
:
BuildExtension
'cxx'
:
[
'-O3'
],
})
'nvcc'
:
[
'-O3'
,
'--use_fast_math'
,
'-Xcompiler'
,
'-fPIC'
,
]
},
)
],
cmdclass
=
{
'build_ext'
:
BuildExtension
}
)
\ No newline at end of file
ktransformers/ktransformers_ext/examples/test_linear.py
View file @
77a34c28
...
@@ -6,7 +6,7 @@ Author : chenht2022
...
@@ -6,7 +6,7 @@ Author : chenht2022
Date : 2024-07-25 10:32:05
Date : 2024-07-25 10:32:05
Version : 1.0.0
Version : 1.0.0
LastEditors : chenht2022
LastEditors : chenht2022
LastEditTime : 2024-0
7-25
10:3
4:00
LastEditTime : 2024-0
8-06
10:3
6:59
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
'''
import
os
,
sys
import
os
,
sys
...
@@ -15,23 +15,23 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
...
@@ -15,23 +15,23 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
import
cpuinfer_ext
import
cpuinfer_ext
import
torch
import
torch
with
torch
.
inference_mode
(
mode
=
True
):
input_size
=
16384
input_size
=
16384
output_size
=
5120
output_size
=
5120
stride
=
32
stride
=
32
group_max_len
=
1024
proj_type
=
1
# ggml_type::GGML_TYPE_F16
proj_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
layer_num
=
10
qlen
=
30
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
layer_num
=
10
validation_iter
=
100
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
warm_up_iter
=
1000
validation_iter
=
100
test_iter
=
10000
with
torch
.
inference_mode
(
mode
=
True
):
linears
=
[]
linears
=
[]
projs
=
[]
projs
=
[]
for
_
in
range
(
layer_num
):
for
_
in
range
(
layer_num
):
proj
=
torch
.
randn
((
output_size
,
input_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
proj
=
torch
.
randn
((
output_size
,
input_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
config
=
cpuinfer_ext
.
linear
.
LinearConfig
(
input_size
,
output_size
,
stride
,
proj
.
data_ptr
(),
proj_type
,
hidden_type
)
config
=
cpuinfer_ext
.
linear
.
LinearConfig
(
input_size
,
output_size
,
stride
,
group_max_len
,
proj
.
data_ptr
(),
proj_type
,
hidden_type
)
linear
=
cpuinfer_ext
.
linear
.
Linear
(
config
)
linear
=
cpuinfer_ext
.
linear
.
Linear
(
config
)
projs
.
append
(
proj
)
projs
.
append
(
proj
)
linears
.
append
(
linear
)
linears
.
append
(
linear
)
...
@@ -39,11 +39,17 @@ with torch.inference_mode(mode=True):
...
@@ -39,11 +39,17 @@ with torch.inference_mode(mode=True):
# validation
# validation
for
i
in
range
(
validation_iter
):
for
i
in
range
(
validation_iter
):
linear
=
linears
[
i
%
layer_num
]
linear
=
linears
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
input_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
input_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
output_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
output_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
input
=
input
/
100
CPUInfer
.
submit
(
linear
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
submit
(
linear
.
forward
(
qlen
,
input
.
data_ptr
(),
output
.
data_ptr
()
)
)
CPUInfer
.
sync
()
CPUInfer
.
sync
()
# print('cpuinfer output', output)
# print('cpuinfer output', output)
...
@@ -54,30 +60,3 @@ with torch.inference_mode(mode=True):
...
@@ -54,30 +60,3 @@ with torch.inference_mode(mode=True):
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
print
(
'diff = '
,
diff
)
print
(
'diff = '
,
diff
)
assert
(
diff
<
0.001
)
assert
(
diff
<
0.001
)
# warm up
for
i
in
range
(
warm_up_iter
):
linear
=
linears
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
input_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
output_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
CPUInfer
.
submit
(
linear
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
# test
total_time
=
0
for
i
in
range
(
test_iter
):
linear
=
linears
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
input_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
output_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
start
=
time
.
perf_counter
()
CPUInfer
.
submit
(
linear
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
end
=
time
.
perf_counter
()
total_time
+=
end
-
start
print
(
'Time: '
,
total_time
)
print
(
'Iteration: '
,
test_iter
)
print
(
'Time per iteration: '
,
total_time
/
test_iter
)
print
(
'Bandwidth: '
,
input_size
*
output_size
*
2
*
test_iter
/
total_time
/
1000
/
1000
/
1000
,
'GB/s'
)
print
(
"All tasks completed."
)
\ No newline at end of file
ktransformers/ktransformers_ext/examples/test_mlp.py
View file @
77a34c28
...
@@ -6,7 +6,7 @@ Author : chenht2022
...
@@ -6,7 +6,7 @@ Author : chenht2022
Date : 2024-07-25 10:32:05
Date : 2024-07-25 10:32:05
Version : 1.0.0
Version : 1.0.0
LastEditors : chenht2022
LastEditors : chenht2022
LastEditTime : 2024-0
7-25
10:3
4:03
LastEditTime : 2024-0
8-06
10:3
7:28
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
'''
import
os
,
sys
import
os
,
sys
...
@@ -15,20 +15,30 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
...
@@ -15,20 +15,30 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
import
cpuinfer_ext
import
cpuinfer_ext
import
torch
import
torch
with
torch
.
inference_mode
(
mode
=
True
):
hidden_size
=
5120
hidden_size
=
5120
intermediate_size
=
3072
intermediate_size
=
3072
stride
=
32
stride
=
32
group_max_len
=
1024
gate_type
=
1
# ggml_type::GGML_TYPE_F16
gate_type
=
1
# ggml_type::GGML_TYPE_F16
up_type
=
1
# ggml_type::GGML_TYPE_F16
up_type
=
1
# ggml_type::GGML_TYPE_F16
down_type
=
1
# ggml_type::GGML_TYPE_F16
down_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
layer_num
=
10
qlen
=
30
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
layer_num
=
10
validation_iter
=
100
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
warm_up_iter
=
1000
validation_iter
=
100
test_iter
=
10000
def
act_fn
(
x
):
return
x
/
(
1.0
+
torch
.
exp
(
-
x
))
def
mlp_torch
(
input
,
gate_proj
,
up_proj
,
down_proj
):
gate_buf
=
torch
.
mm
(
input
,
gate_proj
.
t
())
up_buf
=
torch
.
mm
(
input
,
up_proj
.
t
())
intermediate
=
act_fn
(
gate_buf
)
*
up_buf
ret
=
torch
.
mm
(
intermediate
,
down_proj
.
t
())
return
ret
with
torch
.
inference_mode
(
mode
=
True
):
mlps
=
[]
mlps
=
[]
gate_projs
=
[]
gate_projs
=
[]
up_projs
=
[]
up_projs
=
[]
...
@@ -37,7 +47,7 @@ with torch.inference_mode(mode=True):
...
@@ -37,7 +47,7 @@ with torch.inference_mode(mode=True):
gate_proj
=
torch
.
randn
((
intermediate_size
,
hidden_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
gate_proj
=
torch
.
randn
((
intermediate_size
,
hidden_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
up_proj
=
torch
.
randn
((
intermediate_size
,
hidden_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
up_proj
=
torch
.
randn
((
intermediate_size
,
hidden_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
down_proj
=
torch
.
randn
((
hidden_size
,
intermediate_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
down_proj
=
torch
.
randn
((
hidden_size
,
intermediate_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
config
=
cpuinfer_ext
.
mlp
.
MLPConfig
(
hidden_size
,
intermediate_size
,
stride
,
gate_proj
.
data_ptr
(),
up_proj
.
data_ptr
(),
down_proj
.
data_ptr
(),
gate_type
,
up_type
,
down_type
,
hidden_type
)
config
=
cpuinfer_ext
.
mlp
.
MLPConfig
(
hidden_size
,
intermediate_size
,
stride
,
group_max_len
,
gate_proj
.
data_ptr
(),
up_proj
.
data_ptr
(),
down_proj
.
data_ptr
(),
gate_type
,
up_type
,
down_type
,
hidden_type
)
mlp
=
cpuinfer_ext
.
mlp
.
MLP
(
config
)
mlp
=
cpuinfer_ext
.
mlp
.
MLP
(
config
)
gate_projs
.
append
(
gate_proj
)
gate_projs
.
append
(
gate_proj
)
up_projs
.
append
(
up_proj
)
up_projs
.
append
(
up_proj
)
...
@@ -47,52 +57,26 @@ with torch.inference_mode(mode=True):
...
@@ -47,52 +57,26 @@ with torch.inference_mode(mode=True):
# validation
# validation
for
i
in
range
(
validation_iter
):
for
i
in
range
(
validation_iter
):
mlp
=
mlps
[
i
%
layer_num
]
mlp
=
mlps
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
input
=
input
/
100
CPUInfer
.
submit
(
mlp
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
submit
(
mlp
.
forward
(
qlen
,
input
.
data_ptr
(),
output
.
data_ptr
()
)
)
CPUInfer
.
sync
()
CPUInfer
.
sync
()
# print('cpuinfer output', output)
# print('cpuinfer output', output)
def
act_fn
(
x
):
return
x
/
(
1.0
+
torch
.
exp
(
-
x
))
gate_proj
=
gate_projs
[
i
%
layer_num
]
gate_proj
=
gate_projs
[
i
%
layer_num
]
up_proj
=
up_projs
[
i
%
layer_num
]
up_proj
=
up_projs
[
i
%
layer_num
]
down_proj
=
down_projs
[
i
%
layer_num
]
down_proj
=
down_projs
[
i
%
layer_num
]
gate_buf
=
torch
.
mm
(
input
,
gate_proj
.
t
())
t_output
=
mlp_torch
(
input
,
gate_proj
,
up_proj
,
down_proj
)
up_buf
=
torch
.
mm
(
input
,
up_proj
.
t
())
intermediate
=
act_fn
(
gate_buf
)
*
up_buf
t_output
=
torch
.
mm
(
intermediate
,
down_proj
.
t
())
# print('torch output', t_output)
# print('torch output', t_output)
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
print
(
'diff = '
,
diff
)
print
(
'diff = '
,
diff
)
assert
(
diff
<
0.001
)
assert
(
diff
<
0.001
)
# warm up
for
i
in
range
(
warm_up_iter
):
mlp
=
mlps
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
CPUInfer
.
submit
(
mlp
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
# test
total_time
=
0
for
i
in
range
(
test_iter
):
mlp
=
mlps
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
start
=
time
.
time
()
CPUInfer
.
submit
(
mlp
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
end
=
time
.
time
()
total_time
+=
end
-
start
print
(
'Time: '
,
total_time
)
print
(
'Iteration: '
,
test_iter
)
print
(
'Time per iteration: '
,
total_time
/
test_iter
)
print
(
'Bandwidth: '
,
hidden_size
*
intermediate_size
*
3
*
2
*
test_iter
/
total_time
/
1024
/
1024
/
1024
,
'GB/s'
)
print
(
"All tasks completed."
)
\ No newline at end of file
ktransformers/ktransformers_ext/examples/test_moe.py
View file @
77a34c28
...
@@ -6,7 +6,7 @@ Author : chenht2022
...
@@ -6,7 +6,7 @@ Author : chenht2022
Date : 2024-07-25 10:32:05
Date : 2024-07-25 10:32:05
Version : 1.0.0
Version : 1.0.0
LastEditors : chenht2022
LastEditors : chenht2022
LastEditTime : 2024-0
7-25
10:3
4
:0
6
LastEditTime : 2024-0
8-06
10:3
8
:0
5
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
'''
import
os
,
sys
import
os
,
sys
...
@@ -15,25 +15,64 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
...
@@ -15,25 +15,64 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
import
cpuinfer_ext
import
cpuinfer_ext
import
torch
import
torch
with
torch
.
inference_mode
(
mode
=
True
):
expert_num
=
160
expert_num
=
10
hidden_size
=
5120
hidden_size
=
5120
intermediate_size
=
1536
intermediate_size
=
1536
stride
=
32
stride
=
32
group_min_len
=
10
group_min_len
=
10
group_max_len
=
1024
group_max_len
=
1024
gate_type
=
1
# ggml_type::GGML_TYPE_F16
gate_type
=
1
# ggml_type::GGML_TYPE_F16
up_type
=
1
# ggml_type::GGML_TYPE_F16
up_type
=
1
# ggml_type::GGML_TYPE_F16
down_type
=
1
# ggml_type::GGML_TYPE_F16
down_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
n_routed_experts
=
6
n_routed_experts
=
6
qlen
=
30
qlen
=
30
layer_num
=
10
layer_num
=
10
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
validation_iter
=
100
validation_iter
=
100
warm_up_iter
=
1000
def
act_fn
(
x
):
test_iter
=
10000
return
x
/
(
1.0
+
torch
.
exp
(
-
x
))
def
mlp_torch
(
input
,
gate_proj
,
up_proj
,
down_proj
):
gate_buf
=
torch
.
mm
(
input
,
gate_proj
.
t
())
up_buf
=
torch
.
mm
(
input
,
up_proj
.
t
())
intermediate
=
act_fn
(
gate_buf
)
*
up_buf
ret
=
torch
.
mm
(
intermediate
,
down_proj
.
t
())
return
ret
def
moe_torch
(
input
,
expert_ids
,
weights
,
gate_proj
,
up_proj
,
down_proj
):
cnts
=
expert_ids
.
new_zeros
((
expert_ids
.
shape
[
0
],
expert_num
))
cnts
.
scatter_
(
1
,
expert_ids
,
1
)
tokens_per_expert
=
cnts
.
sum
(
dim
=
0
)
idxs
=
expert_ids
.
view
(
-
1
).
argsort
()
sorted_tokens
=
input
[
idxs
//
expert_ids
.
shape
[
1
]]
outputs
=
[]
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
end_idx
=
start_idx
+
num_tokens
if
num_tokens
==
0
:
continue
tokens_for_this_expert
=
sorted_tokens
[
start_idx
:
end_idx
]
expert_out
=
mlp_torch
(
tokens_for_this_expert
,
gate_proj
[
i
],
up_proj
[
i
],
down_proj
[
i
])
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
outs
=
torch
.
cat
(
outputs
,
dim
=
0
)
if
len
(
outputs
)
else
sorted_tokens
.
new_empty
(
0
)
new_x
=
torch
.
empty_like
(
outs
)
new_x
[
idxs
]
=
outs
t_output
=
(
new_x
.
view
(
*
expert_ids
.
shape
,
-
1
)
.
type
(
weights
.
dtype
)
.
mul_
(
weights
.
unsqueeze
(
dim
=-
1
))
.
sum
(
dim
=
1
)
.
type
(
new_x
.
dtype
)
)
return
t_output
with
torch
.
inference_mode
(
mode
=
True
):
moes
=
[]
moes
=
[]
gate_projs
=
[]
gate_projs
=
[]
up_projs
=
[]
up_projs
=
[]
...
@@ -51,63 +90,32 @@ with torch.inference_mode(mode=True):
...
@@ -51,63 +90,32 @@ with torch.inference_mode(mode=True):
# validation
# validation
for
i
in
range
(
validation_iter
):
for
i
in
range
(
validation_iter
):
moe
=
moes
[
i
%
layer_num
]
expert_ids
=
torch
.
stack
([
torch
.
randperm
(
expert_num
)[:
n_routed_experts
]
for
_
in
range
(
qlen
)]).
contiguous
()
expert_ids
=
torch
.
randint
(
0
,
expert_num
,
(
qlen
,
n_routed_experts
),
dtype
=
torch
.
int64
).
contiguous
()
weights
=
torch
.
rand
((
qlen
,
n_routed_experts
),
dtype
=
torch
.
float32
).
contiguous
()
weights
=
torch
.
rand
((
qlen
,
n_routed_experts
),
dtype
=
torch
.
float32
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
input
=
input
/
100
CPUInfer
.
submit
(
moe
.
forward
,
qlen
,
n_routed_experts
,
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input
.
data_ptr
(),
output
.
data_ptr
())
moe
=
moes
[
i
%
layer_num
]
CPUInfer
.
submit
(
moe
.
forward
(
qlen
,
n_routed_experts
,
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input
.
data_ptr
(),
output
.
data_ptr
()
)
)
CPUInfer
.
sync
()
CPUInfer
.
sync
()
# print('cpuinfer output', output)
# print('cpuinfer output', output)
def
act_fn
(
x
):
return
x
/
(
1.0
+
torch
.
exp
(
-
x
))
t_output
=
torch
.
zeros
((
qlen
,
1
,
hidden_size
),
dtype
=
torch
.
float32
).
contiguous
()
gate_proj
=
gate_projs
[
i
%
layer_num
]
gate_proj
=
gate_projs
[
i
%
layer_num
]
up_proj
=
up_projs
[
i
%
layer_num
]
up_proj
=
up_projs
[
i
%
layer_num
]
down_proj
=
down_projs
[
i
%
layer_num
]
down_proj
=
down_projs
[
i
%
layer_num
]
for
token_idx
in
range
(
qlen
):
t_output
=
moe_torch
(
input
,
expert_ids
,
weights
,
gate_proj
,
up_proj
,
down_proj
)
for
i
,
expert_id
in
enumerate
(
expert_ids
[
token_idx
]):
gate_buf
=
torch
.
mm
(
input
[
token_idx
],
gate_proj
[
expert_id
].
t
())
up_buf
=
torch
.
mm
(
input
[
token_idx
],
up_proj
[
expert_id
].
t
())
intermediate
=
act_fn
(
gate_buf
)
*
up_buf
expert_output
=
torch
.
mm
(
intermediate
,
down_proj
[
expert_id
].
t
())
t_output
[
token_idx
]
+=
weights
[
token_idx
][
i
]
*
expert_output
# print('torch output', t_output)
# print('torch output', t_output)
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
print
(
'diff = '
,
diff
)
print
(
'diff = '
,
diff
)
assert
(
diff
<
0.001
)
assert
(
diff
<
0.001
)
# warm up
for
i
in
range
(
warm_up_iter
):
moe
=
moes
[
i
%
layer_num
]
expert_ids
=
torch
.
randint
(
0
,
expert_num
,
(
qlen
,
n_routed_experts
),
dtype
=
torch
.
int64
).
contiguous
()
weights
=
torch
.
rand
((
qlen
,
n_routed_experts
),
dtype
=
torch
.
float32
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
CPUInfer
.
submit
(
moe
.
forward
,
qlen
,
n_routed_experts
,
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
# test
total_time
=
0
for
i
in
range
(
test_iter
):
moe
=
moes
[
i
%
layer_num
]
expert_ids
=
torch
.
randint
(
0
,
expert_num
,
(
qlen
,
n_routed_experts
),
dtype
=
torch
.
int64
).
contiguous
()
weights
=
torch
.
rand
((
qlen
,
n_routed_experts
),
dtype
=
torch
.
float32
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
start
=
time
.
perf_counter
()
CPUInfer
.
submit
(
moe
.
forward
,
qlen
,
n_routed_experts
,
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
end
=
time
.
perf_counter
()
total_time
+=
end
-
start
print
(
'Time: '
,
total_time
)
print
(
'Iteration: '
,
test_iter
)
print
(
'Time per iteration: '
,
total_time
/
test_iter
)
print
(
'Bandwidth: '
,
hidden_size
*
intermediate_size
*
3
*
n_routed_experts
*
2
*
test_iter
/
total_time
/
1000
/
1000
/
1000
,
'GB/s'
)
print
(
"All tasks completed."
)
\ No newline at end of file
ktransformers/ktransformers_ext/ext_bindings.cpp
View file @
77a34c28
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* @Author : chenht2022
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-0
7-25
10:3
4:23
* @LastEditTime : 2024-0
8-07
10:3
9:37
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
// Python bindings
// Python bindings
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
#include <iostream>
#include <iostream>
#include <memory>
#include <memory>
#include "cpu_backend/cpuinfer.h"
#include "cpu_backend/cpuinfer.h"
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include "device_launch_parameters.h"
#include "llamafile/flags.h"
#include "llamafile/flags.h"
#include "operators/llamafile/linear.h"
#include "operators/llamafile/linear.h"
...
@@ -26,239 +25,155 @@
...
@@ -26,239 +25,155 @@
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
using
namespace
pybind11
::
literals
;
using
namespace
pybind11
::
literals
;
// Binding functions for the Linear class
class
LinearBindings
{
class
LinearBindings
{
public:
public:
static
void
bind_forward
(
CPUInfer
&
cpuinfer
,
Linear
*
linear
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
class
WarmUpBindinds
{
auto
input
=
args
[
0
].
cast
<
intptr_t
>
();
public:
auto
output
=
args
[
1
].
cast
<
intptr_t
>
();
struct
Args
{
cpuinfer
.
submit
(
&
Linear
::
forward
,
linear
,
CPUInfer
*
cpuinfer
;
(
const
void
*
)
input
,
(
void
*
)
output
);
Linear
*
linear
;
}
};
static
void
inner
(
void
*
args
)
{
static
void
bind_warm_up
(
CPUInfer
&
cpuinfer
,
Linear
*
linear
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
Args
*
args_
=
(
Args
*
)
args
;
cpuinfer
.
submit
(
&
Linear
::
warm_up
,
linear
);
args_
->
cpuinfer
->
enqueue
(
&
Linear
::
warm_up
,
args_
->
linear
);
}
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
Linear
&
linear
)
{
static
void
bind_functions
(
CPUInfer
&
cpuinfer
,
py
::
object
func
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
linear
};
auto
linear
=
func
.
attr
(
"__self__"
).
cast
<
Linear
*>
();
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
std
::
string
func_name
=
py
::
str
(
func
.
attr
(
"__func__"
).
attr
(
"__name__"
));
}
};
if
(
func_name
==
"forward"
)
{
class
ForwardBindings
{
bind_forward
(
cpuinfer
,
linear
,
args
,
kwargs
);
public:
}
else
if
(
func_name
==
"warm_up"
)
{
struct
Args
{
bind_warm_up
(
cpuinfer
,
linear
,
args
,
kwargs
);
CPUInfer
*
cpuinfer
;
}
else
{
Linear
*
linear
;
throw
py
::
value_error
(
"Unsupported function: "
+
int
qlen
;
std
::
string
(
func_name
));
const
void
*
input
;
void
*
output
;
};
static
void
inner
(
void
*
args
)
{
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
Linear
::
forward
,
args_
->
linear
,
args_
->
qlen
,
args_
->
input
,
args_
->
output
);
}
}
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
Linear
&
linear
,
int
qlen
,
intptr_t
input
,
intptr_t
output
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
linear
,
qlen
,
(
const
void
*
)
input
,
(
void
*
)
output
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
};
};
};
// Binding functions for the MLP class
class
MLPBindings
{
class
MLPBindings
{
public:
public:
static
void
bind_forward
(
CPUInfer
&
cpuinfer
,
MLP
*
mlp
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
class
WarmUpBindinds
{
auto
input
=
args
[
0
].
cast
<
intptr_t
>
();
public:
auto
output
=
args
[
1
].
cast
<
intptr_t
>
();
struct
Args
{
cpuinfer
.
submit
(
&
MLP
::
forward
,
mlp
,
CPUInfer
*
cpuinfer
;
(
const
void
*
)
input
,
(
void
*
)
output
);
MLP
*
mlp
;
}
};
static
void
inner
(
void
*
args
)
{
static
void
bind_warm_up
(
CPUInfer
&
cpuinfer
,
MLP
*
mlp
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
Args
*
args_
=
(
Args
*
)
args
;
cpuinfer
.
submit
(
&
MLP
::
warm_up
,
mlp
);
args_
->
cpuinfer
->
enqueue
(
&
MLP
::
warm_up
,
args_
->
mlp
);
}
static
void
bind_functions
(
CPUInfer
&
cpuinfer
,
py
::
object
func
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
auto
mlp
=
func
.
attr
(
"__self__"
).
cast
<
MLP
*>
();
std
::
string
func_name
=
py
::
str
(
func
.
attr
(
"__func__"
).
attr
(
"__name__"
));
if
(
func_name
==
"forward"
)
{
bind_forward
(
cpuinfer
,
mlp
,
args
,
kwargs
);
}
else
if
(
func_name
==
"warm_up"
)
{
bind_warm_up
(
cpuinfer
,
mlp
,
args
,
kwargs
);
}
else
{
throw
py
::
value_error
(
"Unsupported function: "
+
std
::
string
(
func_name
));
}
}
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
MLP
&
mlp
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
mlp
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
};
class
ForwardBindings
{
public:
struct
Args
{
CPUInfer
*
cpuinfer
;
MLP
*
mlp
;
int
qlen
;
const
void
*
input
;
void
*
output
;
};
static
void
inner
(
void
*
args
)
{
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
MLP
::
forward
,
args_
->
mlp
,
args_
->
qlen
,
args_
->
input
,
args_
->
output
);
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
MLP
&
mlp
,
int
qlen
,
intptr_t
input
,
intptr_t
output
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
mlp
,
qlen
,
(
const
void
*
)
input
,
(
void
*
)
output
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
};
};
};
// Binding functions for the MOE class
class
MOEBindings
{
class
MOEBindings
{
public:
public:
static
void
bind_forward
(
CPUInfer
&
cpuinfer
,
MOE
*
moe
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
class
WarmUpBindinds
{
int
qlen
=
args
[
0
].
cast
<
int
>
();
public:
int
k
=
args
[
1
].
cast
<
int
>
();
struct
Args
{
auto
expert_ids
=
args
[
2
].
cast
<
intptr_t
>
();
CPUInfer
*
cpuinfer
;
auto
weights
=
args
[
3
].
cast
<
intptr_t
>
();
MOE
*
moe
;
auto
input
=
args
[
4
].
cast
<
intptr_t
>
();
};
auto
output
=
args
[
5
].
cast
<
intptr_t
>
();
static
void
inner
(
void
*
args
)
{
cpuinfer
.
submit
(
&
MOE
::
forward
,
moe
,
Args
*
args_
=
(
Args
*
)
args
;
qlen
,
k
,
(
const
uint64_t
*
)
expert_ids
,
(
const
float
*
)
weights
,
(
const
void
*
)
input
,
(
void
*
)
output
);
args_
->
cpuinfer
->
enqueue
(
&
MOE
::
warm_up
,
args_
->
moe
);
}
static
void
bind_warm_up
(
CPUInfer
&
cpuinfer
,
MOE
*
moe
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
cpuinfer
.
submit
(
&
MOE
::
warm_up
,
moe
);
}
static
void
bind_functions
(
CPUInfer
&
cpuinfer
,
py
::
object
func
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
auto
moe
=
func
.
attr
(
"__self__"
).
cast
<
MOE
*>
();
std
::
string
func_name
=
py
::
str
(
func
.
attr
(
"__func__"
).
attr
(
"__name__"
));
if
(
func_name
==
"forward"
)
{
bind_forward
(
cpuinfer
,
moe
,
args
,
kwargs
);
}
else
if
(
func_name
==
"warm_up"
)
{
bind_warm_up
(
cpuinfer
,
moe
,
args
,
kwargs
);
}
else
{
throw
py
::
value_error
(
"Unsupported function: "
+
std
::
string
(
func_name
));
}
}
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
MOE
&
moe
)
{
};
Args
*
args
=
new
Args
{
nullptr
,
&
moe
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
struct
MOEForwardArgs
{
}
CPUInfer
*
cpuinfer
;
};
MOE
*
moe
;
class
ForwardBindings
{
int
qlen
;
public:
int
k
;
struct
Args
{
uint64_t
*
expert_ids
;
CPUInfer
*
cpuinfer
;
float
*
weights
;
MOE
*
moe
;
void
*
input
;
int
qlen
;
void
*
output
;
int
k
;
const
uint64_t
*
expert_ids
;
const
float
*
weights
;
const
void
*
input
;
void
*
output
;
};
static
void
inner
(
void
*
args
)
{
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
MOE
::
forward
,
args_
->
moe
,
args_
->
qlen
,
args_
->
k
,
args_
->
expert_ids
,
args_
->
weights
,
args_
->
input
,
args_
->
output
);
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
MOE
&
moe
,
int
qlen
,
int
k
,
intptr_t
expert_ids
,
intptr_t
weights
,
intptr_t
input
,
intptr_t
output
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
moe
,
qlen
,
k
,
(
const
uint64_t
*
)
expert_ids
,
(
const
float
*
)
weights
,
(
const
void
*
)
input
,
(
void
*
)
output
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
};
};
};
void
submit_moe_forward_with_host_args_ptr
(
void
*
host_args_ptr
)
{
MOEForwardArgs
*
host_args
=
(
MOEForwardArgs
*
)
host_args_ptr
;
host_args
->
cpuinfer
->
submit
(
&
MOE
::
forward
,
host_args
->
moe
,
host_args
->
qlen
,
host_args
->
k
,
host_args
->
expert_ids
,
host_args
->
weights
,
host_args
->
input
,
host_args
->
output
);
}
void
cpuinfer_sync
(
void
*
host_args_ptr
)
{
CPUInfer
*
cpuinfer
=
(
CPUInfer
*
)
host_args_ptr
;
cpuinfer
->
sync
();
}
PYBIND11_MODULE
(
cpuinfer_ext
,
m
)
{
PYBIND11_MODULE
(
cpuinfer_ext
,
m
)
{
auto
linear_module
=
m
.
def_submodule
(
"linear"
);
py
::
class_
<
CPUInfer
>
(
m
,
"CPUInfer"
)
.
def
(
py
::
init
<
int
>
())
.
def
(
"submit"
,
&
CPUInfer
::
submit
)
.
def
(
"submit_with_cuda_stream"
,
&
CPUInfer
::
submit_with_cuda_stream
)
.
def
(
"sync"
,
&
CPUInfer
::
sync
)
.
def
(
"sync_with_cuda_stream"
,
&
CPUInfer
::
sync_with_cuda_stream
);
auto
linear_module
=
m
.
def_submodule
(
"linear"
);
py
::
class_
<
LinearConfig
>
(
linear_module
,
"LinearConfig"
)
py
::
class_
<
LinearConfig
>
(
linear_module
,
"LinearConfig"
)
.
def
(
py
::
init
([](
int
hidden_size
,
int
intermediate_size
,
int
stride
,
intptr_t
proj
,
int
proj_type
,
int
hidden_type
)
{
.
def
(
py
::
init
([](
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_max_len
,
intptr_t
proj
,
int
proj_type
,
int
hidden_type
)
{
return
LinearConfig
(
hidden_size
,
intermediate_size
,
stride
,
(
void
*
)
proj
,
(
ggml_type
)
proj_type
,
(
ggml_type
)
hidden_type
);
return
LinearConfig
(
hidden_size
,
intermediate_size
,
stride
,
group_max_len
,
(
void
*
)
proj
,
(
ggml_type
)
proj_type
,
(
ggml_type
)
hidden_type
);
}));
}));
py
::
class_
<
Linear
>
(
linear_module
,
"Linear"
)
py
::
class_
<
Linear
>
(
linear_module
,
"Linear"
)
.
def
(
py
::
init
<
LinearConfig
>
())
.
def
(
py
::
init
<
LinearConfig
>
())
.
def
(
"warm_up"
,
[](
Linear
&
linear
)
{
.
def
(
"warm_up"
,
&
LinearBindings
::
WarmUpBindinds
::
cpuinfer_interface
)
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
.
def
(
"forward"
,
&
LinearBindings
::
ForwardBindings
::
cpuinfer_interface
);
})
.
def
(
"forward"
,
[](
Linear
&
linear
,
intptr_t
input
,
intptr_t
output
)
{
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
});
auto
mlp_module
=
m
.
def_submodule
(
"mlp"
);
auto
mlp_module
=
m
.
def_submodule
(
"mlp"
);
py
::
class_
<
MLPConfig
>
(
mlp_module
,
"MLPConfig"
)
py
::
class_
<
MLPConfig
>
(
mlp_module
,
"MLPConfig"
)
.
def
(
py
::
init
([](
int
hidden_size
,
int
intermediate_size
,
int
stride
,
intptr_t
gate_proj
,
intptr_t
up_proj
,
intptr_t
down_proj
,
int
gate_type
,
int
up_type
,
int
down_type
,
int
hidden_type
)
{
.
def
(
py
::
init
([](
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_max_len
,
intptr_t
gate_proj
,
intptr_t
up_proj
,
intptr_t
down_proj
,
int
gate_type
,
int
up_type
,
int
down_type
,
int
hidden_type
)
{
return
MLPConfig
(
hidden_size
,
intermediate_size
,
stride
,
(
void
*
)
gate_proj
,
(
void
*
)
up_proj
,
(
void
*
)
down_proj
,
(
ggml_type
)
gate_type
,
(
ggml_type
)
up_type
,
(
ggml_type
)
down_type
,
(
ggml_type
)
hidden_type
);
return
MLPConfig
(
hidden_size
,
intermediate_size
,
stride
,
group_max_len
,
(
void
*
)
gate_proj
,
(
void
*
)
up_proj
,
(
void
*
)
down_proj
,
(
ggml_type
)
gate_type
,
(
ggml_type
)
up_type
,
(
ggml_type
)
down_type
,
(
ggml_type
)
hidden_type
);
}));
}));
py
::
class_
<
MLP
>
(
mlp_module
,
"MLP"
)
py
::
class_
<
MLP
>
(
mlp_module
,
"MLP"
)
.
def
(
py
::
init
<
MLPConfig
>
())
.
def
(
py
::
init
<
MLPConfig
>
())
.
def
(
"warm_up"
,
[](
MLP
&
mlp
)
{
.
def
(
"warm_up"
,
&
MLPBindings
::
WarmUpBindinds
::
cpuinfer_interface
)
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
.
def
(
"forward"
,
&
MLPBindings
::
ForwardBindings
::
cpuinfer_interface
);
})
.
def
(
"forward"
,
[](
MLP
&
mlp
,
intptr_t
input
,
intptr_t
output
)
{
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
});
auto
moe_module
=
m
.
def_submodule
(
"moe"
);
auto
moe_module
=
m
.
def_submodule
(
"moe"
);
py
::
class_
<
MOEConfig
>
(
moe_module
,
"MOEConfig"
)
py
::
class_
<
MOEConfig
>
(
moe_module
,
"MOEConfig"
)
.
def
(
py
::
init
([](
int
expert_num
,
int
routed_expert_num
,
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_min_len
,
int
group_max_len
,
intptr_t
gate_proj
,
intptr_t
up_proj
,
intptr_t
down_proj
,
int
gate_type
,
int
up_type
,
int
down_type
,
int
hidden_type
)
{
.
def
(
py
::
init
([](
int
expert_num
,
int
routed_expert_num
,
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_min_len
,
int
group_max_len
,
intptr_t
gate_proj
,
intptr_t
up_proj
,
intptr_t
down_proj
,
int
gate_type
,
int
up_type
,
int
down_type
,
int
hidden_type
)
{
return
MOEConfig
(
expert_num
,
routed_expert_num
,
hidden_size
,
intermediate_size
,
stride
,
group_min_len
,
group_max_len
,
(
void
*
)
gate_proj
,
(
void
*
)
up_proj
,
(
void
*
)
down_proj
,
(
ggml_type
)
gate_type
,
(
ggml_type
)
up_type
,
(
ggml_type
)
down_type
,
(
ggml_type
)
hidden_type
);
return
MOEConfig
(
expert_num
,
routed_expert_num
,
hidden_size
,
intermediate_size
,
stride
,
group_min_len
,
group_max_len
,
(
void
*
)
gate_proj
,
(
void
*
)
up_proj
,
(
void
*
)
down_proj
,
(
ggml_type
)
gate_type
,
(
ggml_type
)
up_type
,
(
ggml_type
)
down_type
,
(
ggml_type
)
hidden_type
);
}));
}));
py
::
class_
<
MOE
>
(
moe_module
,
"MOE"
)
py
::
class_
<
MOE
>
(
moe_module
,
"MOE"
)
.
def
(
py
::
init
<
MOEConfig
>
())
.
def
(
py
::
init
<
MOEConfig
>
())
.
def
(
"warm_up"
,
[](
MOE
&
moe
)
{
.
def
(
"warm_up"
,
&
MOEBindings
::
WarmUpBindinds
::
cpuinfer_interface
)
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
.
def
(
"forward"
,
&
MOEBindings
::
ForwardBindings
::
cpuinfer_interface
);
})
.
def
(
"forward"
,
[](
MOE
&
moe
,
int
k
,
uint64_t
expert_ids
,
intptr_t
weights
,
intptr_t
input
,
intptr_t
output
)
{
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
});
py
::
class_
<
CPUInfer
>
(
m
,
"CPUInfer"
)
.
def
(
py
::
init
<
int
>
())
.
def
(
"submit"
,
[
linear_module
,
mlp_module
,
moe_module
](
CPUInfer
&
cpuinfer
,
py
::
object
func
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
if
(
py
::
hasattr
(
func
,
"__self__"
)
&&
py
::
hasattr
(
func
,
"__func__"
))
{
std
::
string
class_name
=
py
::
str
(
func
.
attr
(
"__self__"
)
.
attr
(
"__class__"
)
.
attr
(
"__name__"
));
if
(
class_name
==
"Linear"
)
{
LinearBindings
::
bind_functions
(
cpuinfer
,
func
,
args
,
kwargs
);
}
else
if
(
class_name
==
"MLP"
)
{
MLPBindings
::
bind_functions
(
cpuinfer
,
func
,
args
,
kwargs
);
}
else
if
(
class_name
==
"MOE"
)
{
MOEBindings
::
bind_functions
(
cpuinfer
,
func
,
args
,
kwargs
);
}
else
{
// handle other classes
throw
py
::
type_error
(
"Unsupported class type: "
+
class_name
);
}
}
else
{
// handle cases where func does not have __self__ or
// __func__
throw
py
::
type_error
(
"Invalid function object: missing "
"__self__ or __func__ attribute."
);
}
})
.
def
(
"submit_with_cuda_stream"
,
[
linear_module
,
mlp_module
,
moe_module
](
CPUInfer
&
cpuinfer
,
intptr_t
user_cuda_stream
,
py
::
object
func
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
if
(
py
::
hasattr
(
func
,
"__self__"
)
&&
py
::
hasattr
(
func
,
"__func__"
))
{
std
::
string
class_name
=
py
::
str
(
func
.
attr
(
"__self__"
)
.
attr
(
"__class__"
)
.
attr
(
"__name__"
));
if
(
class_name
==
"MOE"
)
{
std
::
string
func_name
=
py
::
str
(
func
.
attr
(
"__func__"
).
attr
(
"__name__"
));
if
(
func_name
==
"forward"
)
{
auto
moe
=
func
.
attr
(
"__self__"
).
cast
<
MOE
*>
();
int
qlen
=
args
[
0
].
cast
<
int
>
();
int
k
=
args
[
1
].
cast
<
int
>
();
auto
expert_ids
=
args
[
2
].
cast
<
intptr_t
>
();
auto
weights
=
args
[
3
].
cast
<
intptr_t
>
();
auto
input
=
args
[
4
].
cast
<
intptr_t
>
();
auto
output
=
args
[
5
].
cast
<
intptr_t
>
();
MOEForwardArgs
*
moe_forward_args
=
new
MOEForwardArgs
{
&
cpuinfer
,
moe
,
qlen
,
k
,
(
uint64_t
*
)
expert_ids
,
(
float
*
)
weights
,
(
void
*
)
input
,
(
void
*
)
output
};
// submit_moe_forward_with_host_args_ptr(moe_forward_args);
cudaLaunchHostFunc
((
cudaStream_t
)
user_cuda_stream
,
(
cudaHostFn_t
)
submit_moe_forward_with_host_args_ptr
,
moe_forward_args
);
}
else
{
throw
py
::
value_error
(
"Unsupported function: "
+
std
::
string
(
func_name
));
}
}
else
{
// handle other classes
throw
py
::
type_error
(
"Unsupported class type: "
+
class_name
);
}
}
else
{
// handle cases where func does not have __self__ or
// __func__
throw
py
::
type_error
(
"Invalid function object: missing "
"__self__ or __func__ attribute."
);
}
})
.
def
(
"sync_with_cuda_stream"
,
[](
CPUInfer
&
cpuinfer
,
intptr_t
user_cuda_stream
)
{
// cpuinfer_sync((void*)(&cpuinfer));
cudaLaunchHostFunc
((
cudaStream_t
)
user_cuda_stream
,
(
cudaHostFn_t
)
cpuinfer_sync
,
(
void
*
)(
&
cpuinfer
));
})
.
def
(
"sync"
,
&
CPUInfer
::
sync
);
}
}
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py
deleted
100644 → 0
View file @
44f57270
import
math
import
os
import
time
from
logging
import
getLogger
import
torch
import
torch.nn
as
nn
import
transformers
from
.quantizer
import
Quantizer
logger
=
getLogger
(
__name__
)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
class
GPTQ
:
def
__init__
(
self
,
layer
):
self
.
layer
=
layer
self
.
dev
=
self
.
layer
.
weight
.
device
W
=
layer
.
weight
.
data
.
clone
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
W
=
W
.
flatten
(
1
)
if
isinstance
(
self
.
layer
,
transformers
.
pytorch_utils
.
Conv1D
):
W
=
W
.
t
()
self
.
rows
=
W
.
shape
[
0
]
self
.
columns
=
W
.
shape
[
1
]
self
.
H
=
torch
.
zeros
((
self
.
columns
,
self
.
columns
),
device
=
self
.
dev
)
self
.
nsamples
=
0
self
.
quantizer
=
Quantizer
()
def
add_batch
(
self
,
inp
,
out
):
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
inp1
=
inp
self
.
out1
=
out
if
len
(
inp
.
shape
)
==
2
:
inp
=
inp
.
unsqueeze
(
0
)
tmp
=
inp
.
shape
[
0
]
if
isinstance
(
self
.
layer
,
nn
.
Linear
)
or
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
if
len
(
inp
.
shape
)
==
3
:
inp
=
inp
.
reshape
((
-
1
,
inp
.
shape
[
-
1
]))
inp
=
inp
.
t
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
unfold
=
nn
.
Unfold
(
self
.
layer
.
kernel_size
,
dilation
=
self
.
layer
.
dilation
,
padding
=
self
.
layer
.
padding
,
stride
=
self
.
layer
.
stride
,
)
inp
=
unfold
(
inp
)
inp
=
inp
.
permute
([
1
,
0
,
2
])
inp
=
inp
.
flatten
(
1
)
self
.
H
*=
self
.
nsamples
/
(
self
.
nsamples
+
tmp
)
self
.
nsamples
+=
tmp
# inp = inp.float()
inp
=
math
.
sqrt
(
2
/
self
.
nsamples
)
*
inp
.
float
()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self
.
H
+=
inp
.
matmul
(
inp
.
t
())
def
fasterquant
(
self
,
blocksize
=
128
,
percdamp
=
0.01
,
group_size
=-
1
,
actorder
=
False
,
static_groups
=
False
,
):
W
=
self
.
layer
.
weight
.
data
.
clone
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
W
=
W
.
flatten
(
1
)
if
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
W
=
W
.
t
()
W
=
W
.
float
()
tick
=
time
.
time
()
if
not
self
.
quantizer
.
ready
():
self
.
quantizer
.
find_params
(
W
,
weight
=
True
)
H
=
self
.
H
del
self
.
H
dead
=
torch
.
diag
(
H
)
==
0
H
[
dead
,
dead
]
=
1
W
[:,
dead
]
=
0
g_idx
=
[]
scale
=
[]
zero
=
[]
now_idx
=
1
if
static_groups
:
import
copy
groups
=
[]
for
i
in
range
(
0
,
self
.
columns
,
group_size
):
quantizer
=
copy
.
deepcopy
(
self
.
quantizer
)
quantizer
.
find_params
(
W
[:,
i
:
(
i
+
group_size
)],
weight
=
True
)
scale
.
append
(
quantizer
.
scale
)
zero
.
append
(
quantizer
.
zero
)
groups
.
append
(
quantizer
)
if
actorder
:
perm
=
torch
.
argsort
(
torch
.
diag
(
H
),
descending
=
True
)
W
=
W
[:,
perm
]
H
=
H
[
perm
][:,
perm
]
invperm
=
torch
.
argsort
(
perm
)
Losses
=
torch
.
zeros_like
(
W
)
Q
=
torch
.
zeros_like
(
W
)
damp
=
percdamp
*
torch
.
mean
(
torch
.
diag
(
H
))
diag
=
torch
.
arange
(
self
.
columns
,
device
=
self
.
dev
)
H
[
diag
,
diag
]
+=
damp
H
=
torch
.
linalg
.
cholesky
(
H
)
H
=
torch
.
cholesky_inverse
(
H
)
H
=
torch
.
linalg
.
cholesky
(
H
,
upper
=
True
)
Hinv
=
H
for
i1
in
range
(
0
,
self
.
columns
,
blocksize
):
i2
=
min
(
i1
+
blocksize
,
self
.
columns
)
count
=
i2
-
i1
W1
=
W
[:,
i1
:
i2
].
clone
()
Q1
=
torch
.
zeros_like
(
W1
)
Err1
=
torch
.
zeros_like
(
W1
)
Losses1
=
torch
.
zeros_like
(
W1
)
Hinv1
=
Hinv
[
i1
:
i2
,
i1
:
i2
]
for
i
in
range
(
count
):
w
=
W1
[:,
i
]
d
=
Hinv1
[
i
,
i
]
if
group_size
!=
-
1
:
if
not
static_groups
:
if
(
i1
+
i
)
%
group_size
==
0
:
self
.
quantizer
.
find_params
(
W
[:,
(
i1
+
i
)
:
(
i1
+
i
+
group_size
)],
weight
=
True
)
if
((
i1
+
i
)
//
group_size
)
-
now_idx
==
-
1
:
scale
.
append
(
self
.
quantizer
.
scale
)
zero
.
append
(
self
.
quantizer
.
zero
)
now_idx
+=
1
else
:
idx
=
i1
+
i
if
actorder
:
idx
=
perm
[
idx
]
self
.
quantizer
=
groups
[
idx
//
group_size
]
q
=
self
.
quantizer
.
quantize
(
w
.
unsqueeze
(
1
)).
flatten
()
Q1
[:,
i
]
=
q
Losses1
[:,
i
]
=
(
w
-
q
)
**
2
/
d
**
2
err1
=
(
w
-
q
)
/
d
W1
[:,
i
:]
-=
err1
.
unsqueeze
(
1
).
matmul
(
Hinv1
[
i
,
i
:].
unsqueeze
(
0
))
Err1
[:,
i
]
=
err1
Q
[:,
i1
:
i2
]
=
Q1
Losses
[:,
i1
:
i2
]
=
Losses1
/
2
W
[:,
i2
:]
-=
Err1
.
matmul
(
Hinv
[
i1
:
i2
,
i2
:])
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
layer
.
weight
.
data
[:,
:
i2
]
=
Q
[:,
:
i2
]
self
.
layer
.
weight
.
data
[:,
i2
:]
=
W
[:,
i2
:]
logger
.
debug
(
torch
.
sum
((
self
.
layer
(
self
.
inp1
)
-
self
.
out1
)
**
2
))
logger
.
debug
(
torch
.
sum
(
Losses
))
torch
.
cuda
.
synchronize
()
logger
.
info
(
f
"duration:
{
(
time
.
time
()
-
tick
)
}
"
)
logger
.
info
(
f
"avg loss:
{
torch
.
sum
(
Losses
).
item
()
/
self
.
nsamples
}
"
)
group_size
=
group_size
if
group_size
!=
-
1
else
self
.
columns
if
static_groups
and
actorder
:
g_idx
=
[
perm
[
i
]
//
group_size
for
i
in
range
(
self
.
columns
)]
else
:
g_idx
=
[
i
//
group_size
for
i
in
range
(
self
.
columns
)]
g_idx
=
torch
.
tensor
(
g_idx
,
dtype
=
torch
.
int32
,
device
=
Q
.
device
)
if
actorder
:
Q
=
Q
[:,
invperm
]
g_idx
=
g_idx
[
invperm
]
if
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
Q
=
Q
.
t
()
self
.
layer
.
weight
.
data
=
Q
.
reshape
(
self
.
layer
.
weight
.
shape
).
type_as
(
self
.
layer
.
weight
.
data
)
if
os
.
environ
.
get
(
"DEBUG"
):
logger
.
debug
(
torch
.
sum
((
self
.
layer
(
self
.
inp1
)
-
self
.
out1
)
**
2
))
if
scale
==
[]:
scale
.
append
(
self
.
quantizer
.
scale
)
zero
.
append
(
self
.
quantizer
.
zero
)
scale
=
torch
.
cat
(
scale
,
dim
=
1
)
zero
=
torch
.
cat
(
zero
,
dim
=
1
)
return
scale
,
zero
,
g_idx
def
free
(
self
):
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
inp1
=
None
self
.
out1
=
None
self
.
H
=
None
self
.
Losses
=
None
self
.
Trace
=
None
torch
.
cuda
.
empty_cache
()
__all__
=
[
"GPTQ"
]
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py
deleted
100644 → 0
View file @
44f57270
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
get_pack_factor
(
num_bits
:
int
):
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"Marlin does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"Marlin does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
raise
ValueError
(
f
"Marlin does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
# Init
self
.
pack_factor
=
get_pack_factor
(
weight_bits
)
self
.
tile_size
=
GPTQ_MARLIN_TILE
self
.
min_thread_n
=
GPTQ_MARLIN_MIN_THREAD_N
self
.
min_thread_k
=
GPTQ_MARLIN_MIN_THREAD_K
self
.
max_parallel
=
GPTQ_MARLIN_MAX_PARALLEL
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"gptq"
:
logger
.
info
(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
is_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
return
False
# If the capability of the device is too low, cannot convert.
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
cls
.
get_min_capability
():
return
False
# Otherwise, can convert if model satisfies marlin constraints.
return
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and
sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
# Validate dtype
if
params_dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]:
raise
ValueError
(
f
"The params dtype must be float16 "
f
"or bfloat16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_thread_n
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
self
.
quant_config
.
min_thread_n
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_thread_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
self
.
quant_config
.
min_thread_k
}
."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
)
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size
=
input_size
//
group_size
scales_and_zp_input_dim
=
None
if
self
.
quant_config
.
desc_act
:
# Act-order case
assert
self
.
quant_config
.
group_size
!=
-
1
is_k_full
=
input_size_per_partition
==
input_size
else
:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full
=
True
# If this is a row-parallel case, then shard scales/zp
if
(
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_input_dim
=
0
# Init buffers
# Quantized weights
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Activation order
g_idx
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
g_idx_sort_indices
=
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
},
)
# Quantized zero-points
qzeros
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
device
=
"meta"
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Allocate marlin workspace
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_thread_n
)
*
self
.
quant_config
.
max_parallel
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
layer
.
workspace
=
workspace
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
is_k_full
=
is_k_full
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py
deleted
100644 → 0
View file @
44f57270
from
logging
import
getLogger
import
torch
import
torch.nn
as
nn
logger
=
getLogger
(
__name__
)
def
quantize
(
x
,
scale
,
zero
,
maxq
):
if
maxq
<
0
:
return
(
x
>
scale
/
2
).
float
()
*
scale
+
(
x
<
zero
/
2
).
float
()
*
zero
q
=
torch
.
clamp
(
torch
.
round
(
x
/
scale
)
+
zero
,
0
,
maxq
)
return
scale
*
(
q
-
zero
)
class
Quantizer
(
nn
.
Module
):
def
__init__
(
self
,
shape
=
1
):
super
(
Quantizer
,
self
).
__init__
()
self
.
register_buffer
(
"maxq"
,
torch
.
tensor
(
0
))
self
.
register_buffer
(
"scale"
,
torch
.
zeros
(
shape
))
self
.
register_buffer
(
"zero"
,
torch
.
zeros
(
shape
))
def
configure
(
self
,
bits
,
perchannel
=
False
,
sym
=
True
,
mse
=
False
,
norm
=
2.4
,
grid
=
100
,
maxshrink
=
0.8
,
trits
=
False
,
):
self
.
maxq
=
torch
.
tensor
(
2
**
bits
-
1
)
self
.
perchannel
=
perchannel
self
.
sym
=
sym
self
.
mse
=
mse
self
.
norm
=
norm
self
.
grid
=
grid
self
.
maxshrink
=
maxshrink
if
trits
:
self
.
maxq
=
torch
.
tensor
(
-
1
)
def
find_params
(
self
,
x
,
weight
=
False
):
dev
=
x
.
device
self
.
maxq
=
self
.
maxq
.
to
(
dev
)
shape
=
x
.
shape
if
self
.
perchannel
:
if
weight
:
x
=
x
.
flatten
(
1
)
else
:
if
len
(
shape
)
==
4
:
x
=
x
.
permute
([
1
,
0
,
2
,
3
])
x
=
x
.
flatten
(
1
)
if
len
(
shape
)
==
3
:
x
=
x
.
reshape
((
-
1
,
shape
[
-
1
])).
t
()
if
len
(
shape
)
==
2
:
x
=
x
.
t
()
else
:
x
=
x
.
flatten
().
unsqueeze
(
0
)
tmp
=
torch
.
zeros
(
x
.
shape
[
0
],
device
=
dev
)
xmin
=
torch
.
minimum
(
x
.
min
(
1
)[
0
],
tmp
)
xmax
=
torch
.
maximum
(
x
.
max
(
1
)[
0
],
tmp
)
if
self
.
sym
:
xmax
=
torch
.
maximum
(
torch
.
abs
(
xmin
),
xmax
)
tmp
=
xmin
<
0
if
torch
.
any
(
tmp
):
xmin
[
tmp
]
=
-
xmax
[
tmp
]
tmp
=
(
xmin
==
0
)
&
(
xmax
==
0
)
xmin
[
tmp
]
=
-
1
xmax
[
tmp
]
=
+
1
if
self
.
maxq
<
0
:
self
.
scale
=
xmax
self
.
zero
=
xmin
else
:
self
.
scale
=
(
xmax
-
xmin
)
/
self
.
maxq
if
self
.
sym
:
self
.
zero
=
torch
.
full_like
(
self
.
scale
,
(
self
.
maxq
+
1
)
/
2
)
else
:
self
.
zero
=
torch
.
round
(
-
xmin
/
self
.
scale
)
if
self
.
mse
:
best
=
torch
.
full
([
x
.
shape
[
0
]],
float
(
"inf"
),
device
=
dev
)
for
i
in
range
(
int
(
self
.
maxshrink
*
self
.
grid
)):
p
=
1
-
i
/
self
.
grid
xmin1
=
p
*
xmin
xmax1
=
p
*
xmax
scale1
=
(
xmax1
-
xmin1
)
/
self
.
maxq
zero1
=
torch
.
round
(
-
xmin1
/
scale1
)
if
not
self
.
sym
else
self
.
zero
q
=
quantize
(
x
,
scale1
.
unsqueeze
(
1
),
zero1
.
unsqueeze
(
1
),
self
.
maxq
)
q
-=
x
q
.
abs_
()
q
.
pow_
(
self
.
norm
)
err
=
torch
.
sum
(
q
,
1
)
tmp
=
err
<
best
if
torch
.
any
(
tmp
):
best
[
tmp
]
=
err
[
tmp
]
self
.
scale
[
tmp
]
=
scale1
[
tmp
]
self
.
zero
[
tmp
]
=
zero1
[
tmp
]
if
not
self
.
perchannel
:
if
weight
:
tmp
=
shape
[
0
]
else
:
tmp
=
shape
[
1
]
if
len
(
shape
)
!=
3
else
shape
[
2
]
self
.
scale
=
self
.
scale
.
repeat
(
tmp
)
self
.
zero
=
self
.
zero
.
repeat
(
tmp
)
if
weight
:
shape
=
[
-
1
]
+
[
1
]
*
(
len
(
shape
)
-
1
)
self
.
scale
=
self
.
scale
.
reshape
(
shape
)
self
.
zero
=
self
.
zero
.
reshape
(
shape
)
return
if
len
(
shape
)
==
4
:
self
.
scale
=
self
.
scale
.
reshape
((
1
,
-
1
,
1
,
1
))
self
.
zero
=
self
.
zero
.
reshape
((
1
,
-
1
,
1
,
1
))
if
len
(
shape
)
==
3
:
self
.
scale
=
self
.
scale
.
reshape
((
1
,
1
,
-
1
))
self
.
zero
=
self
.
zero
.
reshape
((
1
,
1
,
-
1
))
if
len
(
shape
)
==
2
:
self
.
scale
=
self
.
scale
.
unsqueeze
(
0
)
self
.
zero
=
self
.
zero
.
unsqueeze
(
0
)
def
quantize
(
self
,
x
):
if
self
.
ready
():
return
quantize
(
x
,
self
.
scale
,
self
.
zero
,
self
.
maxq
)
return
x
def
enabled
(
self
):
return
self
.
maxq
>
0
def
ready
(
self
):
return
torch
.
all
(
self
.
scale
!=
0
)
__all__
=
[
"Quantizer"
]
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py
deleted
100644 → 0
View file @
44f57270
import
torch
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
torch.nn.parameter
import
Parameter
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
View file @
77a34c28
...
@@ -220,7 +220,7 @@ def compute_max_diff(output, output_ref):
...
@@ -220,7 +220,7 @@ def compute_max_diff(output, output_ref):
class
MarlinWorkspace
:
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
,
device
):
assert
(
out_features
%
min_thread_n
==
0
),
(
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
out_features
,
min_thread_n
))
...
@@ -229,4 +229,4 @@ class MarlinWorkspace:
...
@@ -229,4 +229,4 @@ class MarlinWorkspace:
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
device
=
device
)
ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
View file @
77a34c28
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
* @Author : chenht2022
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:34:58
* @LastEditTime : 2024-07-25 10:34:58
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
...
@@ -13,9 +13,15 @@ Linear::Linear(LinearConfig config) {
...
@@ -13,9 +13,15 @@ Linear::Linear(LinearConfig config) {
config_
=
config
;
config_
=
config
;
proj_
=
config_
.
proj
;
proj_
=
config_
.
proj
;
input_fp32_
.
resize
(
config_
.
input_size
);
std
::
vector
<
std
::
pair
<
void
**
,
uint64_t
>>
mem_requests
;
proj_input_
.
resize
(
config_
.
input_size
*
4
);
mem_requests
.
push_back
({(
void
**
)
&
input_fp32_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
input_size
});
proj_output_
.
resize
(
config_
.
output_size
);
mem_requests
.
push_back
({(
void
**
)
&
proj_input_
,
config_
.
group_max_len
*
config_
.
input_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
)});
mem_requests
.
push_back
({(
void
**
)
&
proj_output_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
output_size
});
shared_mem_buffer
.
alloc
(
this
,
mem_requests
);
}
Linear
::~
Linear
()
{
shared_mem_buffer
.
dealloc
(
this
);
}
}
void
Linear
::
warm_up
(
Backend
*
backend
)
{
void
Linear
::
warm_up
(
Backend
*
backend
)
{
...
@@ -26,22 +32,42 @@ void Linear::warm_up(Backend* backend) {
...
@@ -26,22 +32,42 @@ void Linear::warm_up(Backend* backend) {
input_fp32
[
i
]
=
0
;
input_fp32
[
i
]
=
0
;
}
}
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
input_size
,
config_
.
hidden_type
);
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
input_size
,
config_
.
hidden_type
);
forward
(
input
.
data
(),
output
.
data
(),
backend
);
forward
_many
(
1
,
input
.
data
(),
output
.
data
(),
backend
);
}
}
void
Linear
::
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
void
Linear
::
forward
_many
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
const
void
*
proj_input_ptr
;
const
void
*
proj_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
)
{
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
)
{
proj_input_ptr
=
input
;
proj_input_ptr
=
input
;
}
else
{
}
else
{
to_float
(
input
,
input_fp32_
.
data
(),
config_
.
input_size
,
config_
.
hidden_type
);
to_float
(
input
,
input_fp32_
,
qlen
*
config_
.
input_size
,
config_
.
hidden_type
);
from_float
(
input_fp32_
.
data
()
,
proj_input_
.
data
(),
config_
.
input_size
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
);
from_float
(
input_fp32_
,
proj_input_
,
qlen
*
config_
.
input_size
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
);
proj_input_ptr
=
proj_input_
.
data
()
;
proj_input_ptr
=
proj_input_
;
}
}
int
nth
=
config_
.
output_size
/
config_
.
stride
;
int
nth
=
config_
.
output_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
%
nth
;
int
ith
=
task_id
;
llamafile_sgemm
(
config_
.
output_size
,
1
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_input_ptr
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_output_
.
data
(),
config_
.
output_size
,
ith
,
nth
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
proj_type
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
proj_ptr
=
(
uint8_t
*
)
proj_
+
ith
*
config_
.
stride
*
config_
.
input_size
*
ggml_type_size
(
config_
.
proj_type
)
/
ggml_blck_size
(
config_
.
proj_type
);
float
*
proj_output_ptr
=
proj_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_ptr
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_input_ptr
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_output_ptr
,
config_
.
output_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
proj_type
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
float
*
output_fp32_ptr
=
proj_output_
+
i
*
config_
.
output_size
+
ith
*
config_
.
stride
;
void
*
output_ptr
=
(
uint8_t
*
)
output
+
i
*
config_
.
output_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
from_float
(
output_fp32_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
}
}
});
});
from_float
(
proj_output_
.
data
(),
output
,
config_
.
output_size
,
config_
.
hidden_type
);
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
!=
0
)
{
from_float
(
proj_output_
,
output
,
qlen
*
config_
.
output_size
,
config_
.
hidden_type
);
}
}
void
Linear
::
forward
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
if
(
qlen
<=
0
)
{
return
;
}
int
forward_len
=
std
::
min
(
qlen
,
config_
.
group_max_len
);
forward_many
(
forward_len
,
input
,
output
,
backend
);
forward
(
qlen
-
forward_len
,
(
uint8_t
*
)
input
+
forward_len
*
config_
.
input_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
(
uint8_t
*
)
output
+
forward_len
*
config_
.
output_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
}
}
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/linear.h
View file @
77a34c28
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
* @Author : chenht2022
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:00
* @LastEditTime : 2024-07-25 10:35:00
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
...
@@ -22,34 +22,38 @@
...
@@ -22,34 +22,38 @@
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct
LinearConfig
{
struct
LinearConfig
{
int
input_size
;
int
input_size
;
int
output_size
;
int
output_size
;
int
stride
;
int
stride
;
int
group_max_len
;
void
*
proj
;
void
*
proj
;
ggml_type
proj_type
;
ggml_type
proj_type
;
ggml_type
hidden_type
;
ggml_type
hidden_type
;
LinearConfig
()
{}
LinearConfig
()
{}
LinearConfig
(
int
input_size
,
int
output_size
,
int
stride
,
void
*
proj
,
ggml_type
proj_type
,
ggml_type
hidden_type
)
LinearConfig
(
int
input_size
,
int
output_size
,
int
stride
,
int
group_max_len
,
void
*
proj
,
ggml_type
proj_type
,
ggml_type
hidden_type
)
:
input_size
(
input_size
),
output_size
(
output_size
),
stride
(
stride
),
proj
(
proj
),
proj_type
(
proj_type
),
hidden_type
(
hidden_type
)
{}
:
input_size
(
input_size
),
output_size
(
output_size
),
stride
(
stride
),
group_max_len
(
group_max_len
),
proj
(
proj
),
proj_type
(
proj_type
),
hidden_type
(
hidden_type
)
{}
};
};
class
Linear
{
class
Linear
{
public:
public:
Linear
(
LinearConfig
);
Linear
(
LinearConfig
);
~
Linear
();
void
warm_up
(
Backend
*
backend
);
void
warm_up
(
Backend
*
backend
);
void
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward_many
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
private:
private:
LinearConfig
config_
;
LinearConfig
config_
;
void
*
proj_
;
// [output_size * input_size ( /32 if quantized)]
void
*
proj_
;
// [output_size * input_size ( /32 if quantized)]
std
::
vector
<
float
>
input_fp32_
;
// [input_size]
float
*
input_fp32_
;
// [
group_max_len *
input_size]
std
::
vector
<
uint8_t
>
proj_input_
;
// [
input_size * 4
]
uint8_t
*
proj_input_
;
// [
group_max_len * input_size * ggml_type_size(ggml_internal_get_type_traits(proj_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(proj_type).vec_dot_type)
]
std
::
vector
<
float
>
proj_output_
;
// [output_size]
float
*
proj_output_
;
// [
group_max_len *
output_size]
};
};
#endif
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
View file @
77a34c28
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
* @Author : chenht2022
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:04
* @LastEditTime : 2024-07-25 10:35:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
...
@@ -15,14 +15,20 @@ MLP::MLP(MLPConfig config) {
...
@@ -15,14 +15,20 @@ MLP::MLP(MLPConfig config) {
up_proj_
=
config_
.
up_proj
;
up_proj_
=
config_
.
up_proj
;
down_proj_
=
config_
.
down_proj
;
down_proj_
=
config_
.
down_proj
;
input_fp32_
.
resize
(
config_
.
hidden_size
);
std
::
vector
<
std
::
pair
<
void
**
,
uint64_t
>>
mem_requests
;
gate_input_
.
resize
(
config_
.
hidden_size
*
4
);
mem_requests
.
push_back
({(
void
**
)
&
input_fp32_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
});
up_input_
.
resize
(
config_
.
hidden_size
*
4
);
mem_requests
.
push_back
({(
void
**
)
&
gate_input_
,
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)});
gate_output_
.
resize
(
config_
.
intermediate_size
);
mem_requests
.
push_back
({(
void
**
)
&
up_input_
,
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)});
up_output_
.
resize
(
config_
.
intermediate_size
);
mem_requests
.
push_back
({(
void
**
)
&
gate_output_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
intermediate_fp32_
.
resize
(
config_
.
intermediate_size
);
mem_requests
.
push_back
({(
void
**
)
&
up_output_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
down_input_
.
resize
(
config_
.
intermediate_size
*
4
);
mem_requests
.
push_back
({(
void
**
)
&
intermediate_fp32_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
down_output_
.
resize
(
config_
.
hidden_size
);
mem_requests
.
push_back
({(
void
**
)
&
down_input_
,
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)});
mem_requests
.
push_back
({(
void
**
)
&
down_output_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
});
shared_mem_buffer
.
alloc
(
this
,
mem_requests
);
}
MLP
::~
MLP
()
{
shared_mem_buffer
.
dealloc
(
this
);
}
}
void
MLP
::
warm_up
(
Backend
*
backend
)
{
void
MLP
::
warm_up
(
Backend
*
backend
)
{
...
@@ -33,33 +39,33 @@ void MLP::warm_up(Backend* backend) {
...
@@ -33,33 +39,33 @@ void MLP::warm_up(Backend* backend) {
input_fp32
[
i
]
=
0
;
input_fp32
[
i
]
=
0
;
}
}
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
forward
(
input
.
data
(),
output
.
data
(),
backend
);
forward
_many
(
1
,
input
.
data
(),
output
.
data
(),
backend
);
}
}
static
float
act_fn
(
float
x
)
{
static
float
act_fn
(
float
x
)
{
return
x
/
(
1.0
f
+
expf
(
-
x
));
return
x
/
(
1.0
f
+
expf
(
-
x
));
}
}
void
MLP
::
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
void
MLP
::
forward
_many
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
const
void
*
gate_input_ptr
;
const
void
*
gate_input_ptr
;
const
void
*
up_input_ptr
;
const
void
*
up_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
&&
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
&&
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
gate_input_ptr
=
up_input_ptr
=
input
;
gate_input_ptr
=
up_input_ptr
=
input
;
}
else
{
}
else
{
to_float
(
input
,
input_fp32_
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
to_float
(
input
,
input_fp32_
,
qlen
*
config_
.
hidden_size
,
config_
.
hidden_type
);
if
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
if
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
()
,
gate_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
from_float
(
input_fp32_
,
gate_input_
,
qlen
*
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
up_input_ptr
=
gate_input_
.
data
()
;
gate_input_ptr
=
up_input_ptr
=
gate_input_
;
}
else
{
}
else
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
()
,
gate_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
from_float
(
input_fp32_
,
gate_input_
,
qlen
*
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
gate_input_
.
data
()
;
gate_input_ptr
=
gate_input_
;
}
else
{
}
else
{
gate_input_ptr
=
input
;
gate_input_ptr
=
input
;
}
}
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
()
,
up_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
from_float
(
input_fp32_
,
up_input_
,
qlen
*
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
up_input_ptr
=
up_input_
.
data
()
;
up_input_ptr
=
up_input_
;
}
else
{
}
else
{
up_input_ptr
=
input
;
up_input_ptr
=
input
;
}
}
...
@@ -69,35 +75,49 @@ void MLP::forward(const void* input, void* output, Backend* backend) {
...
@@ -69,35 +75,49 @@ void MLP::forward(const void* input, void* output, Backend* backend) {
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
int
ith
=
task_id
;
void
*
gate_proj_ptr
=
(
uint8_t
*
)
gate_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
void
*
gate_proj_ptr
=
(
uint8_t
*
)
gate_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
float
*
gate_output_ptr
=
gate_output_
.
data
()
+
ith
*
config_
.
stride
;
float
*
gate_output_ptr
=
gate_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
strid
e
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
intermediate_siz
e
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
up_proj_ptr
=
(
uint8_t
*
)
up_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
void
*
up_proj_ptr
=
(
uint8_t
*
)
up_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
float
*
up_output_ptr
=
up_output_
.
data
()
+
ith
*
config_
.
stride
;
float
*
up_output_ptr
=
up_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
intermediate_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
ith
*
config_
.
stride
;
i
<
(
ith
+
1
)
*
config_
.
stride
;
i
++
)
{
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
intermediate_fp32_
[
i
]
=
act_fn
(
gate_output_
[
i
])
*
up_output_
[
i
];
for
(
int
j
=
ith
*
config_
.
stride
;
j
<
(
ith
+
1
)
*
config_
.
stride
;
j
++
)
{
}
intermediate_fp32_
[
i
*
config_
.
intermediate_size
+
j
]
=
act_fn
(
gate_output_
[
i
*
config_
.
intermediate_size
+
j
])
*
up_output_
[
i
*
config_
.
intermediate_size
+
j
];
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
==
0
)
{
}
float
*
intermediate_fp32_ptr
=
intermediate_fp32_
.
data
()
+
ith
*
config_
.
stride
;
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
==
0
)
{
void
*
down_input_ptr
=
(
uint8_t
*
)
down_input_
.
data
()
+
ith
*
config_
.
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
float
*
intermediate_fp32_ptr
=
intermediate_fp32_
+
i
*
config_
.
intermediate_size
+
ith
*
config_
.
stride
;
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
config_
.
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
void
*
down_input_ptr
=
(
uint8_t
*
)
down_input_
+
i
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
config_
.
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
}
}
});
});
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
!=
0
)
{
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
!=
0
)
{
from_float
(
intermediate_fp32_
.
data
()
,
down_input_
.
data
(),
config_
.
intermediate_size
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_
,
down_input_
,
qlen
*
config_
.
intermediate_size
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
}
nth
=
config_
.
hidden_size
/
config_
.
stride
;
nth
=
config_
.
hidden_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
int
ith
=
task_id
;
void
*
down_proj_ptr
=
(
uint8_t
*
)
down_proj_
+
ith
*
config_
.
stride
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
void
*
down_proj_ptr
=
(
uint8_t
*
)
down_proj_
+
ith
*
config_
.
stride
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
float
*
down_output_ptr
=
down_output_
.
data
()
+
ith
*
config_
.
stride
;
float
*
down_output_ptr
=
down_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_input_
.
data
()
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
strid
e
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_input_
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
hidden_siz
e
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
void
*
output_ptr
=
(
uint8_t
*
)
output
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
from_float
(
down_output_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
float
*
output_fp32_ptr
=
down_output_
+
i
*
config_
.
hidden_size
+
ith
*
config_
.
stride
;
void
*
output_ptr
=
(
uint8_t
*
)
output
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
from_float
(
output_fp32_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
}
}
}
});
});
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
!=
0
)
{
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
!=
0
)
{
from_float
(
down_output_
.
data
()
,
output
,
config_
.
hidden_size
,
config_
.
hidden_type
);
from_float
(
down_output_
,
output
,
qlen
*
config_
.
hidden_size
,
config_
.
hidden_type
);
}
}
}
}
void
MLP
::
forward
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
if
(
qlen
<=
0
)
{
return
;
}
int
forward_len
=
std
::
min
(
qlen
,
config_
.
group_max_len
);
forward_many
(
forward_len
,
input
,
output
,
backend
);
forward
(
qlen
-
forward_len
,
(
uint8_t
*
)
input
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
(
uint8_t
*
)
output
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
}
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
View file @
77a34c28
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
* @Author : chenht2022
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:06
* @LastEditTime : 2024-07-25 10:35:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
...
@@ -22,11 +22,13 @@
...
@@ -22,11 +22,13 @@
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct
MLPConfig
{
struct
MLPConfig
{
int
hidden_size
;
int
hidden_size
;
int
intermediate_size
;
int
intermediate_size
;
int
stride
;
int
stride
;
int
group_max_len
;
void
*
gate_proj
;
void
*
gate_proj
;
void
*
up_proj
;
void
*
up_proj
;
void
*
down_proj
;
void
*
down_proj
;
...
@@ -37,15 +39,17 @@ struct MLPConfig {
...
@@ -37,15 +39,17 @@ struct MLPConfig {
MLPConfig
()
{}
MLPConfig
()
{}
MLPConfig
(
int
hidden_size
,
int
intermediate_size
,
int
stride
,
void
*
gate_proj
,
void
*
up_proj
,
void
*
down_proj
,
ggml_type
gate_type
,
ggml_type
up_type
,
ggml_type
down_type
,
ggml_type
hidden_type
)
MLPConfig
(
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_max_len
,
void
*
gate_proj
,
void
*
up_proj
,
void
*
down_proj
,
ggml_type
gate_type
,
ggml_type
up_type
,
ggml_type
down_type
,
ggml_type
hidden_type
)
:
hidden_size
(
hidden_size
),
intermediate_size
(
intermediate_size
),
stride
(
stride
),
gate_proj
(
gate_proj
),
up_proj
(
up_proj
),
down_proj
(
down_proj
),
gate_type
(
gate_type
),
up_type
(
up_type
),
down_type
(
down_type
),
hidden_type
(
hidden_type
)
{}
:
hidden_size
(
hidden_size
),
intermediate_size
(
intermediate_size
),
stride
(
stride
),
group_max_len
(
group_max_len
),
gate_proj
(
gate_proj
),
up_proj
(
up_proj
),
down_proj
(
down_proj
),
gate_type
(
gate_type
),
up_type
(
up_type
),
down_type
(
down_type
),
hidden_type
(
hidden_type
)
{}
};
};
class
MLP
{
class
MLP
{
public:
public:
MLP
(
MLPConfig
);
MLP
(
MLPConfig
);
~
MLP
();
void
warm_up
(
Backend
*
backend
);
void
warm_up
(
Backend
*
backend
);
void
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward_many
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
private:
private:
MLPConfig
config_
;
MLPConfig
config_
;
...
@@ -53,14 +57,14 @@ class MLP {
...
@@ -53,14 +57,14 @@ class MLP {
void
*
up_proj_
;
// [intermediate_size * hidden_size ( /32 if quantized)]
void
*
up_proj_
;
// [intermediate_size * hidden_size ( /32 if quantized)]
void
*
down_proj_
;
// [hidden_size * intermediate_size ( /32 if quantized)]
void
*
down_proj_
;
// [hidden_size * intermediate_size ( /32 if quantized)]
std
::
vector
<
float
>
input_fp32_
;
// [hidden_size]
float
*
input_fp32_
;
// [
group_max_len *
hidden_size]
std
::
vector
<
uint8_t
>
gate_input_
;
// [
hidden_size * 4
]
uint8_t
*
gate_input_
;
// [
group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)
]
std
::
vector
<
uint8_t
>
up_input_
;
// [
hidden_size * 4
]
uint8_t
*
up_input_
;
// [
group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)
]
std
::
vector
<
float
>
gate_output_
;
// [intermediate_size]
float
*
gate_output_
;
// [
group_max_len *
intermediate_size]
std
::
vector
<
float
>
up_output_
;
// [intermediate_size]
float
*
up_output_
;
// [
group_max_len *
intermediate_size]
std
::
vector
<
float
>
intermediate_fp32_
;
// [intermediate_size]
float
*
intermediate_fp32_
;
// [
group_max_len *
intermediate_size]
std
::
vector
<
uint8_t
>
down_input_
;
// [intermediate_size *
4
]
uint8_t
*
down_input_
;
// [
group_max_len *
intermediate_size *
ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)
]
std
::
vector
<
float
>
down_output_
;
// [hidden_size]
float
*
down_output_
;
// [
group_max_len *
hidden_size]
};
};
#endif
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
View file @
77a34c28
/**
/**
* @Description :
* @Description :
* @Author : chenht2022
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:07
* @LastEditTime : 2024-07-25 10:35:07
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
#include "moe.h"
#include "moe.h"
#include <iostream>
#include <iostream>
#include <cstdint>
#include <cstdint>
uint8_t
*
MOE
::
buffer_
=
nullptr
;
MOE
::
MOE
(
MOEConfig
config
)
{
MOE
::
MOE
(
MOEConfig
config
)
{
config_
=
config
;
config_
=
config
;
gate_proj_
=
config_
.
gate_proj
;
gate_proj_
=
config_
.
gate_proj
;
up_proj_
=
config_
.
up_proj
;
up_proj_
=
config_
.
up_proj
;
down_proj_
=
config_
.
down_proj
;
down_proj_
=
config_
.
down_proj
;
if
(
MOE
::
buffer_
==
nullptr
)
{
std
::
vector
<
std
::
pair
<
void
**
,
uint64_t
>>
s_mem_requests
;
uint64_t
buffer_size
=
0
;
s_mem_requests
.
push_back
({(
void
**
)
&
s_input_fp32_
,
sizeof
(
float
)
*
config_
.
hidden_size
});
buffer_size
+=
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
;
s_mem_requests
.
push_back
({(
void
**
)
&
s_gate_input_
,
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)});
buffer_size
+=
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
s_mem_requests
.
push_back
({(
void
**
)
&
s_up_input_
,
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)});
buffer_size
+=
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
;
buffer_
=
(
uint8_t
*
)
malloc
(
buffer_size
);
}
uint64_t
offset
=
0
;
s_input_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
s_gate_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
s_up_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
s_gate_output_
.
resize
(
config_
.
routed_expert_num
);
s_gate_output_
.
resize
(
config_
.
routed_expert_num
);
s_up_output_
.
resize
(
config_
.
routed_expert_num
);
s_up_output_
.
resize
(
config_
.
routed_expert_num
);
s_intermediate_fp32_
.
resize
(
config_
.
routed_expert_num
);
s_intermediate_fp32_
.
resize
(
config_
.
routed_expert_num
);
s_down_input_
.
resize
(
config_
.
routed_expert_num
);
s_down_input_
.
resize
(
config_
.
routed_expert_num
);
s_down_output_
.
resize
(
config_
.
routed_expert_num
);
s_down_output_
.
resize
(
config_
.
routed_expert_num
);
for
(
int
i
=
0
;
i
<
config_
.
routed_expert_num
;
i
++
)
{
for
(
int
i
=
0
;
i
<
config_
.
routed_expert_num
;
i
++
)
{
s_gate_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
s_mem_requests
.
push_back
({(
void
**
)
&
s_gate_output_
[
i
],
sizeof
(
float
)
*
config_
.
intermediate_size
});
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_mem_requests
.
push_back
({(
void
**
)
&
s_up_output_
[
i
],
sizeof
(
float
)
*
config_
.
intermediate_size
});
s_up_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
s_mem_requests
.
push_back
({(
void
**
)
&
s_intermediate_fp32_
[
i
],
sizeof
(
float
)
*
config_
.
intermediate_size
});
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_mem_requests
.
push_back
({(
void
**
)
&
s_down_input_
[
i
],
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)});
s_intermediate_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
s_mem_requests
.
push_back
({(
void
**
)
&
s_down_output_
[
i
],
sizeof
(
float
)
*
config_
.
hidden_size
});
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_down_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
s_down_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
}
}
s_output_fp32_
=
(
float
*
)(
buffer_
+
offset
);
s_mem_requests
.
push_back
({(
void
**
)
&
s_output_fp32_
,
sizeof
(
float
)
*
config_
.
hidden_size
});
shared_mem_buffer
.
alloc
(
this
,
s_mem_requests
);
offset
=
0
;
std
::
vector
<
std
::
pair
<
void
**
,
uint64_t
>>
m_mem_requests
;
m_input_fp32_
.
resize
(
config_
.
group_max_len
);
m_input_fp32_
.
resize
(
config_
.
group_max_len
);
m_gate_input_
.
resize
(
config_
.
group_max_len
);
m_gate_input_
.
resize
(
config_
.
group_max_len
);
m_up_input_
.
resize
(
config_
.
group_max_len
);
m_up_input_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
m_input_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_input_fp32_
[
i
],
sizeof
(
float
)
*
config_
.
hidden_size
});
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
m_mem_requests
.
push_back
({(
void
**
)
&
m_gate_input_
[
i
],
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)});
m_gate_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_up_input_
[
i
],
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)});
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
m_up_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
}
}
m_local_gate_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_gate_input_
,
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)});
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_up_input_
,
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)});
m_local_up_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_gate_output_
,
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_up_output_
,
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
m_local_gate_output_
=
(
float
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_intermediate_fp32_
,
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_down_input_
,
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)});
m_local_up_output_
=
(
float
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_down_output_
,
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
});
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_intermediate_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_down_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
m_local_down_output_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
;
m_output_fp32_
.
resize
(
config_
.
group_max_len
);
m_output_fp32_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
m_output_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_output_fp32_
[
i
],
sizeof
(
float
)
*
config_
.
hidden_size
});
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
}
}
shared_mem_buffer
.
alloc
(
this
,
m_mem_requests
);
m_local_pos_
.
resize
(
config_
.
group_max_len
);
m_local_pos_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
...
@@ -107,6 +72,10 @@ MOE::MOE(MOEConfig config) {
...
@@ -107,6 +72,10 @@ MOE::MOE(MOEConfig config) {
m_local_down_output_ptr_
.
resize
(
config_
.
expert_num
);
m_local_down_output_ptr_
.
resize
(
config_
.
expert_num
);
}
}
MOE
::~
MOE
()
{
shared_mem_buffer
.
dealloc
(
this
);
}
void
MOE
::
warm_up
(
Backend
*
backend
)
{
void
MOE
::
warm_up
(
Backend
*
backend
)
{
std
::
vector
<
float
>
input_fp32
(
config_
.
hidden_size
);
std
::
vector
<
float
>
input_fp32
(
config_
.
hidden_size
);
std
::
vector
<
uint8_t
>
input
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
std
::
vector
<
uint8_t
>
input
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment