Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
f5f79f5c
Commit
f5f79f5c
authored
Aug 12, 2024
by
chenxl
Browse files
[ADD] support multi-gpu qlen>1 q5_k
parent
f2938031
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
241 additions
and
1010 deletions
+241
-1010
.gitignore
.gitignore
+4
-1
install.bat
install.bat
+16
-0
install.sh
install.sh
+1
-1
ktransformers/ktransformers_ext/CMakeLists.txt
ktransformers/ktransformers_ext/CMakeLists.txt
+18
-3
ktransformers/ktransformers_ext/cpu_backend/task_queue.h
ktransformers/ktransformers_ext/cpu_backend/task_queue.h
+41
-3
ktransformers/ktransformers_ext/cuda/binding.cpp
ktransformers/ktransformers_ext/cuda/binding.cpp
+4
-2
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
+3
-0
ktransformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h
...sformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h
+0
-39
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+54
-2
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
+3
-2
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
...formers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
+56
-21
ktransformers/ktransformers_ext/cuda/setup.py
ktransformers/ktransformers_ext/cuda/setup.py
+18
-10
ktransformers/ktransformers_ext/ext_bindings.cpp
ktransformers/ktransformers_ext/ext_bindings.cpp
+12
-12
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py
...transformers_ext/operators/custom_marlin/quantize/gptq.py
+0
-206
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py
...rmers_ext/operators/custom_marlin/quantize/gptq_marlin.py
+0
-458
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py
...formers_ext/operators/custom_marlin/quantize/quantizer.py
+0
-140
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py
...ansformers_ext/operators/custom_marlin/quantize/repack.py
+0
-99
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
...xt/operators/custom_marlin/quantize/utils/marlin_utils.py
+2
-2
ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
...sformers/ktransformers_ext/operators/llamafile/linear.cpp
+3
-3
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
+6
-6
No files found.
.gitignore
View file @
f5f79f5c
...
...
@@ -14,4 +14,7 @@ node_modules
.DS_Store
compile_commands.json
*.egg-info*
*dist/
\ No newline at end of file
*dist/
ktransformers/server/local_store/
ktransformers/server_test1.db
*.patch
\ No newline at end of file
install.bat
0 → 100644
View file @
f5f79f5c
@echo
off
REM clear build dirs
rmdir
/S /Q
ktransformers
\ktransformers_ext\build
rmdir
/S /Q
ktransformers
\ktransformers_ext\cuda\build
rmdir
/S /Q
ktransformers
\ktransformers_ext\cuda\dist
rmdir
/S /Q
ktransformers
\ktransformers_ext\out
del
/F /Q
ktransformers
\ktransformers_ext\cuda\
*
.egg
-info
echo
Installing
python
dependencies
from
requirements
.txt
pip
install
-r
requirements
-local
_chat.txt
echo
Installing
ktransformers
set
KTRANSFORMERS_FORCE_BUILD
=
TRUE
pip
install
.
--no-build-isolation
echo
Installation
completed
successfully
\ No newline at end of file
install.sh
View file @
f5f79f5c
...
...
@@ -11,5 +11,5 @@ echo "Installing python dependencies from requirements.txt"
pip
install
-r
requirements-local_chat.txt
echo
"Installing ktransformers"
pip
install
.
--no-build-isolation
KTRANSFORMERS_FORCE_BUILD
=
TRUE
pip
install
.
--no-build-isolation
echo
"Installation completed successfully"
\ No newline at end of file
ktransformers/ktransformers_ext/CMakeLists.txt
View file @
f5f79f5c
...
...
@@ -189,7 +189,13 @@ else()
message
(
STATUS
"Unknown architecture"
)
endif
()
find_package
(
CUDA REQUIRED
)
# message(STATUS "CUDAToolkit_ROOT:${CUDAToolkit_ROOT}")
# find_package(FindCUDAToolkit REQUIRED)
# if(CUDAToolkit_FOUND)
# message(STATUS "Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}")
# else()
# message(STATUS "Can't found CUDA lib")
# endif()
add_compile_options
(
"$<$<COMPILE_LANGUAGE:CXX>:
${
ARCH_FLAGS
}
>"
)
add_compile_options
(
"$<$<COMPILE_LANGUAGE:C>:
${
ARCH_FLAGS
}
>"
)
...
...
@@ -198,7 +204,12 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11 ${CMAKE_
add_subdirectory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../third_party/llama.cpp
${
CMAKE_CURRENT_BINARY_DIR
}
/third_party/llama.cpp
)
include_directories
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../third_party
)
include_directories
(
"
${
CUDA_INCLUDE_DIRS
}
"
)
if
(
WIN32
)
include_directories
(
"$ENV{CUDA_PATH}/include"
)
elseif
(
UNIX
)
find_package
(
CUDA REQUIRED
)
include_directories
(
"
${
CUDA_INCLUDE_DIRS
}
"
)
endif
()
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
SOURCE_DIR1
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/cpu_backend SOURCE_DIR2
)
...
...
@@ -209,4 +220,8 @@ message(STATUS "ALL_SOURCES: ${ALL_SOURCES}")
pybind11_add_module
(
${
PROJECT_NAME
}
MODULE
${
ALL_SOURCES
}
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE llama
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"/usr/local/cuda/lib64/libcudart.so"
)
\ No newline at end of file
if
(
WIN32
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"$ENV{CUDA_PATH}/lib/x64/cudart.lib"
)
#CUDA::cudart
elseif
(
UNIX
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"$ENV{CUDA_HOME}/lib64/libcudart.so"
)
endif
()
\ No newline at end of file
ktransformers/ktransformers_ext/cpu_backend/task_queue.h
View file @
f5f79f5c
...
...
@@ -3,8 +3,8 @@
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chen
ht2022
* @LastEditTime : 2024-0
7-25 10:33:47
* @LastEditors : chen
xl
* @LastEditTime : 2024-0
8-08 04:23:51
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_TASKQUEUE_H
...
...
@@ -17,6 +17,44 @@
#include <queue>
#include <thread>
#include <vector>
#ifdef _WIN32
#include <windows.h>
#endif
class
custom_mutex
{
private:
#ifdef _WIN32
HANDLE
global_mutex
;
#else
std
::
mutex
global_mutex
;
#endif
public:
custom_mutex
()
{
#ifdef _WIN32
HANDLE
global_mutex
;
#endif
}
void
lock
()
{
#ifdef _WIN32
WaitForSingleObject
(
global_mutex
,
INFINITE
);
#else
global_mutex
.
lock
();
#endif
}
void
unlock
()
{
#ifdef _WIN32
ReleaseMutex
(
global_mutex
);
#else
global_mutex
.
lock
();
#endif
}
};
class
TaskQueue
{
public:
...
...
@@ -32,7 +70,7 @@ class TaskQueue {
std
::
queue
<
std
::
function
<
void
()
>>
tasks
;
std
::
thread
worker
;
std
::
mutex
mutex
;
custom_
mutex
mutex
;
std
::
atomic
<
bool
>
sync_flag
;
std
::
atomic
<
bool
>
exit_flag
;
};
...
...
ktransformers/ktransformers_ext/cuda/binding.cpp
View file @
f5f79f5c
...
...
@@ -3,8 +3,8 @@
* @Author : Azure-Tang
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors :
Azure
* @LastEditTime : 2024-0
7-26 08:36
:0
3
* @LastEditors :
kkk1nak0
* @LastEditTime : 2024-0
8-09 01:45
:0
2
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
...
...
@@ -23,6 +23,8 @@ PYBIND11_MODULE(KTransformersOps, m) {
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
...
...
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
View file @
f5f79f5c
...
...
@@ -12,12 +12,15 @@ int test(){
}
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
PYBIND11_MODULE
(
cudaops
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
...
...
ktransformers/ktransformers_ext/cuda/custom_gguf/custom_ggml.h
deleted
100644 → 0
View file @
f2938031
#include <cuda_fp16.h>
__device__
float
ggml_compute_fp16_to_fp32
(
uint16_t
h
)
{
return
__uint2float_rd
(
h
);
}
static
inline
float
ggml_compute_fp16_to_fp32
(
uint16_t
h
)
{
uint16_t
tmp
;
memcpy
(
&
tmp
,
&
h
,
sizeof
(
ggml_fp16_t
));
return
(
float
)
tmp
;
}
// define the global table for fp16 to fp32 conversion
__device__
float
ggml_table_f32_f16
[
1
<<
16
];
// CUDA Kernel to init the table
__global__
void
init_fp16_to_fp32_table
()
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
blk_id
=
idx
;
blk_id
<
(
1
<<
16
);
blk_id
+=
blockDim
.
x
*
gridDim
.
x
){
ggml_table_f32_f16
[
blk_id
]
=
GGML_COMPUTE_FP16_TO_FP32
(
blk_id
);
}
}
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
extern
__device__
float
ggml_table_f32_f16
[
1
<<
16
];
// Declare as __device__ if used within device code
// This version of the function is designed to be called from within a CUDA kernel
#if !defined(GGML_FP16_TO_FP32)
__device__
float
ggml_lookup_fp16_to_fp32
(
uint16_t
f
)
{
return
ggml_table_f32_f16
[
f
];
}
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
f5f79f5c
...
...
@@ -3,8 +3,8 @@
* @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors :
Azure
* @LastEditTime : 2024-0
7-26 11:58:50
* @LastEditors :
kkk1nak0
* @LastEditTime : 2024-0
8-09 07:57:06
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
...
...
@@ -14,6 +14,7 @@
#include <torch/extension.h>
#include <torch/torch.h>
#include <cstdint>
#include <c10/cuda/CUDAGuard.h>
__global__
void
dequantize_q8_0_kernel
(
float
*
output
,
const
float
*
scales
,
const
int8_t
*
qs
,
int
num_blocks
,
int
blk_size
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -59,6 +60,35 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__
void
dequantize_q5_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
2
)));
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
48
);
int
is
=
0
;
uint8_t
sc
,
m
;
uint8_t
u1
=
1
,
u2
=
2
;
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
4
);
for
(
int
j
=
0
;
j
<
256
;
j
+=
64
)
{
get_scale_min_k4
(
is
+
0
,
scales
,
&
sc
,
&
m
);
const
float
d1
=
d
*
sc
;
const
float
m1
=
min
*
m
;
get_scale_min_k4
(
is
+
1
,
scales
,
&
sc
,
&
m
);
const
float
d2
=
d
*
sc
;
const
float
m2
=
min
*
m
;
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
d1
*
((
ql
[
l
]
&
0xF
)
+
(
qh
[
l
]
&
u1
?
16
:
0
))
-
m1
;
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
d2
*
((
ql
[
l
]
>>
4
)
+
(
qh
[
l
]
&
u2
?
16
:
0
))
-
m2
;
ql
+=
32
;
is
+=
2
;
u1
<<=
2
;
u2
<<=
2
;
}
}
}
__global__
void
dequantize_q6_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
...
...
@@ -94,6 +124,7 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
// create gpu
auto
options_scales
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options_qs
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
...
...
@@ -128,6 +159,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
// data.numel%blk_size should be 0, else raise err
int
num_blocks
=
data
.
numel
()
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
...
...
@@ -147,6 +179,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
// data.numel%blk_size should be 0, else raise err
int
num_blocks
=
data
.
numel
()
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
...
...
@@ -162,3 +195,22 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
cudaDeviceSynchronize
();
return
output
;
}
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q5_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
View file @
f5f79f5c
...
...
@@ -3,8 +3,8 @@
* @Author : Azure-Tang
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @LastEditors :
Azure
* @LastEditTime : 2024-0
7-26 08:38
:2
0
* @LastEditors :
kkk1nak0
* @LastEditTime : 2024-0
8-09 01:44
:2
1
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#pragma once
...
...
@@ -15,4 +15,5 @@
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
View file @
f5f79f5c
...
...
@@ -23,7 +23,7 @@
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include <c10/cuda/CUDAGuard.h>
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
...
...
@@ -1703,28 +1703,63 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
}
// Define kernel configurations
if
(
false
)
{
#define undefined_error TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + \
str(prob_n) + ", " + str(prob_k) + "]" + \
", has_act_order = " + str(has_act_order) + \
", num_groups = " + str(num_groups) + \
", group_size = " + str(group_size) + \
", thread_m_blocks = " + str(thread_m_blocks) + \
", thread_n_blocks = " + str(thread_n_blocks) + \
", thread_k_blocks = " + str(thread_k_blocks));
if
(
num_bits
==
4
&&
num_threads
==
256
)
{
if
(
false
)
{
}
CALL_IF
(
4
,
32
,
2
,
256
)
CALL_IF
(
4
,
16
,
4
,
256
)
CALL_IF
(
4
,
8
,
8
,
256
)
else
{
undefined_error
}
}
else
if
(
num_bits
==
4
&&
num_threads
==
128
)
{
if
(
false
)
{
}
CALL_IF
(
4
,
8
,
4
,
128
)
CALL_IF
(
4
,
4
,
8
,
128
)
else
{
undefined_error
}
}
else
if
(
num_bits
==
8
&&
num_threads
==
256
)
{
if
(
false
)
{
}
CALL_IF
(
8
,
32
,
2
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
8
,
8
,
256
)
else
{
undefined_error
}
}
else
if
(
num_bits
==
8
&&
num_threads
==
128
)
{
if
(
false
)
{
}
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
else
{
undefined_error
}
}
CALL_IF
(
4
,
32
,
2
,
256
)
CALL_IF
(
4
,
16
,
4
,
256
)
CALL_IF
(
4
,
8
,
8
,
256
)
CALL_IF
(
4
,
8
,
4
,
128
)
CALL_IF
(
4
,
4
,
8
,
128
)
CALL_IF
(
8
,
32
,
2
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
8
,
8
,
256
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", has_act_order = "
+
str
(
has_act_order
)
+
", num_groups = "
+
str
(
num_groups
)
+
", group_size = "
+
str
(
group_size
)
+
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
undefined_error
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
...
...
@@ -1739,6 +1774,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
...
...
@@ -1781,7 +1817,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
...
...
ktransformers/ktransformers_ext/cuda/setup.py
View file @
f5f79f5c
...
...
@@ -2,17 +2,25 @@
from
setuptools
import
setup
,
Extension
from
torch.utils
import
cpp_extension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
# setup marlin gemm
setup
(
name
=
'KTransformersOps'
,
ext_modules
=
[
CUDAExtension
(
'KTransformersOps'
,
[
setup
(
name
=
'KTransformersOps'
,
ext_modules
=
[
CUDAExtension
(
'KTransformersOps'
,
[
'custom_gguf/dequant.cu'
,
'binding.cpp'
,
'gptq_marlin/gptq_marlin.cu'
,
# 'gptq_marlin_repack.cu',
])
],
cmdclass
=
{
'build_ext'
:
BuildExtension
})
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
'nvcc'
:
[
'-O3'
,
'--use_fast_math'
,
'-Xcompiler'
,
'-fPIC'
,
]
},
)
],
cmdclass
=
{
'build_ext'
:
BuildExtension
}
)
\ No newline at end of file
ktransformers/ktransformers_ext/ext_bindings.cpp
View file @
f5f79f5c
...
...
@@ -37,7 +37,7 @@ class LinearBindings {
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
Linear
::
warm_up
,
args_
->
linear
);
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
interface
(
Linear
&
linear
)
{
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_
interface
(
Linear
&
linear
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
linear
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
...
...
@@ -55,7 +55,7 @@ class LinearBindings {
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
Linear
::
forward
,
args_
->
linear
,
args_
->
qlen
,
args_
->
input
,
args_
->
output
);
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
interface
(
Linear
&
linear
,
int
qlen
,
intptr_t
input
,
intptr_t
output
)
{
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_
interface
(
Linear
&
linear
,
int
qlen
,
intptr_t
input
,
intptr_t
output
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
linear
,
qlen
,
(
const
void
*
)
input
,
(
void
*
)
output
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
...
...
@@ -74,7 +74,7 @@ class MLPBindings {
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
MLP
::
warm_up
,
args_
->
mlp
);
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
interface
(
MLP
&
mlp
)
{
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_
interface
(
MLP
&
mlp
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
mlp
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
...
...
@@ -92,7 +92,7 @@ class MLPBindings {
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
MLP
::
forward
,
args_
->
mlp
,
args_
->
qlen
,
args_
->
input
,
args_
->
output
);
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
interface
(
MLP
&
mlp
,
int
qlen
,
intptr_t
input
,
intptr_t
output
)
{
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_
interface
(
MLP
&
mlp
,
int
qlen
,
intptr_t
input
,
intptr_t
output
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
mlp
,
qlen
,
(
const
void
*
)
input
,
(
void
*
)
output
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
...
...
@@ -111,7 +111,7 @@ class MOEBindings {
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
MOE
::
warm_up
,
args_
->
moe
);
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
interface
(
MOE
&
moe
)
{
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_
interface
(
MOE
&
moe
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
moe
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
...
...
@@ -132,7 +132,7 @@ class MOEBindings {
Args
*
args_
=
(
Args
*
)
args
;
args_
->
cpuinfer
->
enqueue
(
&
MOE
::
forward
,
args_
->
moe
,
args_
->
qlen
,
args_
->
k
,
args_
->
expert_ids
,
args_
->
weights
,
args_
->
input
,
args_
->
output
);
}
static
std
::
pair
<
intptr_t
,
intptr_t
>
interface
(
MOE
&
moe
,
int
qlen
,
int
k
,
intptr_t
expert_ids
,
intptr_t
weights
,
intptr_t
input
,
intptr_t
output
)
{
static
std
::
pair
<
intptr_t
,
intptr_t
>
cpuinfer_
interface
(
MOE
&
moe
,
int
qlen
,
int
k
,
intptr_t
expert_ids
,
intptr_t
weights
,
intptr_t
input
,
intptr_t
output
)
{
Args
*
args
=
new
Args
{
nullptr
,
&
moe
,
qlen
,
k
,
(
const
uint64_t
*
)
expert_ids
,
(
const
float
*
)
weights
,
(
const
void
*
)
input
,
(
void
*
)
output
};
return
std
::
make_pair
((
intptr_t
)
&
inner
,
(
intptr_t
)
args
);
}
...
...
@@ -154,8 +154,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
}));
py
::
class_
<
Linear
>
(
linear_module
,
"Linear"
)
.
def
(
py
::
init
<
LinearConfig
>
())
.
def
(
"warm_up"
,
&
LinearBindings
::
WarmUpBindinds
::
interface
)
.
def
(
"forward"
,
&
LinearBindings
::
ForwardBindings
::
interface
);
.
def
(
"warm_up"
,
&
LinearBindings
::
WarmUpBindinds
::
cpuinfer_
interface
)
.
def
(
"forward"
,
&
LinearBindings
::
ForwardBindings
::
cpuinfer_
interface
);
auto
mlp_module
=
m
.
def_submodule
(
"mlp"
);
py
::
class_
<
MLPConfig
>
(
mlp_module
,
"MLPConfig"
)
...
...
@@ -164,8 +164,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
}));
py
::
class_
<
MLP
>
(
mlp_module
,
"MLP"
)
.
def
(
py
::
init
<
MLPConfig
>
())
.
def
(
"warm_up"
,
&
MLPBindings
::
WarmUpBindinds
::
interface
)
.
def
(
"forward"
,
&
MLPBindings
::
ForwardBindings
::
interface
);
.
def
(
"warm_up"
,
&
MLPBindings
::
WarmUpBindinds
::
cpuinfer_
interface
)
.
def
(
"forward"
,
&
MLPBindings
::
ForwardBindings
::
cpuinfer_
interface
);
auto
moe_module
=
m
.
def_submodule
(
"moe"
);
py
::
class_
<
MOEConfig
>
(
moe_module
,
"MOEConfig"
)
...
...
@@ -174,6 +174,6 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
}));
py
::
class_
<
MOE
>
(
moe_module
,
"MOE"
)
.
def
(
py
::
init
<
MOEConfig
>
())
.
def
(
"warm_up"
,
&
MOEBindings
::
WarmUpBindinds
::
interface
)
.
def
(
"forward"
,
&
MOEBindings
::
ForwardBindings
::
interface
);
.
def
(
"warm_up"
,
&
MOEBindings
::
WarmUpBindinds
::
cpuinfer_
interface
)
.
def
(
"forward"
,
&
MOEBindings
::
ForwardBindings
::
cpuinfer_
interface
);
}
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq.py
deleted
100644 → 0
View file @
f2938031
import
math
import
os
import
time
from
logging
import
getLogger
import
torch
import
torch.nn
as
nn
import
transformers
from
.quantizer
import
Quantizer
logger
=
getLogger
(
__name__
)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
class
GPTQ
:
def
__init__
(
self
,
layer
):
self
.
layer
=
layer
self
.
dev
=
self
.
layer
.
weight
.
device
W
=
layer
.
weight
.
data
.
clone
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
W
=
W
.
flatten
(
1
)
if
isinstance
(
self
.
layer
,
transformers
.
pytorch_utils
.
Conv1D
):
W
=
W
.
t
()
self
.
rows
=
W
.
shape
[
0
]
self
.
columns
=
W
.
shape
[
1
]
self
.
H
=
torch
.
zeros
((
self
.
columns
,
self
.
columns
),
device
=
self
.
dev
)
self
.
nsamples
=
0
self
.
quantizer
=
Quantizer
()
def
add_batch
(
self
,
inp
,
out
):
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
inp1
=
inp
self
.
out1
=
out
if
len
(
inp
.
shape
)
==
2
:
inp
=
inp
.
unsqueeze
(
0
)
tmp
=
inp
.
shape
[
0
]
if
isinstance
(
self
.
layer
,
nn
.
Linear
)
or
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
if
len
(
inp
.
shape
)
==
3
:
inp
=
inp
.
reshape
((
-
1
,
inp
.
shape
[
-
1
]))
inp
=
inp
.
t
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
unfold
=
nn
.
Unfold
(
self
.
layer
.
kernel_size
,
dilation
=
self
.
layer
.
dilation
,
padding
=
self
.
layer
.
padding
,
stride
=
self
.
layer
.
stride
,
)
inp
=
unfold
(
inp
)
inp
=
inp
.
permute
([
1
,
0
,
2
])
inp
=
inp
.
flatten
(
1
)
self
.
H
*=
self
.
nsamples
/
(
self
.
nsamples
+
tmp
)
self
.
nsamples
+=
tmp
# inp = inp.float()
inp
=
math
.
sqrt
(
2
/
self
.
nsamples
)
*
inp
.
float
()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self
.
H
+=
inp
.
matmul
(
inp
.
t
())
def
fasterquant
(
self
,
blocksize
=
128
,
percdamp
=
0.01
,
group_size
=-
1
,
actorder
=
False
,
static_groups
=
False
,
):
W
=
self
.
layer
.
weight
.
data
.
clone
()
if
isinstance
(
self
.
layer
,
nn
.
Conv2d
):
W
=
W
.
flatten
(
1
)
if
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
W
=
W
.
t
()
W
=
W
.
float
()
tick
=
time
.
time
()
if
not
self
.
quantizer
.
ready
():
self
.
quantizer
.
find_params
(
W
,
weight
=
True
)
H
=
self
.
H
del
self
.
H
dead
=
torch
.
diag
(
H
)
==
0
H
[
dead
,
dead
]
=
1
W
[:,
dead
]
=
0
g_idx
=
[]
scale
=
[]
zero
=
[]
now_idx
=
1
if
static_groups
:
import
copy
groups
=
[]
for
i
in
range
(
0
,
self
.
columns
,
group_size
):
quantizer
=
copy
.
deepcopy
(
self
.
quantizer
)
quantizer
.
find_params
(
W
[:,
i
:
(
i
+
group_size
)],
weight
=
True
)
scale
.
append
(
quantizer
.
scale
)
zero
.
append
(
quantizer
.
zero
)
groups
.
append
(
quantizer
)
if
actorder
:
perm
=
torch
.
argsort
(
torch
.
diag
(
H
),
descending
=
True
)
W
=
W
[:,
perm
]
H
=
H
[
perm
][:,
perm
]
invperm
=
torch
.
argsort
(
perm
)
Losses
=
torch
.
zeros_like
(
W
)
Q
=
torch
.
zeros_like
(
W
)
damp
=
percdamp
*
torch
.
mean
(
torch
.
diag
(
H
))
diag
=
torch
.
arange
(
self
.
columns
,
device
=
self
.
dev
)
H
[
diag
,
diag
]
+=
damp
H
=
torch
.
linalg
.
cholesky
(
H
)
H
=
torch
.
cholesky_inverse
(
H
)
H
=
torch
.
linalg
.
cholesky
(
H
,
upper
=
True
)
Hinv
=
H
for
i1
in
range
(
0
,
self
.
columns
,
blocksize
):
i2
=
min
(
i1
+
blocksize
,
self
.
columns
)
count
=
i2
-
i1
W1
=
W
[:,
i1
:
i2
].
clone
()
Q1
=
torch
.
zeros_like
(
W1
)
Err1
=
torch
.
zeros_like
(
W1
)
Losses1
=
torch
.
zeros_like
(
W1
)
Hinv1
=
Hinv
[
i1
:
i2
,
i1
:
i2
]
for
i
in
range
(
count
):
w
=
W1
[:,
i
]
d
=
Hinv1
[
i
,
i
]
if
group_size
!=
-
1
:
if
not
static_groups
:
if
(
i1
+
i
)
%
group_size
==
0
:
self
.
quantizer
.
find_params
(
W
[:,
(
i1
+
i
)
:
(
i1
+
i
+
group_size
)],
weight
=
True
)
if
((
i1
+
i
)
//
group_size
)
-
now_idx
==
-
1
:
scale
.
append
(
self
.
quantizer
.
scale
)
zero
.
append
(
self
.
quantizer
.
zero
)
now_idx
+=
1
else
:
idx
=
i1
+
i
if
actorder
:
idx
=
perm
[
idx
]
self
.
quantizer
=
groups
[
idx
//
group_size
]
q
=
self
.
quantizer
.
quantize
(
w
.
unsqueeze
(
1
)).
flatten
()
Q1
[:,
i
]
=
q
Losses1
[:,
i
]
=
(
w
-
q
)
**
2
/
d
**
2
err1
=
(
w
-
q
)
/
d
W1
[:,
i
:]
-=
err1
.
unsqueeze
(
1
).
matmul
(
Hinv1
[
i
,
i
:].
unsqueeze
(
0
))
Err1
[:,
i
]
=
err1
Q
[:,
i1
:
i2
]
=
Q1
Losses
[:,
i1
:
i2
]
=
Losses1
/
2
W
[:,
i2
:]
-=
Err1
.
matmul
(
Hinv
[
i1
:
i2
,
i2
:])
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
layer
.
weight
.
data
[:,
:
i2
]
=
Q
[:,
:
i2
]
self
.
layer
.
weight
.
data
[:,
i2
:]
=
W
[:,
i2
:]
logger
.
debug
(
torch
.
sum
((
self
.
layer
(
self
.
inp1
)
-
self
.
out1
)
**
2
))
logger
.
debug
(
torch
.
sum
(
Losses
))
torch
.
cuda
.
synchronize
()
logger
.
info
(
f
"duration:
{
(
time
.
time
()
-
tick
)
}
"
)
logger
.
info
(
f
"avg loss:
{
torch
.
sum
(
Losses
).
item
()
/
self
.
nsamples
}
"
)
group_size
=
group_size
if
group_size
!=
-
1
else
self
.
columns
if
static_groups
and
actorder
:
g_idx
=
[
perm
[
i
]
//
group_size
for
i
in
range
(
self
.
columns
)]
else
:
g_idx
=
[
i
//
group_size
for
i
in
range
(
self
.
columns
)]
g_idx
=
torch
.
tensor
(
g_idx
,
dtype
=
torch
.
int32
,
device
=
Q
.
device
)
if
actorder
:
Q
=
Q
[:,
invperm
]
g_idx
=
g_idx
[
invperm
]
if
isinstance
(
self
.
layer
,
transformers
.
Conv1D
):
Q
=
Q
.
t
()
self
.
layer
.
weight
.
data
=
Q
.
reshape
(
self
.
layer
.
weight
.
shape
).
type_as
(
self
.
layer
.
weight
.
data
)
if
os
.
environ
.
get
(
"DEBUG"
):
logger
.
debug
(
torch
.
sum
((
self
.
layer
(
self
.
inp1
)
-
self
.
out1
)
**
2
))
if
scale
==
[]:
scale
.
append
(
self
.
quantizer
.
scale
)
zero
.
append
(
self
.
quantizer
.
zero
)
scale
=
torch
.
cat
(
scale
,
dim
=
1
)
zero
=
torch
.
cat
(
zero
,
dim
=
1
)
return
scale
,
zero
,
g_idx
def
free
(
self
):
if
os
.
environ
.
get
(
"DEBUG"
):
self
.
inp1
=
None
self
.
out1
=
None
self
.
H
=
None
self
.
Losses
=
None
self
.
Trace
=
None
torch
.
cuda
.
empty_cache
()
__all__
=
[
"GPTQ"
]
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/gptq_marlin.py
deleted
100644 → 0
View file @
f2938031
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
get_pack_factor
(
num_bits
:
int
):
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"Marlin does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"Marlin does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
raise
ValueError
(
f
"Marlin does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
# Init
self
.
pack_factor
=
get_pack_factor
(
weight_bits
)
self
.
tile_size
=
GPTQ_MARLIN_TILE
self
.
min_thread_n
=
GPTQ_MARLIN_MIN_THREAD_N
self
.
min_thread_k
=
GPTQ_MARLIN_MIN_THREAD_K
self
.
max_parallel
=
GPTQ_MARLIN_MAX_PARALLEL
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"gptq"
:
logger
.
info
(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
is_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
return
False
# If the capability of the device is too low, cannot convert.
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
cls
.
get_min_capability
():
return
False
# Otherwise, can convert if model satisfies marlin constraints.
return
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and
sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
# Validate dtype
if
params_dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]:
raise
ValueError
(
f
"The params dtype must be float16 "
f
"or bfloat16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_thread_n
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
" min_thread_n =
{
self
.
quant_config
.
min_thread_n
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_thread_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
self
.
quant_config
.
min_thread_k
}
."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
)
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size
=
input_size
//
group_size
scales_and_zp_input_dim
=
None
if
self
.
quant_config
.
desc_act
:
# Act-order case
assert
self
.
quant_config
.
group_size
!=
-
1
is_k_full
=
input_size_per_partition
==
input_size
else
:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full
=
True
# If this is a row-parallel case, then shard scales/zp
if
(
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_input_dim
=
0
# Init buffers
# Quantized weights
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Activation order
g_idx
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
g_idx_sort_indices
=
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
},
)
# Quantized zero-points
qzeros
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
device
=
"meta"
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Allocate marlin workspace
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_thread_n
)
*
self
.
quant_config
.
max_parallel
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
layer
.
workspace
=
workspace
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
is_k_full
=
is_k_full
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/quantizer.py
deleted
100644 → 0
View file @
f2938031
from
logging
import
getLogger
import
torch
import
torch.nn
as
nn
logger
=
getLogger
(
__name__
)
def
quantize
(
x
,
scale
,
zero
,
maxq
):
if
maxq
<
0
:
return
(
x
>
scale
/
2
).
float
()
*
scale
+
(
x
<
zero
/
2
).
float
()
*
zero
q
=
torch
.
clamp
(
torch
.
round
(
x
/
scale
)
+
zero
,
0
,
maxq
)
return
scale
*
(
q
-
zero
)
class
Quantizer
(
nn
.
Module
):
def
__init__
(
self
,
shape
=
1
):
super
(
Quantizer
,
self
).
__init__
()
self
.
register_buffer
(
"maxq"
,
torch
.
tensor
(
0
))
self
.
register_buffer
(
"scale"
,
torch
.
zeros
(
shape
))
self
.
register_buffer
(
"zero"
,
torch
.
zeros
(
shape
))
def
configure
(
self
,
bits
,
perchannel
=
False
,
sym
=
True
,
mse
=
False
,
norm
=
2.4
,
grid
=
100
,
maxshrink
=
0.8
,
trits
=
False
,
):
self
.
maxq
=
torch
.
tensor
(
2
**
bits
-
1
)
self
.
perchannel
=
perchannel
self
.
sym
=
sym
self
.
mse
=
mse
self
.
norm
=
norm
self
.
grid
=
grid
self
.
maxshrink
=
maxshrink
if
trits
:
self
.
maxq
=
torch
.
tensor
(
-
1
)
def
find_params
(
self
,
x
,
weight
=
False
):
dev
=
x
.
device
self
.
maxq
=
self
.
maxq
.
to
(
dev
)
shape
=
x
.
shape
if
self
.
perchannel
:
if
weight
:
x
=
x
.
flatten
(
1
)
else
:
if
len
(
shape
)
==
4
:
x
=
x
.
permute
([
1
,
0
,
2
,
3
])
x
=
x
.
flatten
(
1
)
if
len
(
shape
)
==
3
:
x
=
x
.
reshape
((
-
1
,
shape
[
-
1
])).
t
()
if
len
(
shape
)
==
2
:
x
=
x
.
t
()
else
:
x
=
x
.
flatten
().
unsqueeze
(
0
)
tmp
=
torch
.
zeros
(
x
.
shape
[
0
],
device
=
dev
)
xmin
=
torch
.
minimum
(
x
.
min
(
1
)[
0
],
tmp
)
xmax
=
torch
.
maximum
(
x
.
max
(
1
)[
0
],
tmp
)
if
self
.
sym
:
xmax
=
torch
.
maximum
(
torch
.
abs
(
xmin
),
xmax
)
tmp
=
xmin
<
0
if
torch
.
any
(
tmp
):
xmin
[
tmp
]
=
-
xmax
[
tmp
]
tmp
=
(
xmin
==
0
)
&
(
xmax
==
0
)
xmin
[
tmp
]
=
-
1
xmax
[
tmp
]
=
+
1
if
self
.
maxq
<
0
:
self
.
scale
=
xmax
self
.
zero
=
xmin
else
:
self
.
scale
=
(
xmax
-
xmin
)
/
self
.
maxq
if
self
.
sym
:
self
.
zero
=
torch
.
full_like
(
self
.
scale
,
(
self
.
maxq
+
1
)
/
2
)
else
:
self
.
zero
=
torch
.
round
(
-
xmin
/
self
.
scale
)
if
self
.
mse
:
best
=
torch
.
full
([
x
.
shape
[
0
]],
float
(
"inf"
),
device
=
dev
)
for
i
in
range
(
int
(
self
.
maxshrink
*
self
.
grid
)):
p
=
1
-
i
/
self
.
grid
xmin1
=
p
*
xmin
xmax1
=
p
*
xmax
scale1
=
(
xmax1
-
xmin1
)
/
self
.
maxq
zero1
=
torch
.
round
(
-
xmin1
/
scale1
)
if
not
self
.
sym
else
self
.
zero
q
=
quantize
(
x
,
scale1
.
unsqueeze
(
1
),
zero1
.
unsqueeze
(
1
),
self
.
maxq
)
q
-=
x
q
.
abs_
()
q
.
pow_
(
self
.
norm
)
err
=
torch
.
sum
(
q
,
1
)
tmp
=
err
<
best
if
torch
.
any
(
tmp
):
best
[
tmp
]
=
err
[
tmp
]
self
.
scale
[
tmp
]
=
scale1
[
tmp
]
self
.
zero
[
tmp
]
=
zero1
[
tmp
]
if
not
self
.
perchannel
:
if
weight
:
tmp
=
shape
[
0
]
else
:
tmp
=
shape
[
1
]
if
len
(
shape
)
!=
3
else
shape
[
2
]
self
.
scale
=
self
.
scale
.
repeat
(
tmp
)
self
.
zero
=
self
.
zero
.
repeat
(
tmp
)
if
weight
:
shape
=
[
-
1
]
+
[
1
]
*
(
len
(
shape
)
-
1
)
self
.
scale
=
self
.
scale
.
reshape
(
shape
)
self
.
zero
=
self
.
zero
.
reshape
(
shape
)
return
if
len
(
shape
)
==
4
:
self
.
scale
=
self
.
scale
.
reshape
((
1
,
-
1
,
1
,
1
))
self
.
zero
=
self
.
zero
.
reshape
((
1
,
-
1
,
1
,
1
))
if
len
(
shape
)
==
3
:
self
.
scale
=
self
.
scale
.
reshape
((
1
,
1
,
-
1
))
self
.
zero
=
self
.
zero
.
reshape
((
1
,
1
,
-
1
))
if
len
(
shape
)
==
2
:
self
.
scale
=
self
.
scale
.
unsqueeze
(
0
)
self
.
zero
=
self
.
zero
.
unsqueeze
(
0
)
def
quantize
(
self
,
x
):
if
self
.
ready
():
return
quantize
(
x
,
self
.
scale
,
self
.
zero
,
self
.
maxq
)
return
x
def
enabled
(
self
):
return
self
.
maxq
>
0
def
ready
(
self
):
return
torch
.
all
(
self
.
scale
!=
0
)
__all__
=
[
"Quantizer"
]
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/repack.py
deleted
100644 → 0
View file @
f2938031
import
torch
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
torch.nn.parameter
import
Parameter
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
reshaped_x
.
shape
[
0
]
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
full_size_k
=
layer
.
input_size
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
layer
.
marlin_state
=
GPTQMarlinState
.
READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
View file @
f5f79f5c
...
...
@@ -220,7 +220,7 @@ def compute_max_diff(output, output_ref):
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
,
device
):
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
...
...
@@ -229,4 +229,4 @@ class MarlinWorkspace:
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
device
=
device
)
ktransformers/ktransformers_ext/operators/llamafile/linear.cpp
View file @
f5f79f5c
...
...
@@ -47,13 +47,13 @@ void Linear::forward_many(int qlen, const void* input, void* output, Backend* ba
int
nth
=
config_
.
output_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
void
*
proj_ptr
=
proj_
+
ith
*
config_
.
stride
*
config_
.
input_size
*
ggml_type_size
(
config_
.
proj_type
)
/
ggml_blck_size
(
config_
.
proj_type
);
void
*
proj_ptr
=
(
uint8_t
*
)
proj_
+
ith
*
config_
.
stride
*
config_
.
input_size
*
ggml_type_size
(
config_
.
proj_type
)
/
ggml_blck_size
(
config_
.
proj_type
);
float
*
proj_output_ptr
=
proj_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_ptr
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_input_ptr
,
config_
.
input_size
/
ggml_blck_size
(
config_
.
proj_type
),
proj_output_ptr
,
config_
.
output_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
proj_type
,
ggml_internal_get_type_traits
(
config_
.
proj_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
float
*
output_fp32_ptr
=
proj_output_
+
i
*
config_
.
output_size
+
ith
*
config_
.
stride
;
void
*
output_ptr
=
output
+
i
*
config_
.
output_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
void
*
output_ptr
=
(
uint8_t
*
)
output
+
i
*
config_
.
output_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
from_float
(
output_fp32_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
}
}
...
...
@@ -69,5 +69,5 @@ void Linear::forward(int qlen, const void* input, void* output, Backend* backend
}
int
forward_len
=
std
::
min
(
qlen
,
config_
.
group_max_len
);
forward_many
(
forward_len
,
input
,
output
,
backend
);
forward
(
qlen
-
forward_len
,
input
+
forward_len
*
config_
.
input_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
output
+
forward_len
*
config_
.
output_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
forward
(
qlen
-
forward_len
,
(
uint8_t
*
)
input
+
forward_len
*
config_
.
input_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
(
uint8_t
*
)
output
+
forward_len
*
config_
.
output_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
}
\ No newline at end of file
ktransformers/ktransformers_ext/operators/llamafile/mlp.cpp
View file @
f5f79f5c
...
...
@@ -74,10 +74,10 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
int
nth
=
config_
.
intermediate_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
void
*
gate_proj_ptr
=
gate_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
void
*
gate_proj_ptr
=
(
uint8_t
*
)
gate_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
gate_type
)
/
ggml_blck_size
(
config_
.
gate_type
);
float
*
gate_output_ptr
=
gate_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
gate_type
),
gate_output_ptr
,
config_
.
intermediate_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
gate_type
,
ggml_internal_get_type_traits
(
config_
.
gate_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
void
*
up_proj_ptr
=
up_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
void
*
up_proj_ptr
=
(
uint8_t
*
)
up_proj_
+
ith
*
config_
.
stride
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
up_type
)
/
ggml_blck_size
(
config_
.
up_type
);
float
*
up_output_ptr
=
up_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_proj_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_input_ptr
,
config_
.
hidden_size
/
ggml_blck_size
(
config_
.
up_type
),
up_output_ptr
,
config_
.
intermediate_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
up_type
,
ggml_internal_get_type_traits
(
config_
.
up_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
...
...
@@ -86,7 +86,7 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
}
if
(
config_
.
stride
%
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
==
0
)
{
float
*
intermediate_fp32_ptr
=
intermediate_fp32_
+
i
*
config_
.
intermediate_size
+
ith
*
config_
.
stride
;
void
*
down_input_ptr
=
down_input_
+
i
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
void
*
down_input_ptr
=
(
uint8_t
*
)
down_input_
+
i
*
config_
.
intermediate_size
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
)
/
ggml_blck_size
(
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
from_float
(
intermediate_fp32_ptr
,
down_input_ptr
,
config_
.
stride
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
);
}
}
...
...
@@ -97,13 +97,13 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
nth
=
config_
.
hidden_size
/
config_
.
stride
;
backend
->
do_work_stealing_job
(
nth
,
[
&
](
int
task_id
)
{
int
ith
=
task_id
;
void
*
down_proj_ptr
=
down_proj_
+
ith
*
config_
.
stride
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
void
*
down_proj_ptr
=
(
uint8_t
*
)
down_proj_
+
ith
*
config_
.
stride
*
config_
.
intermediate_size
*
ggml_type_size
(
config_
.
down_type
)
/
ggml_blck_size
(
config_
.
down_type
);
float
*
down_output_ptr
=
down_output_
+
ith
*
config_
.
stride
;
llamafile_sgemm
(
config_
.
stride
,
qlen
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_proj_ptr
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_input_
,
config_
.
intermediate_size
/
ggml_blck_size
(
config_
.
down_type
),
down_output_ptr
,
config_
.
hidden_size
,
0
,
1
,
GGML_TASK_TYPE_COMPUTE
,
config_
.
down_type
,
ggml_internal_get_type_traits
(
config_
.
down_type
).
vec_dot_type
,
GGML_TYPE_F32
,
GGML_PREC_DEFAULT
);
if
(
config_
.
stride
%
ggml_blck_size
(
config_
.
hidden_type
)
==
0
)
{
for
(
int
i
=
0
;
i
<
qlen
;
i
++
)
{
float
*
output_fp32_ptr
=
down_output_
+
i
*
config_
.
hidden_size
+
ith
*
config_
.
stride
;
void
*
output_ptr
=
output
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
void
*
output_ptr
=
(
uint8_t
*
)
output
+
i
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
)
+
ith
*
config_
.
stride
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
);
from_float
(
output_fp32_ptr
,
output_ptr
,
config_
.
stride
,
config_
.
hidden_type
);
}
}
...
...
@@ -119,5 +119,5 @@ void MLP::forward(int qlen, const void* input, void* output, Backend* backend) {
}
int
forward_len
=
std
::
min
(
qlen
,
config_
.
group_max_len
);
forward_many
(
forward_len
,
input
,
output
,
backend
);
forward
(
qlen
-
forward_len
,
input
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
output
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
forward
(
qlen
-
forward_len
,
(
uint8_t
*
)
input
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
(
uint8_t
*
)
output
+
forward_len
*
config_
.
hidden_size
*
ggml_type_size
(
config_
.
hidden_type
)
/
ggml_blck_size
(
config_
.
hidden_type
),
backend
);
}
\ No newline at end of file
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment