Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
77a34c28
Unverified
Commit
77a34c28
authored
Aug 15, 2024
by
UnicornChan
Committed by
GitHub
Aug 15, 2024
Browse files
Merge pull request #36 from kvcache-ai/develop-0.1.2
Release v0.1.2
parents
44f57270
395cd3e7
Changes
69
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
617 additions
and
1453 deletions
+617
-1453
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
+8
-0
ktransformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h
...sformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h
+0
-39
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+180
-2
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
+6
-3
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
...formers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
+2
-2
ktransformers/ktransformers_ext/cuda/setup.py
ktransformers/ktransformers_ext/cuda/setup.py
+19
-11
ktransformers/ktransformers_ext/examples/test_linear.py
ktransformers/ktransformers_ext/examples/test_linear.py
+22
-43
ktransformers/ktransformers_ext/examples/test_mlp.py
ktransformers/ktransformers_ext/examples/test_mlp.py
+35
-51
ktransformers/ktransformers_ext/examples/test_moe.py
ktransformers/ktransformers_ext/examples/test_moe.py
+73
-65
ktransformers/ktransformers_ext/ext_bindings.cpp
ktransformers/ktransformers_ext/ext_bindings.cpp
+118
-203
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py
...transformers_ext/operators/custom_marlin/quantize/gptq.py
+0
-206
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py
...rmers_ext/operators/custom_marlin/quantize/gptq_marlin.py
+0
-458
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py
...formers_ext/operators/custom_marlin/quantize/quantizer.py
+0
-140
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py
...ansformers_ext/operators/custom_marlin/quantize/repack.py
+0
-99
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
...xt/operators/custom_marlin/quantize/utils/marlin_utils.py
+2
-2
ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
...sformers/ktransformers_ext/operators/llamafile/linear.cpp
+38
-12
ktransformers/ktransformers_ext/operators/llamafile/linear.h
ktransformers/ktransformers_ext/operators/llamafile/linear.h
+11
-7
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
+55
-35
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
+16
-12
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
+32
-63
No files found.
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
View file @
77a34c28
...
@@ -12,14 +12,22 @@ int test(){
...
@@ -12,14 +12,22 @@ int test(){
}
}
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
PYBIND11_MODULE
(
cudaops
,
m
)
{
PYBIND11_MODULE
(
cudaops
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
}
}
ktransformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h
deleted
100644 → 0
View file @
44f57270
#include <cuda_fp16.h>
__device__
float
ggml_compute_fp16_to_fp32
(
uint16_t
h
)
{
return
__uint2float_rd
(
h
);
}
static
inline
float
ggml_compute_fp16_to_fp32
(
uint16_t
h
)
{
uint16_t
tmp
;
memcpy
(
&
tmp
,
&
h
,
sizeof
(
ggml_fp16_t
));
return
(
float
)
tmp
;
}
// define the global table for fp16 to fp32 conversion
__device__
float
ggml_table_f32_f16
[
1
<<
16
];
// CUDA Kernel to init the table
__global__
void
init_fp16_to_fp32_table
()
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
blk_id
=
idx
;
blk_id
<
(
1
<<
16
);
blk_id
+=
blockDim
.
x
*
gridDim
.
x
){
ggml_table_f32_f16
[
blk_id
]
=
GGML_COMPUTE_FP16_TO_FP32
(
blk_id
);
}
}
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
extern
__device__
float
ggml_table_f32_f16
[
1
<<
16
];
// Declare as __device__ if used within device code
// This version of the function is designed to be called from within a CUDA kernel
#if !defined(GGML_FP16_TO_FP32)
__device__
float
ggml_lookup_fp16_to_fp32
(
uint16_t
f
)
{
return
ggml_table_f32_f16
[
f
];
}
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
77a34c28
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* @Author : Azure-Tang, Boxin Zhang
* @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors :
Azure
* @LastEditors :
kkk1nak0
* @LastEditTime : 2024-0
7-26 11:58:50
* @LastEditTime : 2024-0
8-12 04:18:04
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include <torch/extension.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <torch/torch.h>
#include <cstdint>
#include <cstdint>
#include <c10/cuda/CUDAGuard.h>
__global__
void
dequantize_q8_0_kernel
(
float
*
output
,
const
float
*
scales
,
const
int8_t
*
qs
,
int
num_blocks
,
int
blk_size
)
{
__global__
void
dequantize_q8_0_kernel
(
float
*
output
,
const
float
*
scales
,
const
int8_t
*
qs
,
int
num_blocks
,
int
blk_size
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -35,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
...
@@ -35,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}
}
}
}
__global__
void
dequantize_q2_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
82
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
int
is
=
0
;
float
dl
,
ml
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
uint8_t
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
]
>>
shift
)
&
3
))
-
ml
;
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
))
-
ml
;
shift
+=
2
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q3_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask2
=
0x0f0f0f0f
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
uint32_t
aux
[
4
];
const
int8_t
*
scales
=
(
const
int8_t
*
)
aux
;
const
float
d_all
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
108
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
32
);
const
uint8_t
*
__restrict__
hm
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
0
);
uint8_t
m
=
1
;
uint8_t
*
block_scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
96
);
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
aux
[
i
]
=
0
;
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
aux
[
i
]
|=
((
uint32_t
)
block_scales
[
i
*
4
+
j
])
<<
(
j
*
8
);
}
}
uint32_t
tmp
=
aux
[
2
];
aux
[
2
]
=
((
aux
[
0
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
4
)
&
kmask1
)
<<
4
);
aux
[
3
]
=
((
aux
[
1
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
6
)
&
kmask1
)
<<
4
);
aux
[
0
]
=
(
aux
[
0
]
&
kmask2
)
|
(((
tmp
>>
0
)
&
kmask1
)
<<
4
);
aux
[
1
]
=
(
aux
[
1
]
&
kmask2
)
|
(((
tmp
>>
2
)
&
kmask1
)
<<
4
);
int
is
=
0
;
float
dl
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
0
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
0
]
&
m
)
?
0
:
4
));
}
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
16
]
&
m
)
?
0
:
4
));
}
shift
+=
2
;
m
<<=
1
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
...
@@ -59,6 +151,35 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -59,6 +151,35 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
}
}
}
}
__global__
void
dequantize_q5_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
2
)));
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
48
);
int
is
=
0
;
uint8_t
sc
,
m
;
uint8_t
u1
=
1
,
u2
=
2
;
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
4
);
for
(
int
j
=
0
;
j
<
256
;
j
+=
64
)
{
get_scale_min_k4
(
is
+
0
,
scales
,
&
sc
,
&
m
);
const
float
d1
=
d
*
sc
;
const
float
m1
=
min
*
m
;
get_scale_min_k4
(
is
+
1
,
scales
,
&
sc
,
&
m
);
const
float
d2
=
d
*
sc
;
const
float
m2
=
min
*
m
;
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
d1
*
((
ql
[
l
]
&
0xF
)
+
(
qh
[
l
]
&
u1
?
16
:
0
))
-
m1
;
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
d2
*
((
ql
[
l
]
>>
4
)
+
(
qh
[
l
]
&
u2
?
16
:
0
))
-
m2
;
ql
+=
32
;
is
+=
2
;
u1
<<=
2
;
u2
<<=
2
;
}
}
}
__global__
void
dequantize_q6_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q6_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
...
@@ -94,6 +215,7 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -94,6 +215,7 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
data
.
numel
()
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
// create gpu
// create gpu
auto
options_scales
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options_scales
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options_qs
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options_qs
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
...
@@ -128,6 +250,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -128,6 +250,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
// data.numel%blk_size should be 0, else raise err
// data.numel%blk_size should be 0, else raise err
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
data
.
numel
()
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
...
@@ -144,9 +267,28 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -144,9 +267,28 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q5_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
// data.numel%blk_size should be 0, else raise err
// data.numel%blk_size should be 0, else raise err
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
data
.
numel
()
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
...
@@ -162,3 +304,39 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -162,3 +304,39 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q3_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q2_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
View file @
77a34c28
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* @Author : Azure-Tang
* @Author : Azure-Tang
* @Date : 2024-07-22 09:27:55
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors :
Azure
* @LastEditors :
kkk1nak0
* @LastEditTime : 2024-0
7-26 08:38:20
* @LastEditTime : 2024-0
8-12 03:48:46
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
#pragma once
#pragma once
...
@@ -15,4 +15,7 @@
...
@@ -15,4 +15,7 @@
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
View file @
77a34c28
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include "gptq_marlin_dtypes.cuh"
#include <c10/cuda/CUDAGuard.h>
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
std::is_same<scalar_t, nv_bfloat16>::value, \
...
@@ -1774,6 +1774,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1774,6 +1774,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
int64_t
size_k
,
bool
is_k_full
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
// Verify num_bits
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
"num_bits must be 4 or 8. Got = "
,
num_bits
);
...
@@ -1816,7 +1817,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1816,7 +1817,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
// Alloc buffers
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
...
...
ktransformers/ktransformers_ext/cuda/setup.py
View file @
77a34c28
...
@@ -2,17 +2,25 @@
...
@@ -2,17 +2,25 @@
from
setuptools
import
setup
,
Extension
from
setuptools
import
setup
,
Extension
from
torch.utils
import
cpp_extension
from
torch.utils
import
cpp_extension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
setup
(
# setup marlin gemm
name
=
'KTransformersOps'
,
setup
(
name
=
'KTransformersOps'
,
ext_modules
=
[
ext_modules
=
[
CUDAExtension
(
'KTransformersOps'
,
[
CUDAExtension
(
'KTransformersOps'
,
[
'custom_gguf/dequant.cu'
,
'custom_gguf/dequant.cu'
,
'binding.cpp'
,
'binding.cpp'
,
'gptq_marlin/gptq_marlin.cu'
,
'gptq_marlin/gptq_marlin.cu'
,
# 'gptq_marlin_repack.cu',
# 'gptq_marlin_repack.cu',
])
],
],
cmdclass
=
{
'build_ext'
:
BuildExtension
extra_compile_args
=
{
})
'cxx'
:
[
'-O3'
],
'nvcc'
:
[
'-O3'
,
'--use_fast_math'
,
'-Xcompiler'
,
'-fPIC'
,
]
},
)
],
cmdclass
=
{
'build_ext'
:
BuildExtension
}
)
\ No newline at end of file
ktransformers/ktransformers_ext/examples/test_linear.py
View file @
77a34c28
...
@@ -6,7 +6,7 @@ Author : chenht2022
...
@@ -6,7 +6,7 @@ Author : chenht2022
Date : 2024-07-25 10:32:05
Date : 2024-07-25 10:32:05
Version : 1.0.0
Version : 1.0.0
LastEditors : chenht2022
LastEditors : chenht2022
LastEditTime : 2024-0
7-25
10:3
4:00
LastEditTime : 2024-0
8-06
10:3
6:59
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
'''
import
os
,
sys
import
os
,
sys
...
@@ -15,23 +15,23 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
...
@@ -15,23 +15,23 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
import
cpuinfer_ext
import
cpuinfer_ext
import
torch
import
torch
with
torch
.
inference_mode
(
mode
=
True
):
input_size
=
16384
input_size
=
16384
output_size
=
5120
output_size
=
5120
stride
=
32
stride
=
32
group_max_len
=
1024
proj_type
=
1
# ggml_type::GGML_TYPE_F16
proj_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
layer_num
=
10
qlen
=
30
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
layer_num
=
10
validation_iter
=
100
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
warm_up_iter
=
1000
validation_iter
=
100
test_iter
=
10000
with
torch
.
inference_mode
(
mode
=
True
):
linears
=
[]
linears
=
[]
projs
=
[]
projs
=
[]
for
_
in
range
(
layer_num
):
for
_
in
range
(
layer_num
):
proj
=
torch
.
randn
((
output_size
,
input_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
proj
=
torch
.
randn
((
output_size
,
input_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
config
=
cpuinfer_ext
.
linear
.
LinearConfig
(
input_size
,
output_size
,
stride
,
proj
.
data_ptr
(),
proj_type
,
hidden_type
)
config
=
cpuinfer_ext
.
linear
.
LinearConfig
(
input_size
,
output_size
,
stride
,
group_max_len
,
proj
.
data_ptr
(),
proj_type
,
hidden_type
)
linear
=
cpuinfer_ext
.
linear
.
Linear
(
config
)
linear
=
cpuinfer_ext
.
linear
.
Linear
(
config
)
projs
.
append
(
proj
)
projs
.
append
(
proj
)
linears
.
append
(
linear
)
linears
.
append
(
linear
)
...
@@ -39,11 +39,17 @@ with torch.inference_mode(mode=True):
...
@@ -39,11 +39,17 @@ with torch.inference_mode(mode=True):
# validation
# validation
for
i
in
range
(
validation_iter
):
for
i
in
range
(
validation_iter
):
linear
=
linears
[
i
%
layer_num
]
linear
=
linears
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
input_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
input_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
output_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
output_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
input
=
input
/
100
CPUInfer
.
submit
(
linear
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
submit
(
linear
.
forward
(
qlen
,
input
.
data_ptr
(),
output
.
data_ptr
()
)
)
CPUInfer
.
sync
()
CPUInfer
.
sync
()
# print('cpuinfer output', output)
# print('cpuinfer output', output)
...
@@ -54,30 +60,3 @@ with torch.inference_mode(mode=True):
...
@@ -54,30 +60,3 @@ with torch.inference_mode(mode=True):
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
print
(
'diff = '
,
diff
)
print
(
'diff = '
,
diff
)
assert
(
diff
<
0.001
)
assert
(
diff
<
0.001
)
# warm up
for
i
in
range
(
warm_up_iter
):
linear
=
linears
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
input_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
output_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
CPUInfer
.
submit
(
linear
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
# test
total_time
=
0
for
i
in
range
(
test_iter
):
linear
=
linears
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
input_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
output_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
start
=
time
.
perf_counter
()
CPUInfer
.
submit
(
linear
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
end
=
time
.
perf_counter
()
total_time
+=
end
-
start
print
(
'Time: '
,
total_time
)
print
(
'Iteration: '
,
test_iter
)
print
(
'Time per iteration: '
,
total_time
/
test_iter
)
print
(
'Bandwidth: '
,
input_size
*
output_size
*
2
*
test_iter
/
total_time
/
1000
/
1000
/
1000
,
'GB/s'
)
print
(
"All tasks completed."
)
\ No newline at end of file
ktransformers/ktransformers_ext/examples/test_mlp.py
View file @
77a34c28
...
@@ -6,7 +6,7 @@ Author : chenht2022
...
@@ -6,7 +6,7 @@ Author : chenht2022
Date : 2024-07-25 10:32:05
Date : 2024-07-25 10:32:05
Version : 1.0.0
Version : 1.0.0
LastEditors : chenht2022
LastEditors : chenht2022
LastEditTime : 2024-0
7-25
10:3
4:03
LastEditTime : 2024-0
8-06
10:3
7:28
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
'''
import
os
,
sys
import
os
,
sys
...
@@ -15,20 +15,30 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
...
@@ -15,20 +15,30 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
import
cpuinfer_ext
import
cpuinfer_ext
import
torch
import
torch
with
torch
.
inference_mode
(
mode
=
True
):
hidden_size
=
5120
hidden_size
=
5120
intermediate_size
=
3072
intermediate_size
=
3072
stride
=
32
stride
=
32
group_max_len
=
1024
gate_type
=
1
# ggml_type::GGML_TYPE_F16
gate_type
=
1
# ggml_type::GGML_TYPE_F16
up_type
=
1
# ggml_type::GGML_TYPE_F16
up_type
=
1
# ggml_type::GGML_TYPE_F16
down_type
=
1
# ggml_type::GGML_TYPE_F16
down_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
layer_num
=
10
qlen
=
30
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
layer_num
=
10
validation_iter
=
100
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
warm_up_iter
=
1000
validation_iter
=
100
test_iter
=
10000
def
act_fn
(
x
):
return
x
/
(
1.0
+
torch
.
exp
(
-
x
))
def
mlp_torch
(
input
,
gate_proj
,
up_proj
,
down_proj
):
gate_buf
=
torch
.
mm
(
input
,
gate_proj
.
t
())
up_buf
=
torch
.
mm
(
input
,
up_proj
.
t
())
intermediate
=
act_fn
(
gate_buf
)
*
up_buf
ret
=
torch
.
mm
(
intermediate
,
down_proj
.
t
())
return
ret
with
torch
.
inference_mode
(
mode
=
True
):
mlps
=
[]
mlps
=
[]
gate_projs
=
[]
gate_projs
=
[]
up_projs
=
[]
up_projs
=
[]
...
@@ -37,7 +47,7 @@ with torch.inference_mode(mode=True):
...
@@ -37,7 +47,7 @@ with torch.inference_mode(mode=True):
gate_proj
=
torch
.
randn
((
intermediate_size
,
hidden_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
gate_proj
=
torch
.
randn
((
intermediate_size
,
hidden_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
up_proj
=
torch
.
randn
((
intermediate_size
,
hidden_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
up_proj
=
torch
.
randn
((
intermediate_size
,
hidden_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
down_proj
=
torch
.
randn
((
hidden_size
,
intermediate_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
down_proj
=
torch
.
randn
((
hidden_size
,
intermediate_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
).
to
(
"cpu"
).
contiguous
()
config
=
cpuinfer_ext
.
mlp
.
MLPConfig
(
hidden_size
,
intermediate_size
,
stride
,
gate_proj
.
data_ptr
(),
up_proj
.
data_ptr
(),
down_proj
.
data_ptr
(),
gate_type
,
up_type
,
down_type
,
hidden_type
)
config
=
cpuinfer_ext
.
mlp
.
MLPConfig
(
hidden_size
,
intermediate_size
,
stride
,
group_max_len
,
gate_proj
.
data_ptr
(),
up_proj
.
data_ptr
(),
down_proj
.
data_ptr
(),
gate_type
,
up_type
,
down_type
,
hidden_type
)
mlp
=
cpuinfer_ext
.
mlp
.
MLP
(
config
)
mlp
=
cpuinfer_ext
.
mlp
.
MLP
(
config
)
gate_projs
.
append
(
gate_proj
)
gate_projs
.
append
(
gate_proj
)
up_projs
.
append
(
up_proj
)
up_projs
.
append
(
up_proj
)
...
@@ -47,52 +57,26 @@ with torch.inference_mode(mode=True):
...
@@ -47,52 +57,26 @@ with torch.inference_mode(mode=True):
# validation
# validation
for
i
in
range
(
validation_iter
):
for
i
in
range
(
validation_iter
):
mlp
=
mlps
[
i
%
layer_num
]
mlp
=
mlps
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
input
=
input
/
100
CPUInfer
.
submit
(
mlp
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
submit
(
mlp
.
forward
(
qlen
,
input
.
data_ptr
(),
output
.
data_ptr
()
)
)
CPUInfer
.
sync
()
CPUInfer
.
sync
()
# print('cpuinfer output', output)
# print('cpuinfer output', output)
def
act_fn
(
x
):
return
x
/
(
1.0
+
torch
.
exp
(
-
x
))
gate_proj
=
gate_projs
[
i
%
layer_num
]
gate_proj
=
gate_projs
[
i
%
layer_num
]
up_proj
=
up_projs
[
i
%
layer_num
]
up_proj
=
up_projs
[
i
%
layer_num
]
down_proj
=
down_projs
[
i
%
layer_num
]
down_proj
=
down_projs
[
i
%
layer_num
]
gate_buf
=
torch
.
mm
(
input
,
gate_proj
.
t
())
t_output
=
mlp_torch
(
input
,
gate_proj
,
up_proj
,
down_proj
)
up_buf
=
torch
.
mm
(
input
,
up_proj
.
t
())
intermediate
=
act_fn
(
gate_buf
)
*
up_buf
t_output
=
torch
.
mm
(
intermediate
,
down_proj
.
t
())
# print('torch output', t_output)
# print('torch output', t_output)
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
print
(
'diff = '
,
diff
)
print
(
'diff = '
,
diff
)
assert
(
diff
<
0.001
)
assert
(
diff
<
0.001
)
# warm up
for
i
in
range
(
warm_up_iter
):
mlp
=
mlps
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
CPUInfer
.
submit
(
mlp
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
# test
total_time
=
0
for
i
in
range
(
test_iter
):
mlp
=
mlps
[
i
%
layer_num
]
input
=
torch
.
randn
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
start
=
time
.
time
()
CPUInfer
.
submit
(
mlp
.
forward
,
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
end
=
time
.
time
()
total_time
+=
end
-
start
print
(
'Time: '
,
total_time
)
print
(
'Iteration: '
,
test_iter
)
print
(
'Time per iteration: '
,
total_time
/
test_iter
)
print
(
'Bandwidth: '
,
hidden_size
*
intermediate_size
*
3
*
2
*
test_iter
/
total_time
/
1024
/
1024
/
1024
,
'GB/s'
)
print
(
"All tasks completed."
)
\ No newline at end of file
ktransformers/ktransformers_ext/examples/test_moe.py
View file @
77a34c28
...
@@ -6,7 +6,7 @@ Author : chenht2022
...
@@ -6,7 +6,7 @@ Author : chenht2022
Date : 2024-07-25 10:32:05
Date : 2024-07-25 10:32:05
Version : 1.0.0
Version : 1.0.0
LastEditors : chenht2022
LastEditors : chenht2022
LastEditTime : 2024-0
7-25
10:3
4
:0
6
LastEditTime : 2024-0
8-06
10:3
8
:0
5
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
'''
import
os
,
sys
import
os
,
sys
...
@@ -15,25 +15,64 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
...
@@ -15,25 +15,64 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
import
cpuinfer_ext
import
cpuinfer_ext
import
torch
import
torch
with
torch
.
inference_mode
(
mode
=
True
):
expert_num
=
160
expert_num
=
10
hidden_size
=
5120
hidden_size
=
5120
intermediate_size
=
1536
intermediate_size
=
1536
stride
=
32
stride
=
32
group_min_len
=
10
group_min_len
=
10
group_max_len
=
1024
group_max_len
=
1024
gate_type
=
1
# ggml_type::GGML_TYPE_F16
gate_type
=
1
# ggml_type::GGML_TYPE_F16
up_type
=
1
# ggml_type::GGML_TYPE_F16
up_type
=
1
# ggml_type::GGML_TYPE_F16
down_type
=
1
# ggml_type::GGML_TYPE_F16
down_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
hidden_type
=
1
# ggml_type::GGML_TYPE_F16
n_routed_experts
=
6
n_routed_experts
=
6
qlen
=
30
qlen
=
30
layer_num
=
10
layer_num
=
10
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
CPUInfer
=
cpuinfer_ext
.
CPUInfer
(
48
)
validation_iter
=
100
validation_iter
=
100
warm_up_iter
=
1000
def
act_fn
(
x
):
test_iter
=
10000
return
x
/
(
1.0
+
torch
.
exp
(
-
x
))
def
mlp_torch
(
input
,
gate_proj
,
up_proj
,
down_proj
):
gate_buf
=
torch
.
mm
(
input
,
gate_proj
.
t
())
up_buf
=
torch
.
mm
(
input
,
up_proj
.
t
())
intermediate
=
act_fn
(
gate_buf
)
*
up_buf
ret
=
torch
.
mm
(
intermediate
,
down_proj
.
t
())
return
ret
def
moe_torch
(
input
,
expert_ids
,
weights
,
gate_proj
,
up_proj
,
down_proj
):
cnts
=
expert_ids
.
new_zeros
((
expert_ids
.
shape
[
0
],
expert_num
))
cnts
.
scatter_
(
1
,
expert_ids
,
1
)
tokens_per_expert
=
cnts
.
sum
(
dim
=
0
)
idxs
=
expert_ids
.
view
(
-
1
).
argsort
()
sorted_tokens
=
input
[
idxs
//
expert_ids
.
shape
[
1
]]
outputs
=
[]
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
end_idx
=
start_idx
+
num_tokens
if
num_tokens
==
0
:
continue
tokens_for_this_expert
=
sorted_tokens
[
start_idx
:
end_idx
]
expert_out
=
mlp_torch
(
tokens_for_this_expert
,
gate_proj
[
i
],
up_proj
[
i
],
down_proj
[
i
])
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
outs
=
torch
.
cat
(
outputs
,
dim
=
0
)
if
len
(
outputs
)
else
sorted_tokens
.
new_empty
(
0
)
new_x
=
torch
.
empty_like
(
outs
)
new_x
[
idxs
]
=
outs
t_output
=
(
new_x
.
view
(
*
expert_ids
.
shape
,
-
1
)
.
type
(
weights
.
dtype
)
.
mul_
(
weights
.
unsqueeze
(
dim
=-
1
))
.
sum
(
dim
=
1
)
.
type
(
new_x
.
dtype
)
)
return
t_output
with
torch
.
inference_mode
(
mode
=
True
):
moes
=
[]
moes
=
[]
gate_projs
=
[]
gate_projs
=
[]
up_projs
=
[]
up_projs
=
[]
...
@@ -51,63 +90,32 @@ with torch.inference_mode(mode=True):
...
@@ -51,63 +90,32 @@ with torch.inference_mode(mode=True):
# validation
# validation
for
i
in
range
(
validation_iter
):
for
i
in
range
(
validation_iter
):
moe
=
moes
[
i
%
layer_num
]
expert_ids
=
torch
.
stack
([
torch
.
randperm
(
expert_num
)[:
n_routed_experts
]
for
_
in
range
(
qlen
)]).
contiguous
()
expert_ids
=
torch
.
randint
(
0
,
expert_num
,
(
qlen
,
n_routed_experts
),
dtype
=
torch
.
int64
).
contiguous
()
weights
=
torch
.
rand
((
qlen
,
n_routed_experts
),
dtype
=
torch
.
float32
).
contiguous
()
weights
=
torch
.
rand
((
qlen
,
n_routed_experts
),
dtype
=
torch
.
float32
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
1
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
input
=
input
/
100
CPUInfer
.
submit
(
moe
.
forward
,
qlen
,
n_routed_experts
,
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input
.
data_ptr
(),
output
.
data_ptr
())
moe
=
moes
[
i
%
layer_num
]
CPUInfer
.
submit
(
moe
.
forward
(
qlen
,
n_routed_experts
,
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input
.
data_ptr
(),
output
.
data_ptr
()
)
)
CPUInfer
.
sync
()
CPUInfer
.
sync
()
# print('cpuinfer output', output)
# print('cpuinfer output', output)
def
act_fn
(
x
):
return
x
/
(
1.0
+
torch
.
exp
(
-
x
))
t_output
=
torch
.
zeros
((
qlen
,
1
,
hidden_size
),
dtype
=
torch
.
float32
).
contiguous
()
gate_proj
=
gate_projs
[
i
%
layer_num
]
gate_proj
=
gate_projs
[
i
%
layer_num
]
up_proj
=
up_projs
[
i
%
layer_num
]
up_proj
=
up_projs
[
i
%
layer_num
]
down_proj
=
down_projs
[
i
%
layer_num
]
down_proj
=
down_projs
[
i
%
layer_num
]
for
token_idx
in
range
(
qlen
):
t_output
=
moe_torch
(
input
,
expert_ids
,
weights
,
gate_proj
,
up_proj
,
down_proj
)
for
i
,
expert_id
in
enumerate
(
expert_ids
[
token_idx
]):
gate_buf
=
torch
.
mm
(
input
[
token_idx
],
gate_proj
[
expert_id
].
t
())
up_buf
=
torch
.
mm
(
input
[
token_idx
],
up_proj
[
expert_id
].
t
())
intermediate
=
act_fn
(
gate_buf
)
*
up_buf
expert_output
=
torch
.
mm
(
intermediate
,
down_proj
[
expert_id
].
t
())
t_output
[
token_idx
]
+=
weights
[
token_idx
][
i
]
*
expert_output
# print('torch output', t_output)
# print('torch output', t_output)
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
diff
=
torch
.
mean
(
torch
.
abs
(
output
-
t_output
))
/
torch
.
mean
(
torch
.
abs
(
t_output
))
print
(
'diff = '
,
diff
)
print
(
'diff = '
,
diff
)
assert
(
diff
<
0.001
)
assert
(
diff
<
0.001
)
# warm up
for
i
in
range
(
warm_up_iter
):
moe
=
moes
[
i
%
layer_num
]
expert_ids
=
torch
.
randint
(
0
,
expert_num
,
(
qlen
,
n_routed_experts
),
dtype
=
torch
.
int64
).
contiguous
()
weights
=
torch
.
rand
((
qlen
,
n_routed_experts
),
dtype
=
torch
.
float32
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
CPUInfer
.
submit
(
moe
.
forward
,
qlen
,
n_routed_experts
,
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
# test
total_time
=
0
for
i
in
range
(
test_iter
):
moe
=
moes
[
i
%
layer_num
]
expert_ids
=
torch
.
randint
(
0
,
expert_num
,
(
qlen
,
n_routed_experts
),
dtype
=
torch
.
int64
).
contiguous
()
weights
=
torch
.
rand
((
qlen
,
n_routed_experts
),
dtype
=
torch
.
float32
).
contiguous
()
input
=
torch
.
randn
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
output
=
torch
.
empty
((
qlen
,
hidden_size
),
dtype
=
torch
.
float16
).
contiguous
()
input
=
input
/
100
start
=
time
.
perf_counter
()
CPUInfer
.
submit
(
moe
.
forward
,
qlen
,
n_routed_experts
,
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input
.
data_ptr
(),
output
.
data_ptr
())
CPUInfer
.
sync
()
end
=
time
.
perf_counter
()
total_time
+=
end
-
start
print
(
'Time: '
,
total_time
)
print
(
'Iteration: '
,
test_iter
)
print
(
'Time per iteration: '
,
total_time
/
test_iter
)
print
(
'Bandwidth: '
,
hidden_size
*
intermediate_size
*
3
*
n_routed_experts
*
2
*
test_iter
/
total_time
/
1000
/
1000
/
1000
,
'GB/s'
)
print
(
"All tasks completed."
)
\ No newline at end of file
ktransformers/ktransformers_ext/ext_bindings.cpp
View file @
77a34c28
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* @Date : 2024-07-22 02:03:22
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-0
7-25
10:3
4:23
* @LastEditTime : 2024-0
8-07
10:3
9:37
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
// Python bindings
// Python bindings
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
#include <iostream>
#include <iostream>
#include <memory>
#include <memory>
#include "cpu_backend/cpuinfer.h"
#include "cpu_backend/cpuinfer.h"
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include "device_launch_parameters.h"
#include "llamafile/flags.h"
#include "llamafile/flags.h"
#include "operators/llamafile/linear.h"
#include "operators/llamafile/linear.h"
...
@@ -26,239 +25,155 @@
...
@@ -26,239 +25,155 @@
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
using
namespace
pybind11
::
literals
;
using
namespace
pybind11
::
literals
;
// Binding functions for the Linear class
class
LinearBindings
{
class
LinearBindings
{
public:
public:
static
void
bind_forward
(
CPUInfer
&
cpuinfer
,
Linear
*
linear
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
class
WarmUpBindinds
{
auto
input
=
args
[
0
].
cast
<
intptr_t
>
();
public:
auto
output
=
args
[
1
].
cast
<
intptr_t
>
();
struct
Args
{
cpuinfer
.
submit
(
&
Linear
::
forward
,
linear
,
CPUInfer
*
cpuinfer
;
(
const
void
*
)
input
,
(
void
*
)
output
);
Linear
*
linear
;
}
};
static
void
inner
(
void
*
args
)
{
static
void
bind_warm_up
(
CPUInfer
&
cpuinfer
,
Linear
*
linear
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
Args
*
args_
=
(
Args
*
)
args
;
cpuinfer
.
submit
(
&
Linear
::
warm_up
,
linear
);
args_
->
cpuinfer
->
enqueue
(
&
Linear
::
warm_up
,
args_
->
linear
);
}
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
Linear
&
linear
)
{
static
void
bind_functions
(
CPUInfer
&
cpuinfer
,
py
::
object
func
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
linear
};
auto
linear
=
func
.
attr
(
"__self__"
).
cast
<
Linear
*>
();
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
std
::
string
func_name
=
py
::
str
(
func
.
attr
(
"__func__"
).
attr
(
"__name__"
));
}
};
if
(
func_name
==
"forward"
)
{
class
ForwardBindings
{
bind_forward
(
cpuinfer
,
linear
,
args
,
kwargs
);
public:
}
else
if
(
func_name
==
"warm_up"
)
{
struct
Args
{
bind_warm_up
(
cpuinfer
,
linear
,
args
,
kwargs
);
CPUInfer
*
cpuinfer
;
}
else
{
Linear
*
linear
;
throw
py
::
value_error
(
"Unsupported function: "
+
int
qlen
;
std
::
string
(
func_name
));
const
void
*
input
;
void
*
output
;
};
static
void
inner
(
void
*
args
)
{
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
Linear
::
forward
,
args_
->
linear
,
args_
->
qlen
,
args_
->
input
,
args_
->
output
);
}
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
Linear
&
linear
,
int
qlen
,
intptr_t
input
,
intptr_t
output
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
linear
,
qlen
,
(
const
void
*
)
input
,
(
void
*
)
output
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
}
};
};
};
// Binding functions for the MLP class
class
MLPBindings
{
class
MLPBindings
{
public:
public:
static
void
bind_forward
(
CPUInfer
&
cpuinfer
,
MLP
*
mlp
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
class
WarmUpBindinds
{
auto
input
=
args
[
0
].
cast
<
intptr_t
>
();
public:
auto
output
=
args
[
1
].
cast
<
intptr_t
>
();
struct
Args
{
cpuinfer
.
submit
(
&
MLP
::
forward
,
mlp
,
CPUInfer
*
cpuinfer
;
(
const
void
*
)
input
,
(
void
*
)
output
);
MLP
*
mlp
;
}
};
static
void
inner
(
void
*
args
)
{
static
void
bind_warm_up
(
CPUInfer
&
cpuinfer
,
MLP
*
mlp
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
Args
*
args_
=
(
Args
*
)
args
;
cpuinfer
.
submit
(
&
MLP
::
warm_up
,
mlp
);
args_
->
cpuinfer
->
enqueue
(
&
MLP
::
warm_up
,
args_
->
mlp
);
}
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
MLP
&
mlp
)
{
static
void
bind_functions
(
CPUInfer
&
cpuinfer
,
py
::
object
func
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
mlp
};
auto
mlp
=
func
.
attr
(
"__self__"
).
cast
<
MLP
*>
();
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
std
::
string
func_name
=
py
::
str
(
func
.
attr
(
"__func__"
).
attr
(
"__name__"
));
}
};
if
(
func_name
==
"forward"
)
{
class
ForwardBindings
{
bind_forward
(
cpuinfer
,
mlp
,
args
,
kwargs
);
public:
}
else
if
(
func_name
==
"warm_up"
)
{
struct
Args
{
bind_warm_up
(
cpuinfer
,
mlp
,
args
,
kwargs
);
CPUInfer
*
cpuinfer
;
}
else
{
MLP
*
mlp
;
throw
py
::
value_error
(
"Unsupported function: "
+
int
qlen
;
std
::
string
(
func_name
));
const
void
*
input
;
void
*
output
;
};
static
void
inner
(
void
*
args
)
{
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
MLP
::
forward
,
args_
->
mlp
,
args_
->
qlen
,
args_
->
input
,
args_
->
output
);
}
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
MLP
&
mlp
,
int
qlen
,
intptr_t
input
,
intptr_t
output
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
mlp
,
qlen
,
(
const
void
*
)
input
,
(
void
*
)
output
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
}
};
};
};
// Binding functions for the MOE class
class
MOEBindings
{
class
MOEBindings
{
public:
public:
static
void
bind_forward
(
CPUInfer
&
cpuinfer
,
MOE
*
moe
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
class
WarmUpBindinds
{
int
qlen
=
args
[
0
].
cast
<
int
>
();
public:
int
k
=
args
[
1
].
cast
<
int
>
();
struct
Args
{
auto
expert_ids
=
args
[
2
].
cast
<
intptr_t
>
();
CPUInfer
*
cpuinfer
;
auto
weights
=
args
[
3
].
cast
<
intptr_t
>
();
MOE
*
moe
;
auto
input
=
args
[
4
].
cast
<
intptr_t
>
();
};
auto
output
=
args
[
5
].
cast
<
intptr_t
>
();
static
void
inner
(
void
*
args
)
{
cpuinfer
.
submit
(
&
MOE
::
forward
,
moe
,
Args
*
args_
=
(
Args
*
)
args
;
qlen
,
k
,
(
const
uint64_t
*
)
expert_ids
,
(
const
float
*
)
weights
,
(
const
void
*
)
input
,
(
void
*
)
output
);
args_
->
cpuinfer
->
enqueue
(
&
MOE
::
warm_up
,
args_
->
moe
);
}
static
void
bind_warm_up
(
CPUInfer
&
cpuinfer
,
MOE
*
moe
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
cpuinfer
.
submit
(
&
MOE
::
warm_up
,
moe
);
}
static
void
bind_functions
(
CPUInfer
&
cpuinfer
,
py
::
object
func
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
auto
moe
=
func
.
attr
(
"__self__"
).
cast
<
MOE
*>
();
std
::
string
func_name
=
py
::
str
(
func
.
attr
(
"__func__"
).
attr
(
"__name__"
));
if
(
func_name
==
"forward"
)
{
bind_forward
(
cpuinfer
,
moe
,
args
,
kwargs
);
}
else
if
(
func_name
==
"warm_up"
)
{
bind_warm_up
(
cpuinfer
,
moe
,
args
,
kwargs
);
}
else
{
throw
py
::
value_error
(
"Unsupported function: "
+
std
::
string
(
func_name
));
}
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
MOE
&
moe
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
moe
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
}
};
};
class
ForwardBindings
{
struct
MOEForwardArgs
{
public:
struct
Args
{
CPUInfer
*
cpuinfer
;
CPUInfer
*
cpuinfer
;
MOE
*
moe
;
MOE
*
moe
;
int
qlen
;
int
qlen
;
int
k
;
int
k
;
uint64_t
*
expert_ids
;
const
uint64_t
*
expert_ids
;
float
*
weights
;
const
float
*
weights
;
void
*
input
;
const
void
*
input
;
void
*
output
;
void
*
output
;
};
static
void
inner
(
void
*
args
)
{
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
MOE
::
forward
,
args_
->
moe
,
args_
->
qlen
,
args_
->
k
,
args_
->
expert_ids
,
args_
->
weights
,
args_
->
input
,
args_
->
output
);
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_interface
(
MOE
&
moe
,
int
qlen
,
int
k
,
intptr_t
expert_ids
,
intptr_t
weights
,
intptr_t
input
,
intptr_t
output
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
moe
,
qlen
,
k
,
(
const
uint64_t
*
)
expert_ids
,
(
const
float
*
)
weights
,
(
const
void
*
)
input
,
(
void
*
)
output
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
};
};
};
void
submit_moe_forward_with_host_args_ptr
(
void
*
host_args_ptr
)
{
MOEForwardArgs
*
host_args
=
(
MOEForwardArgs
*
)
host_args_ptr
;
host_args
->
cpuinfer
->
submit
(
&
MOE
::
forward
,
host_args
->
moe
,
host_args
->
qlen
,
host_args
->
k
,
host_args
->
expert_ids
,
host_args
->
weights
,
host_args
->
input
,
host_args
->
output
);
}
void
cpuinfer_sync
(
void
*
host_args_ptr
)
{
CPUInfer
*
cpuinfer
=
(
CPUInfer
*
)
host_args_ptr
;
cpuinfer
->
sync
();
}
PYBIND11_MODULE
(
cpuinfer_ext
,
m
)
{
PYBIND11_MODULE
(
cpuinfer_ext
,
m
)
{
auto
linear_module
=
m
.
def_submodule
(
"linear"
);
py
::
class_
<
CPUInfer
>
(
m
,
"CPUInfer"
)
.
def
(
py
::
init
<
int
>
())
.
def
(
"submit"
,
&
CPUInfer
::
submit
)
.
def
(
"submit_with_cuda_stream"
,
&
CPUInfer
::
submit_with_cuda_stream
)
.
def
(
"sync"
,
&
CPUInfer
::
sync
)
.
def
(
"sync_with_cuda_stream"
,
&
CPUInfer
::
sync_with_cuda_stream
);
auto
linear_module
=
m
.
def_submodule
(
"linear"
);
py
::
class_
<
LinearConfig
>
(
linear_module
,
"LinearConfig"
)
py
::
class_
<
LinearConfig
>
(
linear_module
,
"LinearConfig"
)
.
def
(
py
::
init
([](
int
hidden_size
,
int
intermediate_size
,
int
stride
,
intptr_t
proj
,
int
proj_type
,
int
hidden_type
)
{
.
def
(
py
::
init
([](
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_max_len
,
intptr_t
proj
,
int
proj_type
,
int
hidden_type
)
{
return
LinearConfig
(
hidden_size
,
intermediate_size
,
stride
,
(
void
*
)
proj
,
(
ggml_type
)
proj_type
,
(
ggml_type
)
hidden_type
);
return
LinearConfig
(
hidden_size
,
intermediate_size
,
stride
,
group_max_len
,
(
void
*
)
proj
,
(
ggml_type
)
proj_type
,
(
ggml_type
)
hidden_type
);
}));
}));
py
::
class_
<
Linear
>
(
linear_module
,
"Linear"
)
py
::
class_
<
Linear
>
(
linear_module
,
"Linear"
)
.
def
(
py
::
init
<
LinearConfig
>
())
.
def
(
py
::
init
<
LinearConfig
>
())
.
def
(
"warm_up"
,
[](
Linear
&
linear
)
{
.
def
(
"warm_up"
,
&
LinearBindings
::
WarmUpBindinds
::
cpuinfer_interface
)
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
.
def
(
"forward"
,
&
LinearBindings
::
ForwardBindings
::
cpuinfer_interface
);
})
.
def
(
"forward"
,
[](
Linear
&
linear
,
intptr_t
input
,
intptr_t
output
)
{
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
});
auto
mlp_module
=
m
.
def_submodule
(
"mlp"
);
auto
mlp_module
=
m
.
def_submodule
(
"mlp"
);
py
::
class_
<
MLPConfig
>
(
mlp_module
,
"MLPConfig"
)
py
::
class_
<
MLPConfig
>
(
mlp_module
,
"MLPConfig"
)
.
def
(
py
::
init
([](
int
hidden_size
,
int
intermediate_size
,
int
stride
,
intptr_t
gate_proj
,
intptr_t
up_proj
,
intptr_t
down_proj
,
int
gate_type
,
int
up_type
,
int
down_type
,
int
hidden_type
)
{
.
def
(
py
::
init
([](
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_max_len
,
intptr_t
gate_proj
,
intptr_t
up_proj
,
intptr_t
down_proj
,
int
gate_type
,
int
up_type
,
int
down_type
,
int
hidden_type
)
{
return
MLPConfig
(
hidden_size
,
intermediate_size
,
stride
,
(
void
*
)
gate_proj
,
(
void
*
)
up_proj
,
(
void
*
)
down_proj
,
(
ggml_type
)
gate_type
,
(
ggml_type
)
up_type
,
(
ggml_type
)
down_type
,
(
ggml_type
)
hidden_type
);
return
MLPConfig
(
hidden_size
,
intermediate_size
,
stride
,
group_max_len
,
(
void
*
)
gate_proj
,
(
void
*
)
up_proj
,
(
void
*
)
down_proj
,
(
ggml_type
)
gate_type
,
(
ggml_type
)
up_type
,
(
ggml_type
)
down_type
,
(
ggml_type
)
hidden_type
);
}));
}));
py
::
class_
<
MLP
>
(
mlp_module
,
"MLP"
)
py
::
class_
<
MLP
>
(
mlp_module
,
"MLP"
)
.
def
(
py
::
init
<
MLPConfig
>
())
.
def
(
py
::
init
<
MLPConfig
>
())
.
def
(
"warm_up"
,
[](
MLP
&
mlp
)
{
.
def
(
"warm_up"
,
&
MLPBindings
::
WarmUpBindinds
::
cpuinfer_interface
)
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
.
def
(
"forward"
,
&
MLPBindings
::
ForwardBindings
::
cpuinfer_interface
);
})
.
def
(
"forward"
,
[](
MLP
&
mlp
,
intptr_t
input
,
intptr_t
output
)
{
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
});
auto
moe_module
=
m
.
def_submodule
(
"moe"
);
auto
moe_module
=
m
.
def_submodule
(
"moe"
);
py
::
class_
<
MOEConfig
>
(
moe_module
,
"MOEConfig"
)
py
::
class_
<
MOEConfig
>
(
moe_module
,
"MOEConfig"
)
.
def
(
py
::
init
([](
int
expert_num
,
int
routed_expert_num
,
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_min_len
,
int
group_max_len
,
intptr_t
gate_proj
,
intptr_t
up_proj
,
intptr_t
down_proj
,
int
gate_type
,
int
up_type
,
int
down_type
,
int
hidden_type
)
{
.
def
(
py
::
init
([](
int
expert_num
,
int
routed_expert_num
,
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_min_len
,
int
group_max_len
,
intptr_t
gate_proj
,
intptr_t
up_proj
,
intptr_t
down_proj
,
int
gate_type
,
int
up_type
,
int
down_type
,
int
hidden_type
)
{
return
MOEConfig
(
expert_num
,
routed_expert_num
,
hidden_size
,
intermediate_size
,
stride
,
group_min_len
,
group_max_len
,
(
void
*
)
gate_proj
,
(
void
*
)
up_proj
,
(
void
*
)
down_proj
,
(
ggml_type
)
gate_type
,
(
ggml_type
)
up_type
,
(
ggml_type
)
down_type
,
(
ggml_type
)
hidden_type
);
return
MOEConfig
(
expert_num
,
routed_expert_num
,
hidden_size
,
intermediate_size
,
stride
,
group_min_len
,
group_max_len
,
(
void
*
)
gate_proj
,
(
void
*
)
up_proj
,
(
void
*
)
down_proj
,
(
ggml_type
)
gate_type
,
(
ggml_type
)
up_type
,
(
ggml_type
)
down_type
,
(
ggml_type
)
hidden_type
);
}));
}));
py
::
class_
<
MOE
>
(
moe_module
,
"MOE"
)
py
::
class_
<
MOE
>
(
moe_module
,
"MOE"
)
.
def
(
py
::
init
<
MOEConfig
>
())
.
def
(
py
::
init
<
MOEConfig
>
())
.
def
(
"warm_up"
,
[](
MOE
&
moe
)
{
.
def
(
"warm_up"
,
&
MOEBindings
::
WarmUpBindinds
::
cpuinfer_interface
)
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
.
def
(
"forward"
,
&
MOEBindings
::
ForwardBindings
::
cpuinfer_interface
);
})
.
def
(
"forward"
,
[](
MOE
&
moe
,
int
k
,
uint64_t
expert_ids
,
intptr_t
weights
,
intptr_t
input
,
intptr_t
output
)
{
throw
std
::
runtime_error
(
"!!! Doing nothing, please use CPUInfer.submit to call it!!!
\n
"
);
});
py
::
class_
<
CPUInfer
>
(
m
,
"CPUInfer"
)
.
def
(
py
::
init
<
int
>
())
.
def
(
"submit"
,
[
linear_module
,
mlp_module
,
moe_module
](
CPUInfer
&
cpuinfer
,
py
::
object
func
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
if
(
py
::
hasattr
(
func
,
"__self__"
)
&&
py
::
hasattr
(
func
,
"__func__"
))
{
std
::
string
class_name
=
py
::
str
(
func
.
attr
(
"__self__"
)
.
attr
(
"__class__"
)
.
attr
(
"__name__"
));
if
(
class_name
==
"Linear"
)
{
LinearBindings
::
bind_functions
(
cpuinfer
,
func
,
args
,
kwargs
);
}
else
if
(
class_name
==
"MLP"
)
{
MLPBindings
::
bind_functions
(
cpuinfer
,
func
,
args
,
kwargs
);
}
else
if
(
class_name
==
"MOE"
)
{
MOEBindings
::
bind_functions
(
cpuinfer
,
func
,
args
,
kwargs
);
}
else
{
// handle other classes
throw
py
::
type_error
(
"Unsupported class type: "
+
class_name
);
}
}
else
{
// handle cases where func does not have __self__ or
// __func__
throw
py
::
type_error
(
"Invalid function object: missing "
"__self__ or __func__ attribute."
);
}
})
.
def
(
"submit_with_cuda_stream"
,
[
linear_module
,
mlp_module
,
moe_module
](
CPUInfer
&
cpuinfer
,
intptr_t
user_cuda_stream
,
py
::
object
func
,
py
::
args
args
,
py
::
kwargs
kwargs
)
{
if
(
py
::
hasattr
(
func
,
"__self__"
)
&&
py
::
hasattr
(
func
,
"__func__"
))
{
std
::
string
class_name
=
py
::
str
(
func
.
attr
(
"__self__"
)
.
attr
(
"__class__"
)
.
attr
(
"__name__"
));
if
(
class_name
==
"MOE"
)
{
std
::
string
func_name
=
py
::
str
(
func
.
attr
(
"__func__"
).
attr
(
"__name__"
));
if
(
func_name
==
"forward"
)
{
auto
moe
=
func
.
attr
(
"__self__"
).
cast
<
MOE
*>
();
int
qlen
=
args
[
0
].
cast
<
int
>
();
int
k
=
args
[
1
].
cast
<
int
>
();
auto
expert_ids
=
args
[
2
].
cast
<
intptr_t
>
();
auto
weights
=
args
[
3
].
cast
<
intptr_t
>
();
auto
input
=
args
[
4
].
cast
<
intptr_t
>
();
auto
output
=
args
[
5
].
cast
<
intptr_t
>
();
MOEForwardArgs
*
moe_forward_args
=
new
MOEForwardArgs
{
&
cpuinfer
,
moe
,
qlen
,
k
,
(
uint64_t
*
)
expert_ids
,
(
float
*
)
weights
,
(
void
*
)
input
,
(
void
*
)
output
};
// submit_moe_forward_with_host_args_ptr(moe_forward_args);
cudaLaunchHostFunc
((
cudaStream_t
)
user_cuda_stream
,
(
cudaHostFn_t
)
submit_moe_forward_with_host_args_ptr
,
moe_forward_args
);
}
else
{
throw
py
::
value_error
(
"Unsupported function: "
+
std
::
string
(
func_name
));
}
}
else
{
// handle other classes
throw
py
::
type_error
(
"Unsupported class type: "
+
class_name
);
}
}
else
{
// handle cases where func does not have __self__ or
// __func__
throw
py
::
type_error
(
"Invalid function object: missing "
"__self__ or __func__ attribute."
);
}
})
.
def
(
"sync_with_cuda_stream"
,
[](
CPUInfer
&
cpuinfer
,
intptr_t
user_cuda_stream
)
{
// cpuinfer_sync((void*)(&cpuinfer));
cudaLaunchHostFunc
((
cudaStream_t
)
user_cuda_stream
,
(
cudaHostFn_t
)
cpuinfer_sync
,
(
void
*
)(
&
cpuinfer
));
})
.
def
(
"sync"
,
&
CPUInfer
::
sync
);
}
}
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py
deleted
100644 → 0
View file @
44f57270
import
math
import
os
import
time
from
logging
import
getLogger
import
torch
import
torch.nn
as
nn
import
transformers
from
.quantizer
import
Quantizer
logger
=
getLogger
(
__name__
)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
class
GPTQ
:
def
__init__
(
self
,
layer
):
self
.
layer
=
layer
self
.
dev
=
self
.
layer
.
weight
.
device
W
=
layer
.
weight
.
data
.
clone
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
W
=
W
.
flatten
(
1
)
if
isinstance
(
self
.
layer
,
transformers
.
pytorch_utils
.
Conv1D
):
W
=
W
.
t
()
self
.
rows
=
W
.
shape
[
0
]
self
.
columns
=
W
.
shape
[
1
]
self
.
H
=
torch
.
zeros
((
self
.
columns
,
self
.
columns
),
device
=
self
.
dev
)
self
.
nsamples
=
0
self
.
quantizer
=
Quantizer
()
def
add_batch
(
self
,
inp
,
out
):
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
inp1
=
inp
self
.
out1
=
out
if
len
(
inp
.
shape
)
==
2
:
inp
=
inp
.
unsqueeze
(
0
)
tmp
=
inp
.
shape
[
0
]
if
isinstance
(
self
.
layer
,
nn
.
Linear
)
or
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
if
len
(
inp
.
shape
)
==
3
:
inp
=
inp
.
reshape
((
-
1
,
inp
.
shape
[
-
1
]))
inp
=
inp
.
t
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
unfold
=
nn
.
Unfold
(
self
.
layer
.
kernel_size
,
dilation
=
self
.
layer
.
dilation
,
padding
=
self
.
layer
.
padding
,
stride
=
self
.
layer
.
stride
,
)
inp
=
unfold
(
inp
)
inp
=
inp
.
permute
([
1
,
0
,
2
])
inp
=
inp
.
flatten
(
1
)
self
.
H
*=
self
.
nsamples
/
(
self
.
nsamples
+
tmp
)
self
.
nsamples
+=
tmp
# inp = inp.float()
inp
=
math
.
sqrt
(
2
/
self
.
nsamples
)
*
inp
.
float
()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self
.
H
+=
inp
.
matmul
(
inp
.
t
())
def
fasterquant
(
self
,
blocksize
=
128
,
percdamp
=
0.01
,
group_size
=-
1
,
actorder
=
False
,
static_groups
=
False
,
):
W
=
self
.
layer
.
weight
.
data
.
clone
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
W
=
W
.
flatten
(
1
)
if
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
W
=
W
.
t
()
W
=
W
.
float
()
tick
=
time
.
time
()
if
not
self
.
quantizer
.
ready
():
self
.
quantizer
.
find_params
(
W
,
weight
=
True
)
H
=
self
.
H
del
self
.
H
dead
=
torch
.
diag
(
H
)
==
0
H
[
dead
,
dead
]
=
1
W
[:,
dead
]
=
0
g_idx
=
[]
scale
=
[]
zero
=
[]
now_idx
=
1
if
static_groups
:
import
copy
groups
=
[]
for
i
in
range
(
0
,
self
.
columns
,
group_size
):
quantizer
=
copy
.
deepcopy
(
self
.
quantizer
)
quantizer
.
find_params
(
W
[:,
i
:
(
i
+
group_size
)],
weight
=
True
)
scale
.
append
(
quantizer
.
scale
)
zero
.
append
(
quantizer
.
zero
)
groups
.
append
(
quantizer
)
if
actorder
:
perm
=
torch
.
argsort
(
torch
.
diag
(
H
),
descending
=
True
)
W
=
W
[:,
perm
]
H
=
H
[
perm
][:,
perm
]
invperm
=
torch
.
argsort
(
perm
)
Losses
=
torch
.
zeros_like
(
W
)
Q
=
torch
.
zeros_like
(
W
)
damp
=
percdamp
*
torch
.
mean
(
torch
.
diag
(
H
))
diag
=
torch
.
arange
(
self
.
columns
,
device
=
self
.
dev
)
H
[
diag
,
diag
]
+=
damp
H
=
torch
.
linalg
.
cholesky
(
H
)
H
=
torch
.
cholesky_inverse
(
H
)
H
=
torch
.
linalg
.
cholesky
(
H
,
upper
=
True
)
Hinv
=
H
for
i1
in
range
(
0
,
self
.
columns
,
blocksize
):
i2
=
min
(
i1
+
blocksize
,
self
.
columns
)
count
=
i2
-
i1
W1
=
W
[:,
i1
:
i2
].
clone
()
Q1
=
torch
.
zeros_like
(
W1
)
Err1
=
torch
.
zeros_like
(
W1
)
Losses1
=
torch
.
zeros_like
(
W1
)
Hinv1
=
Hinv
[
i1
:
i2
,
i1
:
i2
]
for
i
in
range
(
count
):
w
=
W1
[:,
i
]
d
=
Hinv1
[
i
,
i
]
if
group_size
!=
-
1
:
if
not
static_groups
:
if
(
i1
+
i
)
%
group_size
==
0
:
self
.
quantizer
.
find_params
(
W
[:,
(
i1
+
i
)
:
(
i1
+
i
+
group_size
)],
weight
=
True
)
if
((
i1
+
i
)
//
group_size
)
-
now_idx
==
-
1
:
scale
.
append
(
self
.
quantizer
.
scale
)
zero
.
append
(
self
.
quantizer
.
zero
)
now_idx
+=
1
else
:
idx
=
i1
+
i
if
actorder
:
idx
=
perm
[
idx
]
self
.
quantizer
=
groups
[
idx
//
group_size
]
q
=
self
.
quantizer
.
quantize
(
w
.
unsqueeze
(
1
)).
flatten
()
Q1
[:,
i
]
=
q
Losses1
[:,
i
]
=
(
w
-
q
)
**
2
/
d
**
2
err1
=
(
w
-
q
)
/
d
W1
[:,
i
:]
-=
err1
.
unsqueeze
(
1
).
matmul
(
Hinv1
[
i
,
i
:].
unsqueeze
(
0
))
Err1
[:,
i
]
=
err1
Q
[:,
i1
:
i2
]
=
Q1
Losses
[:,
i1
:
i2
]
=
Losses1
/
2
W
[:,
i2
:]
-=
Err1
.
matmul
(
Hinv
[
i1
:
i2
,
i2
:])
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
layer
.
weight
.
data
[:,
:
i2
]
=
Q
[:,
:
i2
]
self
.
layer
.
weight
.
data
[:,
i2
:]
=
W
[:,
i2
:]
logger
.
debug
(
torch
.
sum
((
self
.
layer
(
self
.
inp1
)
-
self
.
out1
)
**
2
))
logger
.
debug
(
torch
.
sum
(
Losses
))
torch
.
cuda
.
synchronize
()
logger
.
info
(
f
"duration:
{
(
time
.
time
()
-
tick
)
}
"
)
logger
.
info
(
f
"avg loss:
{
torch
.
sum
(
Losses
).
item
()
/
self
.
nsamples
}
"
)
group_size
=
group_size
if
group_size
!=
-
1
else
self
.
columns
if
static_groups
and
actorder
:
g_idx
=
[
perm
[
i
]
//
group_size
for
i
in
range
(
self
.
columns
)]
else
:
g_idx
=
[
i
//
group_size
for
i
in
range
(
self
.
columns
)]
g_idx
=
torch
.
tensor
(
g_idx
,
dtype
=
torch
.
int32
,
device
=
Q
.
device
)
if
actorder
:
Q
=
Q
[:,
invperm
]
g_idx
=
g_idx
[
invperm
]
if
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
Q
=
Q
.
t
()
self
.
layer
.
weight
.
data
=
Q
.
reshape
(
self
.
layer
.
weight
.
shape
).
type_as
(
self
.
layer
.
weight
.
data
)
if
os
.
environ
.
get
(
"DEBUG"
):
logger
.
debug
(
torch
.
sum
((
self
.
layer
(
self
.
inp1
)
-
self
.
out1
)
**
2
))
if
scale
==
[]:
scale
.
append
(
self
.
quantizer
.
scale
)
zero
.
append
(
self
.
quantizer
.
zero
)
scale
=
torch
.
cat
(
scale
,
dim
=
1
)
zero
=
torch
.
cat
(
zero
,
dim
=
1
)
return
scale
,
zero
,
g_idx
def
free
(
self
):
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
inp1
=
None
self
.
out1
=
None
self
.
H
=
None
self
.
Losses
=
None
self
.
Trace
=
None
torch
.
cuda
.
empty_cache
()
__all__
=
[
"GPTQ"
]
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py
deleted
100644 → 0
View file @
44f57270
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
get_pack_factor
(
num_bits
:
int
):
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"Marlin does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"Marlin does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
raise
ValueError
(
f
"Marlin does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
# Init
self
.
pack_factor
=
get_pack_factor
(
weight_bits
)
self
.
tile_size
=
GPTQ_MARLIN_TILE
self
.
min_thread_n
=
GPTQ_MARLIN_MIN_THREAD_N
self
.
min_thread_k
=
GPTQ_MARLIN_MIN_THREAD_K
self
.
max_parallel
=
GPTQ_MARLIN_MAX_PARALLEL
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"gptq"
:
logger
.
info
(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
is_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
return
False
# If the capability of the device is too low, cannot convert.
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
cls
.
get_min_capability
():
return
False
# Otherwise, can convert if model satisfies marlin constraints.
return
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and
sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
# Validate dtype
if
params_dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]:
raise
ValueError
(
f
"The params dtype must be float16 "
f
"or bfloat16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_thread_n
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
self
.
quant_config
.
min_thread_n
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_thread_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
self
.
quant_config
.
min_thread_k
}
."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
)
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size
=
input_size
//
group_size
scales_and_zp_input_dim
=
None
if
self
.
quant_config
.
desc_act
:
# Act-order case
assert
self
.
quant_config
.
group_size
!=
-
1
is_k_full
=
input_size_per_partition
==
input_size
else
:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full
=
True
# If this is a row-parallel case, then shard scales/zp
if
(
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_input_dim
=
0
# Init buffers
# Quantized weights
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Activation order
g_idx
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
g_idx_sort_indices
=
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
},
)
# Quantized zero-points
qzeros
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
device
=
"meta"
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Allocate marlin workspace
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_thread_n
)
*
self
.
quant_config
.
max_parallel
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
layer
.
workspace
=
workspace
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
is_k_full
=
is_k_full
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py
deleted
100644 → 0
View file @
44f57270
from
logging
import
getLogger
import
torch
import
torch.nn
as
nn
logger
=
getLogger
(
__name__
)
def
quantize
(
x
,
scale
,
zero
,
maxq
):
if
maxq
<
0
:
return
(
x
>
scale
/
2
).
float
()
*
scale
+
(
x
<
zero
/
2
).
float
()
*
zero
q
=
torch
.
clamp
(
torch
.
round
(
x
/
scale
)
+
zero
,
0
,
maxq
)
return
scale
*
(
q
-
zero
)
class
Quantizer
(
nn
.
Module
):
def
__init__
(
self
,
shape
=
1
):
super
(
Quantizer
,
self
).
__init__
()
self
.
register_buffer
(
"maxq"
,
torch
.
tensor
(
0
))
self
.
register_buffer
(
"scale"
,
torch
.
zeros
(
shape
))
self
.
register_buffer
(
"zero"
,
torch
.
zeros
(
shape
))
def
configure
(
self
,
bits
,
perchannel
=
False
,
sym
=
True
,
mse
=
False
,
norm
=
2.4
,
grid
=
100
,
maxshrink
=
0.8
,
trits
=
False
,
):
self
.
maxq
=
torch
.
tensor
(
2
**
bits
-
1
)
self
.
perchannel
=
perchannel
self
.
sym
=
sym
self
.
mse
=
mse
self
.
norm
=
norm
self
.
grid
=
grid
self
.
maxshrink
=
maxshrink
if
trits
:
self
.
maxq
=
torch
.
tensor
(
-
1
)
def
find_params
(
self
,
x
,
weight
=
False
):
dev
=
x
.
device
self
.
maxq
=
self
.
maxq
.
to
(
dev
)
shape
=
x
.
shape
if
self
.
perchannel
:
if
weight
:
x
=
x
.
flatten
(
1
)
else
:
if
len
(
shape
)
==
4
:
x
=
x
.
permute
([
1
,
0
,
2
,
3
])
x
=
x
.
flatten
(
1
)
if
len
(
shape
)
==
3
:
x
=
x
.
reshape
((
-
1
,
shape
[
-
1
])).
t
()
if
len
(
shape
)
==
2
:
x
=
x
.
t
()
else
:
x
=
x
.
flatten
().
unsqueeze
(
0
)
tmp
=
torch
.
zeros
(
x
.
shape
[
0
],
device
=
dev
)
xmin
=
torch
.
minimum
(
x
.
min
(
1
)[
0
],
tmp
)
xmax
=
torch
.
maximum
(
x
.
max
(
1
)[
0
],
tmp
)
if
self
.
sym
:
xmax
=
torch
.
maximum
(
torch
.
abs
(
xmin
),
xmax
)
tmp
=
xmin
<
0
if
torch
.
any
(
tmp
):
xmin
[
tmp
]
=
-
xmax
[
tmp
]
tmp
=
(
xmin
==
0
)
&
(
xmax
==
0
)
xmin
[
tmp
]
=
-
1
xmax
[
tmp
]
=
+
1
if
self
.
maxq
<
0
:
self
.
scale
=
xmax
self
.
zero
=
xmin
else
:
self
.
scale
=
(
xmax
-
xmin
)
/
self
.
maxq
if
self
.
sym
:
self
.
zero
=
torch
.
full_like
(
self
.
scale
,
(
self
.
maxq
+
1
)
/
2
)
else
:
self
.
zero
=
torch
.
round
(
-
xmin
/
self
.
scale
)
if
self
.
mse
:
best
=
torch
.
full
([
x
.
shape
[
0
]],
float
(
"inf"
),
device
=
dev
)
for
i
in
range
(
int
(
self
.
maxshrink
*
self
.
grid
)):
p
=
1
-
i
/
self
.
grid
xmin1
=
p
*
xmin
xmax1
=
p
*
xmax
scale1
=
(
xmax1
-
xmin1
)
/
self
.
maxq
zero1
=
torch
.
round
(
-
xmin1
/
scale1
)
if
not
self
.
sym
else
self
.
zero
q
=
quantize
(
x
,
scale1
.
unsqueeze
(
1
),
zero1
.
unsqueeze
(
1
),
self
.
maxq
)
q
-=
x
q
.
abs_
()
q
.
pow_
(
self
.
norm
)
err
=
torch
.
sum
(
q
,
1
)
tmp
=
err
<
best
if
torch
.
any
(
tmp
):
best
[
tmp
]
=
err
[
tmp
]
self
.
scale
[
tmp
]
=
scale1
[
tmp
]
self
.
zero
[
tmp
]
=
zero1
[
tmp
]
if
not
self
.
perchannel
:
if
weight
:
tmp
=
shape
[
0
]
else
:
tmp
=
shape
[
1
]
if
len
(
shape
)
!=
3
else
shape
[
2
]
self
.
scale
=
self
.
scale
.
repeat
(
tmp
)
self
.
zero
=
self
.
zero
.
repeat
(
tmp
)
if
weight
:
shape
=
[
-
1
]
+
[
1
]
*
(
len
(
shape
)
-
1
)
self
.
scale
=
self
.
scale
.
reshape
(
shape
)
self
.
zero
=
self
.
zero
.
reshape
(
shape
)
return
if
len
(
shape
)
==
4
:
self
.
scale
=
self
.
scale
.
reshape
((
1
,
-
1
,
1
,
1
))
self
.
zero
=
self
.
zero
.
reshape
((
1
,
-
1
,
1
,
1
))
if
len
(
shape
)
==
3
:
self
.
scale
=
self
.
scale
.
reshape
((
1
,
1
,
-
1
))
self
.
zero
=
self
.
zero
.
reshape
((
1
,
1
,
-
1
))
if
len
(
shape
)
==
2
:
self
.
scale
=
self
.
scale
.
unsqueeze
(
0
)
self
.
zero
=
self
.
zero
.
unsqueeze
(
0
)
def
quantize
(
self
,
x
):
if
self
.
ready
():
return
quantize
(
x
,
self
.
scale
,
self
.
zero
,
self
.
maxq
)
return
x
def
enabled
(
self
):
return
self
.
maxq
>
0
def
ready
(
self
):
return
torch
.
all
(
self
.
scale
!=
0
)
__all__
=
[
"Quantizer"
]
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py
deleted
100644 → 0
View file @
44f57270
import
torch
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
torch.nn.parameter
import
Parameter
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
View file @
77a34c28
...
@@ -220,7 +220,7 @@ def compute_max_diff(output, output_ref):
...
@@ -220,7 +220,7 @@ def compute_max_diff(output, output_ref):
class
MarlinWorkspace
:
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
,
device
):
assert
(
out_features
%
min_thread_n
==
0
),
(
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
out_features
,
min_thread_n
))
...
@@ -229,4 +229,4 @@ class MarlinWorkspace:
...
@@ -229,4 +229,4 @@ class MarlinWorkspace:
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
device
=
device
)
ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
View file @
77a34c28
...
@@ -13,9 +13,15 @@ Linear::Linear(LinearConfig config) {
...
@@ -13,9 +13,15 @@ Linear::Linear(LinearConfig config) {
config_
=
config
;
config_
=
config
;
proj_
=
config_
.
proj
;
proj_
=
config_
.
proj
;
input_fp32_
.
resize
(
config_
.
input_size
);
std
::
vector
<
std
::
pair
<
void
**
,
uint64_t
>>
mem_requests
;
proj_input_
.
resize
(
config_
.
input_size
*
4
);
mem_requests
.
push_back
({(
void
**
)
&
input_fp32_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
input_size
});
proj_output_
.
resize
(
config_
.
output_size
);
mem_requests
.
push_back
({(
void
**
)
&
proj_input_
,
config_
.
group_max_len
*
config_
.
input_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
)});
mem_requests
.
push_back
({(
void
**
)
&
proj_output_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
output_size
});
shared_mem_buffer
.
alloc
(
this
,
mem_requests
);
}
Linear
::~
Linear
()
{
shared_mem_buffer
.
dealloc
(
this
);
}
}
void
Linear
::
warm_up
(
Backend
*
backend
)
{
void
Linear
::
warm_up
(
Backend
*
backend
)
{
...
@@ -26,22 +32,42 @@ void Linear::warm_up(Backend* backend) {
...
@@ -26,22 +32,42 @@ void Linear::warm_up(Backend* backend) {
input_fp32
[
i
]
=
0
;
input_fp32
[
i
]
=
0
;
}
}
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
input_size
,
config_
.
hidden_type
);
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
input_size
,
config_
.
hidden_type
);
forward
(
input
.
data
(),
output
.
data
(),
backend
);
forward
_many
(
1
,
input
.
data
(),
output
.
data
(),
backend
);
}
}
void
Linear
::
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
void
Linear
::
forward
_many
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
const
void
*
proj_input_ptr
;
const
void
*
proj_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
)
{
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
)
{
proj_input_ptr
=
input
;
proj_input_ptr
=
input
;
}
else
{
}
else
{
to_float
(
input
,
input_fp32_
.
data
(),
config_
.
input_size
,
config_
.
hidden_type
);
to_float
(
input
,
input_fp32_
,
qlen
*
config_
.
input_size
,
config_
.
hidden_type
);
from_float
(
input_fp32_
.
data
()
,
proj_input_
.
data
(),
config_
.
input_size
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
);
from_float
(
input_fp32_
,
proj_input_
,
qlen
*
config_
.
input_size
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
);
proj_input_ptr
=
proj_input_
.
data
()
;
proj_input_ptr
=
proj_input_
;
}
}
int
nth
=
config_
.
output_size
/
config_
.
stride
;
int
nth
=
config_
.
output_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
%
nth
;
int
ith
=
task_id
;
llamafile_sgemm
(
config_
.
output_size
,
1
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_input_ptr
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_output_
.
data
(),
config_
.
output_size
,
ith
,
nth
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
proj_type
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
proj_ptr
=
(
uint8_t
*
)
proj_
+
ith
*
config_
.
stride
*
config_
.
input_size
*
ggml_type_size
(
config_
.
proj_type
)
/
ggml_blck_size
(
config_
.
proj_type
);
float
*
proj_output_ptr
=
proj_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_ptr
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_input_ptr
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_output_ptr
,
config_
.
output_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
proj_type
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
float
*
output_fp32_ptr
=
proj_output_
+
i
*
config_
.
output_size
+
ith
*
config_
.
stride
;
void
*
output_ptr
=
(
uint8_t
*
)
output
+
i
*
config_
.
output_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
from_float
(
output_fp32_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
}
}
});
});
from_float
(
proj_output_
.
data
(),
output
,
config_
.
output_size
,
config_
.
hidden_type
);
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
!=
0
)
{
from_float
(
proj_output_
,
output
,
qlen
*
config_
.
output_size
,
config_
.
hidden_type
);
}
}
void
Linear
::
forward
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
if
(
qlen
<=
0
)
{
return
;
}
int
forward_len
=
std
::
min
(
qlen
,
config_
.
group_max_len
);
forward_many
(
forward_len
,
input
,
output
,
backend
);
forward
(
qlen
-
forward_len
,
(
uint8_t
*
)
input
+
forward_len
*
config_
.
input_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
(
uint8_t
*
)
output
+
forward_len
*
config_
.
output_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
}
}
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/linear.h
View file @
77a34c28
...
@@ -22,34 +22,38 @@
...
@@ -22,34 +22,38 @@
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct
LinearConfig
{
struct
LinearConfig
{
int
input_size
;
int
input_size
;
int
output_size
;
int
output_size
;
int
stride
;
int
stride
;
int
group_max_len
;
void
*
proj
;
void
*
proj
;
ggml_type
proj_type
;
ggml_type
proj_type
;
ggml_type
hidden_type
;
ggml_type
hidden_type
;
LinearConfig
()
{}
LinearConfig
()
{}
LinearConfig
(
int
input_size
,
int
output_size
,
int
stride
,
void
*
proj
,
ggml_type
proj_type
,
ggml_type
hidden_type
)
LinearConfig
(
int
input_size
,
int
output_size
,
int
stride
,
int
group_max_len
,
void
*
proj
,
ggml_type
proj_type
,
ggml_type
hidden_type
)
:
input_size
(
input_size
),
output_size
(
output_size
),
stride
(
stride
),
proj
(
proj
),
proj_type
(
proj_type
),
hidden_type
(
hidden_type
)
{}
:
input_size
(
input_size
),
output_size
(
output_size
),
stride
(
stride
),
group_max_len
(
group_max_len
),
proj
(
proj
),
proj_type
(
proj_type
),
hidden_type
(
hidden_type
)
{}
};
};
class
Linear
{
class
Linear
{
public:
public:
Linear
(
LinearConfig
);
Linear
(
LinearConfig
);
~
Linear
();
void
warm_up
(
Backend
*
backend
);
void
warm_up
(
Backend
*
backend
);
void
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward_many
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
private:
private:
LinearConfig
config_
;
LinearConfig
config_
;
void
*
proj_
;
// [output_size * input_size ( /32 if quantized)]
void
*
proj_
;
// [output_size * input_size ( /32 if quantized)]
std
::
vector
<
float
>
input_fp32_
;
// [input_size]
float
*
input_fp32_
;
// [
group_max_len *
input_size]
std
::
vector
<
uint8_t
>
proj_input_
;
// [
input_size * 4
]
uint8_t
*
proj_input_
;
// [
group_max_len * input_size * ggml_type_size(ggml_internal_get_type_traits(proj_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(proj_type).vec_dot_type)
]
std
::
vector
<
float
>
proj_output_
;
// [output_size]
float
*
proj_output_
;
// [
group_max_len *
output_size]
};
};
#endif
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
View file @
77a34c28
...
@@ -15,14 +15,20 @@ MLP::MLP(MLPConfig config) {
...
@@ -15,14 +15,20 @@ MLP::MLP(MLPConfig config) {
up_proj_
=
config_
.
up_proj
;
up_proj_
=
config_
.
up_proj
;
down_proj_
=
config_
.
down_proj
;
down_proj_
=
config_
.
down_proj
;
input_fp32_
.
resize
(
config_
.
hidden_size
);
std
::
vector
<
std
::
pair
<
void
**
,
uint64_t
>>
mem_requests
;
gate_input_
.
resize
(
config_
.
hidden_size
*
4
);
mem_requests
.
push_back
({(
void
**
)
&
input_fp32_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
});
up_input_
.
resize
(
config_
.
hidden_size
*
4
);
mem_requests
.
push_back
({(
void
**
)
&
gate_input_
,
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)});
gate_output_
.
resize
(
config_
.
intermediate_size
);
mem_requests
.
push_back
({(
void
**
)
&
up_input_
,
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)});
up_output_
.
resize
(
config_
.
intermediate_size
);
mem_requests
.
push_back
({(
void
**
)
&
gate_output_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
intermediate_fp32_
.
resize
(
config_
.
intermediate_size
);
mem_requests
.
push_back
({(
void
**
)
&
up_output_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
down_input_
.
resize
(
config_
.
intermediate_size
*
4
);
mem_requests
.
push_back
({(
void
**
)
&
intermediate_fp32_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
down_output_
.
resize
(
config_
.
hidden_size
);
mem_requests
.
push_back
({(
void
**
)
&
down_input_
,
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)});
mem_requests
.
push_back
({(
void
**
)
&
down_output_
,
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
});
shared_mem_buffer
.
alloc
(
this
,
mem_requests
);
}
MLP
::~
MLP
()
{
shared_mem_buffer
.
dealloc
(
this
);
}
}
void
MLP
::
warm_up
(
Backend
*
backend
)
{
void
MLP
::
warm_up
(
Backend
*
backend
)
{
...
@@ -33,33 +39,33 @@ void MLP::warm_up(Backend* backend) {
...
@@ -33,33 +39,33 @@ void MLP::warm_up(Backend* backend) {
input_fp32
[
i
]
=
0
;
input_fp32
[
i
]
=
0
;
}
}
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
from_float
(
input_fp32
.
data
(),
input
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
forward
(
input
.
data
(),
output
.
data
(),
backend
);
forward
_many
(
1
,
input
.
data
(),
output
.
data
(),
backend
);
}
}
static
float
act_fn
(
float
x
)
{
static
float
act_fn
(
float
x
)
{
return
x
/
(
1.0
f
+
expf
(
-
x
));
return
x
/
(
1.0
f
+
expf
(
-
x
));
}
}
void
MLP
::
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
void
MLP
::
forward
_many
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
const
void
*
gate_input_ptr
;
const
void
*
gate_input_ptr
;
const
void
*
up_input_ptr
;
const
void
*
up_input_ptr
;
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
&&
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
if
(
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
&&
config_
.
hidden_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
gate_input_ptr
=
up_input_ptr
=
input
;
gate_input_ptr
=
up_input_ptr
=
input
;
}
else
{
}
else
{
to_float
(
input
,
input_fp32_
.
data
(),
config_
.
hidden_size
,
config_
.
hidden_type
);
to_float
(
input
,
input_fp32_
,
qlen
*
config_
.
hidden_size
,
config_
.
hidden_type
);
if
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
if
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
==
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
()
,
gate_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
from_float
(
input_fp32_
,
gate_input_
,
qlen
*
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
up_input_ptr
=
gate_input_
.
data
()
;
gate_input_ptr
=
up_input_ptr
=
gate_input_
;
}
else
{
}
else
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
()
,
gate_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
from_float
(
input_fp32_
,
gate_input_
,
qlen
*
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
gate_input_ptr
=
gate_input_
.
data
()
;
gate_input_ptr
=
gate_input_
;
}
else
{
}
else
{
gate_input_ptr
=
input
;
gate_input_ptr
=
input
;
}
}
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
if
(
config_
.
hidden_type
!=
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
{
from_float
(
input_fp32_
.
data
()
,
up_input_
.
data
(),
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
from_float
(
input_fp32_
,
up_input_
,
qlen
*
config_
.
hidden_size
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
up_input_ptr
=
up_input_
.
data
()
;
up_input_ptr
=
up_input_
;
}
else
{
}
else
{
up_input_ptr
=
input
;
up_input_ptr
=
input
;
}
}
...
@@ -69,35 +75,49 @@ void MLP::forward(const void* input, void* output, Backend* backend) {
...
@@ -69,35 +75,49 @@ void MLP::forward(const void* input, void* output, Backend* backend) {
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
int
ith
=
task_id
;
void
*
gate_proj_ptr
=
(
uint8_t
*
)
gate_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
void
*
gate_proj_ptr
=
(
uint8_t
*
)
gate_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
float
*
gate_output_ptr
=
gate_output_
.
data
()
+
ith
*
config_
.
stride
;
float
*
gate_output_ptr
=
gate_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
strid
e
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
intermediate_siz
e
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
up_proj_ptr
=
(
uint8_t
*
)
up_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
void
*
up_proj_ptr
=
(
uint8_t
*
)
up_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
float
*
up_output_ptr
=
up_output_
.
data
()
+
ith
*
config_
.
stride
;
float
*
up_output_ptr
=
up_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
stride
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
intermediate_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
ith
*
config_
.
stride
;
i
<
(
ith
+
1
)
*
config_
.
stride
;
i
++
)
{
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
intermediate_fp32_
[
i
]
=
act_fn
(
gate_output_
[
i
])
*
up_output_
[
i
];
for
(
int
j
=
ith
*
config_
.
stride
;
j
<
(
ith
+
1
)
*
config_
.
stride
;
j
++
)
{
intermediate_fp32_
[
i
*
config_
.
intermediate_size
+
j
]
=
act_fn
(
gate_output_
[
i
*
config_
.
intermediate_size
+
j
])
*
up_output_
[
i
*
config_
.
intermediate_size
+
j
];
}
}
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
==
0
)
{
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
==
0
)
{
float
*
intermediate_fp32_ptr
=
intermediate_fp32_
.
data
()
+
ith
*
config_
.
stride
;
float
*
intermediate_fp32_ptr
=
intermediate_fp32_
+
i
*
config_
.
intermediate_size
+
ith
*
config_
.
stride
;
void
*
down_input_ptr
=
(
uint8_t
*
)
down_input_
.
data
(
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
void
*
down_input_ptr
=
(
uint8_t
*
)
down_input_
+
i
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
config_
.
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
config_
.
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
}
}
});
});
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
!=
0
)
{
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
!=
0
)
{
from_float
(
intermediate_fp32_
.
data
()
,
down_input_
.
data
(),
config_
.
intermediate_size
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_
,
down_input_
,
qlen
*
config_
.
intermediate_size
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
}
nth
=
config_
.
hidden_size
/
config_
.
stride
;
nth
=
config_
.
hidden_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
int
ith
=
task_id
;
void
*
down_proj_ptr
=
(
uint8_t
*
)
down_proj_
+
ith
*
config_
.
stride
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
void
*
down_proj_ptr
=
(
uint8_t
*
)
down_proj_
+
ith
*
config_
.
stride
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
float
*
down_output_ptr
=
down_output_
.
data
()
+
ith
*
config_
.
stride
;
float
*
down_output_ptr
=
down_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
1
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_input_
.
data
()
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
strid
e
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_input_
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
hidden_siz
e
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
void
*
output_ptr
=
(
uint8_t
*
)
output
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
from_float
(
down_output_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
float
*
output_fp32_ptr
=
down_output_
+
i
*
config_
.
hidden_size
+
ith
*
config_
.
stride
;
void
*
output_ptr
=
(
uint8_t
*
)
output
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
from_float
(
output_fp32_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
}
}
}
});
});
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
!=
0
)
{
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
!=
0
)
{
from_float
(
down_output_
.
data
(),
output
,
config_
.
hidden_size
,
config_
.
hidden_type
);
from_float
(
down_output_
,
output
,
qlen
*
config_
.
hidden_size
,
config_
.
hidden_type
);
}
}
void
MLP
::
forward
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
)
{
if
(
qlen
<=
0
)
{
return
;
}
}
int
forward_len
=
std
::
min
(
qlen
,
config_
.
group_max_len
);
forward_many
(
forward_len
,
input
,
output
,
backend
);
forward
(
qlen
-
forward_len
,
(
uint8_t
*
)
input
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
(
uint8_t
*
)
output
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
}
}
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/mlp.h
View file @
77a34c28
...
@@ -22,11 +22,13 @@
...
@@ -22,11 +22,13 @@
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct
MLPConfig
{
struct
MLPConfig
{
int
hidden_size
;
int
hidden_size
;
int
intermediate_size
;
int
intermediate_size
;
int
stride
;
int
stride
;
int
group_max_len
;
void
*
gate_proj
;
void
*
gate_proj
;
void
*
up_proj
;
void
*
up_proj
;
void
*
down_proj
;
void
*
down_proj
;
...
@@ -37,15 +39,17 @@ struct MLPConfig {
...
@@ -37,15 +39,17 @@ struct MLPConfig {
MLPConfig
()
{}
MLPConfig
()
{}
MLPConfig
(
int
hidden_size
,
int
intermediate_size
,
int
stride
,
void
*
gate_proj
,
void
*
up_proj
,
void
*
down_proj
,
ggml_type
gate_type
,
ggml_type
up_type
,
ggml_type
down_type
,
ggml_type
hidden_type
)
MLPConfig
(
int
hidden_size
,
int
intermediate_size
,
int
stride
,
int
group_max_len
,
void
*
gate_proj
,
void
*
up_proj
,
void
*
down_proj
,
ggml_type
gate_type
,
ggml_type
up_type
,
ggml_type
down_type
,
ggml_type
hidden_type
)
:
hidden_size
(
hidden_size
),
intermediate_size
(
intermediate_size
),
stride
(
stride
),
gate_proj
(
gate_proj
),
up_proj
(
up_proj
),
down_proj
(
down_proj
),
gate_type
(
gate_type
),
up_type
(
up_type
),
down_type
(
down_type
),
hidden_type
(
hidden_type
)
{}
:
hidden_size
(
hidden_size
),
intermediate_size
(
intermediate_size
),
stride
(
stride
),
group_max_len
(
group_max_len
),
gate_proj
(
gate_proj
),
up_proj
(
up_proj
),
down_proj
(
down_proj
),
gate_type
(
gate_type
),
up_type
(
up_type
),
down_type
(
down_type
),
hidden_type
(
hidden_type
)
{}
};
};
class
MLP
{
class
MLP
{
public:
public:
MLP
(
MLPConfig
);
MLP
(
MLPConfig
);
~
MLP
();
void
warm_up
(
Backend
*
backend
);
void
warm_up
(
Backend
*
backend
);
void
forward
(
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward_many
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
void
forward
(
int
qlen
,
const
void
*
input
,
void
*
output
,
Backend
*
backend
);
private:
private:
MLPConfig
config_
;
MLPConfig
config_
;
...
@@ -53,14 +57,14 @@ class MLP {
...
@@ -53,14 +57,14 @@ class MLP {
void
*
up_proj_
;
// [intermediate_size * hidden_size ( /32 if quantized)]
void
*
up_proj_
;
// [intermediate_size * hidden_size ( /32 if quantized)]
void
*
down_proj_
;
// [hidden_size * intermediate_size ( /32 if quantized)]
void
*
down_proj_
;
// [hidden_size * intermediate_size ( /32 if quantized)]
std
::
vector
<
float
>
input_fp32_
;
// [hidden_size]
float
*
input_fp32_
;
// [
group_max_len *
hidden_size]
std
::
vector
<
uint8_t
>
gate_input_
;
// [
hidden_size * 4
]
uint8_t
*
gate_input_
;
// [
group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)
]
std
::
vector
<
uint8_t
>
up_input_
;
// [
hidden_size * 4
]
uint8_t
*
up_input_
;
// [
group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)
]
std
::
vector
<
float
>
gate_output_
;
// [intermediate_size]
float
*
gate_output_
;
// [
group_max_len *
intermediate_size]
std
::
vector
<
float
>
up_output_
;
// [intermediate_size]
float
*
up_output_
;
// [
group_max_len *
intermediate_size]
std
::
vector
<
float
>
intermediate_fp32_
;
// [intermediate_size]
float
*
intermediate_fp32_
;
// [
group_max_len *
intermediate_size]
std
::
vector
<
uint8_t
>
down_input_
;
// [intermediate_size *
4
]
uint8_t
*
down_input_
;
// [
group_max_len *
intermediate_size *
ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)
]
std
::
vector
<
float
>
down_output_
;
// [hidden_size]
float
*
down_output_
;
// [
group_max_len *
hidden_size]
};
};
#endif
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/moe.cpp
View file @
77a34c28
...
@@ -6,92 +6,57 @@
...
@@ -6,92 +6,57 @@
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:07
* @LastEditTime : 2024-07-25 10:35:07
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
#include "moe.h"
#include "moe.h"
#include <iostream>
#include <iostream>
#include <cstdint>
#include <cstdint>
uint8_t
*
MOE
::
buffer_
=
nullptr
;
MOE
::
MOE
(
MOEConfig
config
)
{
MOE
::
MOE
(
MOEConfig
config
)
{
config_
=
config
;
config_
=
config
;
gate_proj_
=
config_
.
gate_proj
;
gate_proj_
=
config_
.
gate_proj
;
up_proj_
=
config_
.
up_proj
;
up_proj_
=
config_
.
up_proj
;
down_proj_
=
config_
.
down_proj
;
down_proj_
=
config_
.
down_proj
;
if
(
MOE
::
buffer_
==
nullptr
)
{
std
::
vector
<
std
::
pair
<
void
**
,
uint64_t
>>
s_mem_requests
;
uint64_t
buffer_size
=
0
;
s_mem_requests
.
push_back
({(
void
**
)
&
s_input_fp32_
,
sizeof
(
float
)
*
config_
.
hidden_size
});
buffer_size
+=
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
;
s_mem_requests
.
push_back
({(
void
**
)
&
s_gate_input_
,
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)});
buffer_size
+=
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
s_mem_requests
.
push_back
({(
void
**
)
&
s_up_input_
,
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)});
buffer_size
+=
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
buffer_size
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
buffer_size
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
;
buffer_size
+=
sizeof
(
float
)
*
config_
.
group_max_len
*
config_
.
hidden_size
;
buffer_
=
(
uint8_t
*
)
malloc
(
buffer_size
);
}
uint64_t
offset
=
0
;
s_input_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
s_gate_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
s_up_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
s_gate_output_
.
resize
(
config_
.
routed_expert_num
);
s_gate_output_
.
resize
(
config_
.
routed_expert_num
);
s_up_output_
.
resize
(
config_
.
routed_expert_num
);
s_up_output_
.
resize
(
config_
.
routed_expert_num
);
s_intermediate_fp32_
.
resize
(
config_
.
routed_expert_num
);
s_intermediate_fp32_
.
resize
(
config_
.
routed_expert_num
);
s_down_input_
.
resize
(
config_
.
routed_expert_num
);
s_down_input_
.
resize
(
config_
.
routed_expert_num
);
s_down_output_
.
resize
(
config_
.
routed_expert_num
);
s_down_output_
.
resize
(
config_
.
routed_expert_num
);
for
(
int
i
=
0
;
i
<
config_
.
routed_expert_num
;
i
++
)
{
for
(
int
i
=
0
;
i
<
config_
.
routed_expert_num
;
i
++
)
{
s_gate_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
s_mem_requests
.
push_back
({(
void
**
)
&
s_gate_output_
[
i
],
sizeof
(
float
)
*
config_
.
intermediate_size
});
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_mem_requests
.
push_back
({(
void
**
)
&
s_up_output_
[
i
],
sizeof
(
float
)
*
config_
.
intermediate_size
});
s_up_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
s_mem_requests
.
push_back
({(
void
**
)
&
s_intermediate_fp32_
[
i
],
sizeof
(
float
)
*
config_
.
intermediate_size
});
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
s_mem_requests
.
push_back
({(
void
**
)
&
s_down_input_
[
i
],
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)});
s_intermediate_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
s_mem_requests
.
push_back
({(
void
**
)
&
s_down_output_
[
i
],
sizeof
(
float
)
*
config_
.
hidden_size
});
offset
+=
sizeof
(
float
)
*
config_
.
intermediate_size
;
}
s_down_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
s_mem_requests
.
push_back
({(
void
**
)
&
s_output_fp32_
,
sizeof
(
float
)
*
config_
.
hidden_size
});
offset
+=
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
shared_mem_buffer
.
alloc
(
this
,
s_mem_requests
);
s_down_output_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
}
s_output_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
=
0
;
std
::
vector
<
std
::
pair
<
void
**
,
uint64_t
>>
m_mem_requests
;
m_input_fp32_
.
resize
(
config_
.
group_max_len
);
m_input_fp32_
.
resize
(
config_
.
group_max_len
);
m_gate_input_
.
resize
(
config_
.
group_max_len
);
m_gate_input_
.
resize
(
config_
.
group_max_len
);
m_up_input_
.
resize
(
config_
.
group_max_len
);
m_up_input_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
m_input_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_input_fp32_
[
i
],
sizeof
(
float
)
*
config_
.
hidden_size
});
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
m_mem_requests
.
push_back
({(
void
**
)
&
m_gate_input_
[
i
],
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)});
m_gate_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_up_input_
[
i
],
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)});
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
}
m_up_input_
[
i
]
=
(
uint8_t
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_gate_input_
,
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)});
offset
+=
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_up_input_
,
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)});
}
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_gate_output_
,
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
m_local_gate_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_up_output_
,
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_intermediate_fp32_
,
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
});
m_local_up_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_down_input_
,
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)});
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_local_down_output_
,
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
});
m_local_gate_output_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_up_output_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_intermediate_fp32_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
;
m_local_down_input_
=
(
uint8_t
*
)(
buffer_
+
offset
);
offset
+=
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
m_local_down_output_
=
(
float
*
)(
buffer_
+
offset
);
offset
+=
sizeof
(
float
)
*
config_
.
routed_expert_num
*
config_
.
group_max_len
*
config_
.
hidden_size
;
m_output_fp32_
.
resize
(
config_
.
group_max_len
);
m_output_fp32_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
m_output_fp32_
[
i
]
=
(
float
*
)(
buffer_
+
offset
);
m_mem_requests
.
push_back
({(
void
**
)
&
m_output_fp32_
[
i
],
sizeof
(
float
)
*
config_
.
hidden_size
});
offset
+=
sizeof
(
float
)
*
config_
.
hidden_size
;
}
}
shared_mem_buffer
.
alloc
(
this
,
m_mem_requests
);
m_local_pos_
.
resize
(
config_
.
group_max_len
);
m_local_pos_
.
resize
(
config_
.
group_max_len
);
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
for
(
int
i
=
0
;
i
<
config_
.
group_max_len
;
i
++
)
{
...
@@ -107,6 +72,10 @@ MOE::MOE(MOEConfig config) {
...
@@ -107,6 +72,10 @@ MOE::MOE(MOEConfig config) {
m_local_down_output_ptr_
.
resize
(
config_
.
expert_num
);
m_local_down_output_ptr_
.
resize
(
config_
.
expert_num
);
}
}
MOE
::~
MOE
()
{
shared_mem_buffer
.
dealloc
(
this
);
}
void
MOE
::
warm_up
(
Backend
*
backend
)
{
void
MOE
::
warm_up
(
Backend
*
backend
)
{
std
::
vector
<
float
>
input_fp32
(
config_
.
hidden_size
);
std
::
vector
<
float
>
input_fp32
(
config_
.
hidden_size
);
std
::
vector
<
uint8_t
>
input
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
std
::
vector
<
uint8_t
>
input
(
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
));
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment