Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
d26f4c73
Commit
d26f4c73
authored
May 27, 2024
by
gaoqiong
Browse files
增加awq模块
parent
2326380c
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
607 additions
and
35 deletions
+607
-35
src/turbomind/models/llama/LlamaLinear.h
src/turbomind/models/llama/LlamaLinear.h
+214
-18
src/turbomind/models/llama/LlamaWeight.cc
src/turbomind/models/llama/LlamaWeight.cc
+19
-1
src/turbomind/models/llama/LlamaWeight.h
src/turbomind/models/llama/LlamaWeight.h
+1
-0
src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cu
src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cu
+143
-0
src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cuh
src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cuh
+36
-0
src/turbomind/models/llama/awq_sugon/lmdeploy_sugon.cuh
src/turbomind/models/llama/awq_sugon/lmdeploy_sugon.cuh
+25
-0
src/turbomind/python/bind.cpp
src/turbomind/python/bind.cpp
+43
-14
src/turbomind/triton_backend/llama/LlamaTritonModel.cc
src/turbomind/triton_backend/llama/LlamaTritonModel.cc
+2
-0
src/turbomind/triton_backend/llama/LlamaTritonModel.h
src/turbomind/triton_backend/llama/LlamaTritonModel.h
+2
-1
src/turbomind/utils/cublasMMWrapper.cc
src/turbomind/utils/cublasMMWrapper.cc
+108
-0
src/turbomind/utils/cublasMMWrapper.h
src/turbomind/utils/cublasMMWrapper.h
+6
-0
src/turbomind/utils/cuda_utils.h
src/turbomind/utils/cuda_utils.h
+8
-1
No files found.
src/turbomind/models/llama/LlamaLinear.h
View file @
d26f4c73
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
#pragma once
#pragma once
// #include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
#include "src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cuh"
#include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
...
@@ -41,7 +42,27 @@ public:
...
@@ -41,7 +42,27 @@ public:
FT_CHECK
(
0
);
FT_CHECK
(
0
);
}
}
}
}
void
forward_ffn
(
T
*
output_data
,
T
*
output_tmp
,
const
T
*
input_data
,
int
batch_size
,
const
LlamaDenseWeight
<
T
>&
weight
,
Type
type
=
kGemm
)
{
switch
(
weight
.
type
)
{
case
WeightType
::
kFP16
:
case
WeightType
::
kFP32
:
case
WeightType
::
kBF16
:
forwardFp
(
output_data
,
input_data
,
batch_size
,
weight
,
type
);
break
;
case
WeightType
::
kINT4
:
{
if
(
type
==
kFusedSiluFfn
)
forwardInt4_ffn
(
output_data
,
output_tmp
,
input_data
,
batch_size
,
weight
,
type
);
else
forwardInt4
(
output_data
,
input_data
,
batch_size
,
weight
,
type
);
break
;
}
default:
FT_CHECK
(
0
);
}
}
private:
private:
void
forwardFp
(
T
*
output_data
,
const
T
*
input_data
,
int
batch_size
,
const
LlamaDenseWeight
<
T
>&
weight
,
Type
type
)
void
forwardFp
(
T
*
output_data
,
const
T
*
input_data
,
int
batch_size
,
const
LlamaDenseWeight
<
T
>&
weight
,
Type
type
)
{
{
...
@@ -62,23 +83,198 @@ private:
...
@@ -62,23 +83,198 @@ private:
void
forwardInt4
(
T
*
output_data
,
const
T
*
input_data
,
int
batch_size
,
const
LlamaDenseWeight
<
T
>&
weight
,
Type
type
)
void
forwardInt4
(
T
*
output_data
,
const
T
*
input_data
,
int
batch_size
,
const
LlamaDenseWeight
<
T
>&
weight
,
Type
type
)
{
{
// if constexpr (std::is_same_v<T, half>) {
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
// gemm_s4_f16_.Run(output_data,
// (const uint*)weight.kernel,
if
(
weight
.
w4_weight_layout
==
0
)
//普通NN模式 rocblas
// input_data,
{
// (const half2*)weight.scales_and_zeros,
//检查DQweight的空间是否足够
// weight.output_dims,
if
(
batch_size
*
weight
.
output_dims
>
M_max
*
N_max
)
// batch_size,
{
// weight.input_dims,
FT_CHECK_WITH_INFO
(
0
,
"error! batch_size>N_max ||weight.output_dims>N_max"
);
// weight.group_size,
}
// type == kFusedSiluFfn ? GemmS4F16::kFusedSiluFfn : GemmS4F16::kGemm,
// -1,
dequant_w4_gemm
(
stream_
,
reinterpret_cast
<
T
*>
(
cublas_wrapper_
->
deweight_workspace_
),(
const
uint32_t
*
)
weight
.
kernel
,(
const
half2
*
)
weight
.
scales_and_zeros
,
weight
.
input_dims
,
weight
.
output_dims
,
weight
.
group_size
);
// stream_);
cublas_wrapper_
->
Gemm
(
CUBLAS_OP_N
,
// sync_check_cuda_error();
CUBLAS_OP_N
,
// }
weight
.
output_dims
,
//m
// else {
batch_size
,
//n
weight
.
input_dims
,
//k
(
const
T
*
)
cublas_wrapper_
->
deweight_workspace_
,
//[]
weight
.
output_dims
,
//m
input_data
,
weight
.
input_dims
,
//k
output_data
,
weight
.
output_dims
);
//m
}
else
if
(
weight
.
w4_weight_layout
==
1
)
//TN模式 padding rocblas
{
//检查DQweight的空间是否足够
if
(
batch_size
*
weight
.
output_dims
>
M_max
*
N_max
)
{
FT_CHECK_WITH_INFO
(
0
,
"error! batch_size>N_max ||weight.output_dims>N_max"
);
}
//检查xpad空间是否足够
if
(
weight
.
input_dims
%
4096
==
0
)
//需要进行pad
{
int
pad_group_count
=
2
;
input_padding
(
stream_
,
reinterpret_cast
<
half
*>
(
cublas_wrapper_
->
xpading_workspace_
),(
const
T
*
)
input_data
,
batch_size
,
weight
.
input_dims
,
weight
.
group_size
,
pad_group_count
);
dequant_w4_gemm_colmajor
(
stream_
,
reinterpret_cast
<
T
*>
(
cublas_wrapper_
->
deweight_workspace_
),(
const
uint32_t
*
)
weight
.
kernel
,(
const
half2
*
)
weight
.
scales_and_zeros
,
weight
.
input_dims
+
pad_group_count
*
weight
.
group_size
,
weight
.
output_dims
,
weight
.
group_size
);
cublas_wrapper_
->
Gemm
(
CUBLAS_OP_T
,
CUBLAS_OP_N
,
weight
.
output_dims
,
//m
batch_size
,
//n
weight
.
input_dims
+
pad_group_count
*
weight
.
group_size
,
//k
(
const
T
*
)
reinterpret_cast
<
T
*>
(
cublas_wrapper_
->
deweight_workspace_
),
//[]
weight
.
input_dims
+
pad_group_count
*
weight
.
group_size
,
//k
(
const
T
*
)
cublas_wrapper_
->
xpading_workspace_
,
weight
.
input_dims
+
pad_group_count
*
weight
.
group_size
,
//k
output_data
,
weight
.
output_dims
);
//m
}
else
//不需要进行pad
{
dequant_w4_gemm_colmajor
(
stream_
,
reinterpret_cast
<
T
*>
(
cublas_wrapper_
->
deweight_workspace_
),(
const
uint32_t
*
)
weight
.
kernel
,(
const
half2
*
)
weight
.
scales_and_zeros
,
weight
.
input_dims
,
weight
.
output_dims
,
weight
.
group_size
);
cublas_wrapper_
->
Gemm
(
CUBLAS_OP_T
,
CUBLAS_OP_N
,
weight
.
output_dims
,
//m
batch_size
,
//n
weight
.
input_dims
,
//k
(
const
T
*
)
reinterpret_cast
<
T
*>
(
cublas_wrapper_
->
deweight_workspace_
),
//[]
weight
.
input_dims
,
//k
input_data
,
weight
.
input_dims
,
//k
output_data
,
weight
.
output_dims
);
//m
}
}
else
if
(
weight
.
w4_weight_layout
==
2
)
//TN 模式padding ck
{
//检查ck workspace 的空间是否足够
if
(
weight
.
input_dims
%
4096
==
0
)
{
int
pad_groupcount
=
2
;
run_weight_only_gemm
(
reinterpret_cast
<
const
void
*>
(
input_data
),
reinterpret_cast
<
const
void
*>
(
weight
.
kernel
),
reinterpret_cast
<
const
void
*>
(
weight
.
scales_and_zeros
),
reinterpret_cast
<
void
*>
(
output_data
),
batch_size
,
weight
.
output_dims
,
(
weight
.
input_dims
),
(
weight
.
input_dims
),(
weight
.
input_dims
),
(
weight
.
input_dims
+
pad_groupcount
*
weight
.
group_size
),
weight
.
output_dims
,
weight
.
group_size
,
reinterpret_cast
<
void
*>
(
cublas_wrapper_
->
ck_workspace_
),
CK_WORKSPACE_SIZE
,(
hipStream_t
)
stream_
);
}
// A B0 B1 C M N K strideA strideB strideBpad strideC group_size
else
{
run_weight_only_gemm
(
reinterpret_cast
<
const
void
*>
(
input_data
),
reinterpret_cast
<
const
void
*>
(
weight
.
kernel
),
reinterpret_cast
<
const
void
*>
(
weight
.
scales_and_zeros
),
reinterpret_cast
<
void
*>
(
output_data
),
batch_size
,
weight
.
output_dims
,
(
weight
.
input_dims
),
(
weight
.
input_dims
),(
weight
.
input_dims
),
(
weight
.
input_dims
),
weight
.
output_dims
,
weight
.
group_size
,
reinterpret_cast
<
void
*>
(
cublas_wrapper_
->
ck_workspace_
),
CK_WORKSPACE_SIZE
,(
hipStream_t
)
stream_
);
}
}
if
(
cublas_wrapper_
->
m_dump_switch
==
2
)
{
std
::
cout
<<
" m: "
<<
batch_size
<<
" n: "
<<
weight
.
output_dims
<<
" k: "
<<
weight
.
input_dims
<<
std
::
endl
;
PrintScale
<
T
>
(
stream_
,
output_data
,
36
,
0
,
0
,
0
);
}
sync_check_cuda_error
();
}
else
{
FT_CHECK_WITH_INFO
(
0
,
"Not implemented"
);
}
}
void
forwardInt4_ffn
(
T
*
output_data
,
T
*
output_tmp
,
const
T
*
input_data
,
int
batch_size
,
const
LlamaDenseWeight
<
T
>&
weight
,
Type
type
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
if
(
weight
.
w4_weight_layout
==
0
)
//普通NN模式 rocblas
{
//检查DQweight的空间是否足够
if
(
batch_size
*
weight
.
output_dims
>
M_max
*
N_max
)
{
FT_CHECK_WITH_INFO
(
0
,
"error! batch_size>N_max ||weight.output_dims>N_max"
);
}
dequant_w4_gemm
(
stream_
,
reinterpret_cast
<
T
*>
(
cublas_wrapper_
->
deweight_workspace_
),(
const
uint32_t
*
)
weight
.
kernel
,(
const
half2
*
)
weight
.
scales_and_zeros
,
weight
.
input_dims
,
weight
.
output_dims
,
weight
.
group_size
);
cublas_wrapper_
->
Gemm
(
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
output_dims
,
//m
batch_size
,
//n
weight
.
input_dims
,
//k
(
const
T
*
)
cublas_wrapper_
->
deweight_workspace_
,
//[]
weight
.
output_dims
,
//m
input_data
,
weight
.
input_dims
,
//k
output_tmp
,
weight
.
output_dims
);
//m
}
else
if
(
weight
.
w4_weight_layout
==
1
)
//TN模式 padding rocblas
{
//检查DQweight的空间是否足够
if
(
batch_size
*
weight
.
output_dims
>
M_max
*
N_max
)
{
FT_CHECK_WITH_INFO
(
0
,
"error! batch_size>N_max ||weight.output_dims>N_max"
);
}
//检查xpad空间是否足够
if
(
weight
.
input_dims
%
4096
==
0
)
//需要进行pad
{
int
pad_group_count
=
2
;
input_padding
<
T
>
(
stream_
,
reinterpret_cast
<
half
*>
(
cublas_wrapper_
->
xpading_workspace_
),(
const
T
*
)
input_data
,
batch_size
,
weight
.
input_dims
,
weight
.
group_size
,
pad_group_count
);
dequant_w4_gemm_colmajor
(
stream_
,
reinterpret_cast
<
T
*>
(
cublas_wrapper_
->
deweight_workspace_
),(
const
uint32_t
*
)
weight
.
kernel
,(
const
half2
*
)
weight
.
scales_and_zeros
,
weight
.
input_dims
+
pad_group_count
*
weight
.
group_size
,
weight
.
output_dims
,
weight
.
group_size
);
cublas_wrapper_
->
Gemm
(
CUBLAS_OP_T
,
CUBLAS_OP_N
,
weight
.
output_dims
,
//m
batch_size
,
//n
weight
.
input_dims
+
pad_group_count
*
weight
.
group_size
,
//k
(
const
T
*
)
reinterpret_cast
<
T
*>
(
cublas_wrapper_
->
deweight_workspace_
),
//[]
weight
.
input_dims
+
pad_group_count
*
weight
.
group_size
,
//k
(
const
T
*
)
cublas_wrapper_
->
xpading_workspace_
,
weight
.
input_dims
+
pad_group_count
*
weight
.
group_size
,
//k
output_tmp
,
weight
.
output_dims
);
//m
}
else
//不需要进行pad
{
dequant_w4_gemm_colmajor
(
stream_
,
reinterpret_cast
<
T
*>
(
cublas_wrapper_
->
deweight_workspace_
),(
const
uint32_t
*
)
weight
.
kernel
,(
const
half2
*
)
weight
.
scales_and_zeros
,
weight
.
input_dims
,
weight
.
output_dims
,
weight
.
group_size
);
cublas_wrapper_
->
Gemm
(
CUBLAS_OP_T
,
CUBLAS_OP_N
,
weight
.
output_dims
,
//m
batch_size
,
//n
weight
.
input_dims
,
//k
(
const
T
*
)
reinterpret_cast
<
T
*>
(
cublas_wrapper_
->
deweight_workspace_
),
//[]
weight
.
input_dims
,
//k
input_data
,
weight
.
input_dims
,
//k
output_tmp
,
weight
.
output_dims
);
//m
}
}
else
if
(
weight
.
w4_weight_layout
==
2
)
//TN 模式padding ck
{
//检查ck workspace 的空间是否足够
if
(
batch_size
*
weight
.
output_dims
>
M_max
*
N_max
)
{
FT_CHECK_WITH_INFO
(
0
,
"error! ck workspace is not enough"
);
}
if
(
weight
.
input_dims
%
4096
==
0
)
{
int
pad_groupcount
=
2
;
run_weight_only_gemm
(
reinterpret_cast
<
const
void
*>
(
input_data
),
reinterpret_cast
<
const
void
*>
(
weight
.
kernel
),
reinterpret_cast
<
const
void
*>
(
weight
.
scales_and_zeros
),
reinterpret_cast
<
void
*>
(
output_tmp
),
batch_size
,
weight
.
output_dims
,
(
weight
.
input_dims
),
(
weight
.
input_dims
),(
weight
.
input_dims
),
(
weight
.
input_dims
+
pad_groupcount
*
weight
.
group_size
),
weight
.
output_dims
,
weight
.
group_size
,
reinterpret_cast
<
void
*>
(
cublas_wrapper_
->
ck_workspace_
),
CK_WORKSPACE_SIZE
,(
hipStream_t
)
stream_
);
}
// A B0 B1 C M N K strideA strideB strideBpad strideC group_size
else
{
run_weight_only_gemm
(
reinterpret_cast
<
const
void
*>
(
input_data
),
reinterpret_cast
<
const
void
*>
(
weight
.
kernel
),
reinterpret_cast
<
const
void
*>
(
weight
.
scales_and_zeros
),
reinterpret_cast
<
void
*>
(
output_tmp
),
batch_size
,
weight
.
output_dims
,
(
weight
.
input_dims
),
(
weight
.
input_dims
),(
weight
.
input_dims
),
(
weight
.
input_dims
),
weight
.
output_dims
,
weight
.
group_size
,
reinterpret_cast
<
void
*>
(
cublas_wrapper_
->
ck_workspace_
),
CK_WORKSPACE_SIZE
,(
hipStream_t
)
stream_
);
}
}
addFusedSiluActivation
(
stream_
,
output_data
,
output_tmp
,
batch_size
,
weight
.
output_dims
,
1
);
if
(
cublas_wrapper_
->
m_dump_switch
==
2
)
{
std
::
cout
<<
" m: "
<<
batch_size
<<
" n: "
<<
weight
.
output_dims
<<
" k: "
<<
weight
.
input_dims
<<
std
::
endl
;
PrintScale
<
T
>
(
stream_
,
output_data
,
36
,
0
,
0
,
0
);
}
sync_check_cuda_error
();
}
else
{
FT_CHECK_WITH_INFO
(
0
,
"Not implemented"
);
FT_CHECK_WITH_INFO
(
0
,
"Not implemented"
);
//
}
}
}
}
private:
private:
...
...
src/turbomind/models/llama/LlamaWeight.cc
View file @
d26f4c73
...
@@ -32,6 +32,7 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
...
@@ -32,6 +32,7 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
bool
attn_bias
,
bool
attn_bias
,
WeightType
weight_type
,
WeightType
weight_type
,
int
group_size
,
int
group_size
,
int
w4_weight_layout
,
size_t
tensor_para_size
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
)
:
size_t
tensor_para_rank
)
:
hidden_units_
(
head_num
*
size_per_head
),
hidden_units_
(
head_num
*
size_per_head
),
...
@@ -55,11 +56,28 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
...
@@ -55,11 +56,28 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
inter_size_
,
inter_size_
,
weight_type_
,
weight_type_
,
group_size
,
group_size
,
w4_weight_layout
,
attn_bias
,
attn_bias
,
tensor_para_size_
,
tensor_para_size_
,
tensor_para_rank_
));
tensor_para_rank_
));
}
}
// 这同样会将环境变量 MY_VARIABLE 设置为 my_value,并且最后一个参数 1 表示如果变量已经存在,是否覆盖。如果为 1,则会覆盖原有的值;如果为 0,则不会覆盖,保持原有的值不变。
char
*
env_name
=
"LMDEPLOY_WEIGHTLAYOUT_SWITCH"
;
if
(
weight_type_
==
WeightType
::
kINT4
){
std
::
string
str_w4_weight_layout
=
std
::
to_string
(
w4_weight_layout
);
const
char
*
env_value
=
str_w4_weight_layout
.
c_str
();
setenv
(
env_name
,
env_value
,
1
);
printf
(
"set LMDEPLOY_WEIGHTLAYOUT_SWITCH env: %d
\n
"
,
w4_weight_layout
);
}
else
{
std
::
string
str_w4_weight_layout
=
std
::
to_string
(
-
1
);
const
char
*
env_value
=
str_w4_weight_layout
.
c_str
();
setenv
(
env_name
,
env_value
,
1
);
printf
(
"set LMDEPLOY_WEIGHTLAYOUT_SWITCH env: %d
\n
"
,
w4_weight_layout
);
}
mallocWeights
();
mallocWeights
();
}
}
...
...
src/turbomind/models/llama/LlamaWeight.h
View file @
d26f4c73
...
@@ -37,6 +37,7 @@ struct LlamaWeight {
...
@@ -37,6 +37,7 @@ struct LlamaWeight {
bool
attn_bias
,
bool
attn_bias
,
WeightType
weight_type
,
WeightType
weight_type
,
int
group_size
,
int
group_size
,
int
w4_weight_layout
,
size_t
tensor_para_size
,
size_t
tensor_para_size
,
size_t
tensor_para_rank
);
size_t
tensor_para_rank
);
...
...
src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cu
0 → 100644
View file @
d26f4c73
#include "src/turbomind/models/llama/awq_sugon/lmdeploy_sugon.cuh"
#include "src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cuh"
template
<
typename
T
>
__global__
void
add_kernel
(
int
n
,
T
*
A
,
const
T
*
B
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
n
)
return
;
A
[
id
]
=
A
[
id
]
+
B
[
id
];
}
template
<
typename
T
>
__global__
void
assign_kernel
(
int
n
,
T
*
A
,
const
T
*
B
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
n
)
return
;
A
[
id
]
=
B
[
id
];
}
template
<
typename
T
>
void
assign_fun
(
cudaStream_t
stream
,
T
*
A
,
const
T
*
B
,
int
size
)
{
int
num_kernels
=
size
;
assign_kernel
<<<
(
num_kernels
+
BLOCKSIZE
-
1
)
/
BLOCKSIZE
,
BLOCKSIZE
,
0
,
stream
>>>
(
num_kernels
,
A
,
B
);
}
#define INSTANTIATEASSIGN(T) \
template void assign_fun(cudaStream_t stream, T* A,const T* B,int size);
INSTANTIATEASSIGN
(
__half
)
INSTANTIATEASSIGN
(
float
)
INSTANTIATEASSIGN
(
half2
)
INSTANTIATEASSIGN
(
uint
)
template
<
typename
T
>
void
PrintScale
(
cudaStream_t
stream
,
const
T
*
data
,
int
size
,
int
flag
,
int
m
,
int
n
){
printf
(
"start printf ****
\n
"
);
int
input_size
=
size
;
T
*
h_data
;
h_data
=
new
T
[
input_size
];
T
*
d_data
;
cudaMalloc
((
void
**
)
&
d_data
,
input_size
*
sizeof
(
T
));
//进行初始化
// for(int i=0;i<input_size;i++)
// {
// h_data[i] = __float2half(2.0f);
// }
// cudaMemcpy(d_data, h_data, input_size * sizeof(T), cudaMemcpyHostToDevice);
assign_fun
<
T
>
(
stream
,
d_data
,
data
,
input_size
);
cudaStreamSynchronize
(
stream
);
cudaMemcpy
(
h_data
,
d_data
,
input_size
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
if
(
flag
!=
0
)
{
std
::
string
file_name
=
"/FrameWork/nvidia_file/elsetest/data"
+
std
::
to_string
(
flag
)
+
".bin"
;
std
::
ofstream
outfile
(
file_name
,
std
::
ios
::
binary
);
if
(
!
outfile
)
{
std
::
cerr
<<
"Failed to open the file for writing."
<<
std
::
endl
;
}
outfile
.
write
(
reinterpret_cast
<
const
char
*>
(
h_data
),
m
*
n
*
sizeof
(
T
));
outfile
.
close
();
}
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
for
(
int
i
=
0
;
i
<
input_size
;
i
++
)
{
printf
(
"%f "
,
__half2float
(
h_data
[
i
]));
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
for
(
int
i
=
0
;
i
<
input_size
;
i
++
)
{
printf
(
"x:%f y:%f "
,
__half2float
(
h_data
[
i
].
data
[
0
]),
__half2float
(
h_data
[
i
].
data
[
1
]));
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
uint
>
)
{
for
(
int
i
=
0
;
i
<
input_size
;
i
++
)
{
printf
(
" %u "
,
h_data
[
i
]);
}
}
printf
(
"
\n
"
);
delete
[]
h_data
;
cudaFree
(
d_data
);
return
;
}
#define INSTANTIATEPRINT(T) \
template void PrintScale(cudaStream_t stream,const T* data,int size,int flag,int m,int n);
INSTANTIATEPRINT
(
__half
)
INSTANTIATEPRINT
(
float
)
INSTANTIATEPRINT
(
half2
)
INSTANTIATEPRINT
(
uint32_t
)
template
<
typename
T
>
__global__
void
input_padding_kernel
(
int
num_kernels
,
T
*
output
,
const
T
*
input
,
int
m
,
int
k
,
int
group_size
,
int
count
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
int
j
=
id
%
(
k
+
count
*
group_size
);
int
i
=
id
/
(
k
+
count
*
group_size
);
if
(
j
<
k
)
{
output
[
i
*
(
k
+
count
*
group_size
)
+
j
]
=
input
[
i
*
(
k
)
+
j
];
}
else
{
output
[
i
*
(
k
+
count
*
group_size
)
+
j
]
=
0.
f
;
}
}
template
<
typename
T
>
void
input_padding
(
cudaStream_t
stream
,
T
*
output
,
const
T
*
input
,
int
m
,
int
k
,
int
group_size
,
int
pad_groupcount
)
{
//input的size是[m,k],output的size是[m,n+group_size]
//
int
num_kernels
=
m
*
(
k
+
pad_groupcount
*
group_size
);
input_padding_kernel
<<<
(
num_kernels
+
BLOCKSIZE
-
1
)
/
BLOCKSIZE
,
BLOCKSIZE
,
0
,
stream
>>>
(
num_kernels
,
output
,
input
,
m
,
k
,
group_size
,
pad_groupcount
);
}
#define INSTANTIATEINPUTPADING(T) \
template void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount);
INSTANTIATEINPUTPADING
(
__half
)
src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cuh
0 → 100644
View file @
d26f4c73
#pragma once
#include "src/turbomind/models/llama/awq_sugon/lmdeploy_sugon.cuh"
#include <string>
#include <iostream>
#include <fstream>
typedef
struct
ihipStream_t
*
hipStream_t
;
// template <typename T>
// void dequant_w4_gemm(cudaStream_t stream, T* output,const uint32_t* weight,const half2* zeros_and_scales,int k,int n,int group_size);
template
<
typename
T
>
void
PrintScale
(
cudaStream_t
stream
,
const
T
*
data
,
int
size
,
int
flag
,
int
m
,
int
n
);
template
<
typename
T
>
void
assign_fun
(
cudaStream_t
stream
,
T
*
A
,
const
T
*
B
,
int
size
);
extern
void
run_weight_only_gemm
(
const
void
*
A
,
const
void
*
B0
,
const
void
*
B1
,
void
*
C
,
int
M
,
int
N
,
int
K
,
int
StrideA
,
int
StrideB
,
int
StrideB_padded
,
// 输入的权重矩阵添加pad后的K
int
StrideC
,
int
Group
,
void
*
splitK_padA_workspace
,
// 用于SplitK和tensorA添加pad的显存空间
int
splitK_padA_workspace_elementSize
,
// workspace有多少个bit
hipStream_t
stream_id
=
0
);
template
<
typename
T
>
void
input_padding
(
cudaStream_t
stream
,
T
*
output
,
const
T
*
input
,
int
m
,
int
k
,
int
group_size
,
int
pad_groupcount
);
src/turbomind/models/llama/awq_sugon/lmdeploy_sugon.cuh
0 → 100644
View file @
d26f4c73
#define BLOCKSIZE 256
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cassert>
#include <cstdint>
#include <type_traits>
#include <sys/time.h>
#pragma once
struct
my_timer
{
timeval
ts
,
te
;
//起始时刻,终止时刻
float
dt
;
// 时间间隔,单位毫秒(ms)
void
start
(){
gettimeofday
(
&
ts
,
NULL
);
}
void
stop
(){
gettimeofday
(
&
te
,
NULL
);
long
int
dt_sec
=
te
.
tv_sec
-
ts
.
tv_sec
;
long
int
dt_usec
=
te
.
tv_usec
-
ts
.
tv_usec
;
dt
=
dt_sec
*
1.0e3
+
dt_usec
/
1.0e3
;
}
};
\ No newline at end of file
src/turbomind/python/bind.cpp
View file @
d26f4c73
//
#include "src/turbomind/kernels/gemm_s_f16/format.h"
#include "src/turbomind/kernels/gemm_s_f16/format.h"
#include "src/turbomind/python/dlpack.h"
#include "src/turbomind/python/dlpack.h"
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
...
@@ -457,15 +457,15 @@ PYBIND11_MODULE(_turbomind, m)
...
@@ -457,15 +457,15 @@ PYBIND11_MODULE(_turbomind, m)
auto
src_tensor
=
GetDLTensor
(
src
);
auto
src_tensor
=
GetDLTensor
(
src
);
auto
dst_tensor
=
GetDLTensor
(
dst
);
auto
dst_tensor
=
GetDLTensor
(
dst
);
//
turbomind::transpose_qk_s4_k_m8_hf(
turbomind
::
transpose_qk_s4_k_m8_hf
(
//
(uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, size_per_head, nullptr);
(
uint32_t
*
)
dst_tensor
.
data
,
(
const
uint32_t
*
)
src_tensor
.
data
,
m
,
k
,
size_per_head
,
nullptr
);
});
});
m
.
def
(
"fuse_w1_w3_s4_k_m8"
,
[](
py
::
object
src
,
py
::
object
dst
,
int
m
,
int
k
)
{
m
.
def
(
"fuse_w1_w3_s4_k_m8"
,
[](
py
::
object
src
,
py
::
object
dst
,
int
m
,
int
k
)
{
auto
src_tensor
=
GetDLTensor
(
src
);
auto
src_tensor
=
GetDLTensor
(
src
);
auto
dst_tensor
=
GetDLTensor
(
dst
);
auto
dst_tensor
=
GetDLTensor
(
dst
);
//
turbomind::fuse_w1_w3_s4_k_m8((uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, nullptr);
turbomind
::
fuse_w1_w3_s4_k_m8
((
uint32_t
*
)
dst_tensor
.
data
,
(
const
uint32_t
*
)
src_tensor
.
data
,
m
,
k
,
nullptr
);
});
});
m
.
def
(
"convert_s4_k_m8"
,
m
.
def
(
"convert_s4_k_m8"
,
...
@@ -485,16 +485,45 @@ PYBIND11_MODULE(_turbomind, m)
...
@@ -485,16 +485,45 @@ PYBIND11_MODULE(_turbomind, m)
auto
s
=
GetDLTensor
(
scales
);
auto
s
=
GetDLTensor
(
scales
);
auto
qz
=
GetDLTensor
(
qzeros
);
auto
qz
=
GetDLTensor
(
qzeros
);
// turbomind::convert_s4_k_m8((uint32_t*)a_dst.data,
turbomind
::
convert_s4_k_m8
((
uint32_t
*
)
a_dst
.
data
,
// (half2*)q_dst.data,
(
half2
*
)
q_dst
.
data
,
// (half*)w.data,
(
half
*
)
w
.
data
,
// (const uint32_t*)a_src.data,
(
const
uint32_t
*
)
a_src
.
data
,
// (const half*)s.data,
(
const
half
*
)
s
.
data
,
// (const uint32_t*)qz.data,
(
const
uint32_t
*
)
qz
.
data
,
// m,
m
,
// k,
k
,
// group_size,
group_size
,
// nullptr);
nullptr
);
});
m
.
def
(
"convert_s4_k_m8_"
,
[](
py
::
object
A_dst
,
py
::
object
Q_dst
,
py
::
object
ws
,
py
::
object
A_src
,
py
::
object
scales
,
py
::
object
qzeros
,
int
m
,
int
k
,
int
group_size
)
{
auto
a_dst
=
GetDLTensor
(
A_dst
);
auto
q_dst
=
GetDLTensor
(
Q_dst
);
auto
w
=
GetDLTensor
(
ws
);
auto
a_src
=
GetDLTensor
(
A_src
);
auto
s
=
GetDLTensor
(
scales
);
auto
qz
=
GetDLTensor
(
qzeros
);
turbomind
::
convert_s4_k_m8_
((
uint32_t
*
)
a_dst
.
data
,
(
half2
*
)
q_dst
.
data
,
(
half
*
)
w
.
data
,
(
const
uint32_t
*
)
a_src
.
data
,
(
const
half
*
)
s
.
data
,
(
const
uint32_t
*
)
qz
.
data
,
m
,
k
,
group_size
,
nullptr
);
});
});
m
.
def
(
"dequantize_s4"
,
[](
py
::
object
src
,
py
::
object
dst
)
{
m
.
def
(
"dequantize_s4"
,
[](
py
::
object
src
,
py
::
object
dst
)
{
...
...
src/turbomind/triton_backend/llama/LlamaTritonModel.cc
View file @
d26f4c73
...
@@ -186,6 +186,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
...
@@ -186,6 +186,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
attn_bias_
=
reader
.
GetInteger
(
"llama"
,
"attn_bias"
,
0
);
attn_bias_
=
reader
.
GetInteger
(
"llama"
,
"attn_bias"
,
0
);
quant_policy_
=
reader
.
GetInteger
(
"llama"
,
"quant_policy"
,
0
);
quant_policy_
=
reader
.
GetInteger
(
"llama"
,
"quant_policy"
,
0
);
group_size_
=
reader
.
GetInteger
(
"llama"
,
"group_size"
,
0
);
group_size_
=
reader
.
GetInteger
(
"llama"
,
"group_size"
,
0
);
w4_weight_layout_
=
reader
.
GetInteger
(
"llama"
,
"w4_weight_layout"
,
2
);
// rotary embedding parameters
// rotary embedding parameters
attn_params_
.
rotary_embedding_dim
=
reader
.
GetInteger
(
"llama"
,
"rotary_embedding"
);
attn_params_
.
rotary_embedding_dim
=
reader
.
GetInteger
(
"llama"
,
"rotary_embedding"
);
...
@@ -381,6 +382,7 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
...
@@ -381,6 +382,7 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
attn_bias_
,
attn_bias_
,
weight_type_
,
weight_type_
,
group_size_
,
group_size_
,
w4_weight_layout_
,
tensor_para_size_
,
tensor_para_size_
,
tensor_para_rank
);
tensor_para_rank
);
// model inited with model_dir
// model inited with model_dir
...
...
src/turbomind/triton_backend/llama/LlamaTritonModel.h
View file @
d26f4c73
...
@@ -101,7 +101,8 @@ private:
...
@@ -101,7 +101,8 @@ private:
bool
attn_bias_
;
bool
attn_bias_
;
int
quant_policy_
;
int
quant_policy_
;
int
group_size_
;
int
group_size_
;
int
w4_weight_layout_
;
// shared weights for each device
// shared weights for each device
std
::
vector
<
std
::
shared_ptr
<
ft
::
LlamaWeight
<
T
>>>
shared_weights_
;
std
::
vector
<
std
::
shared_ptr
<
ft
::
LlamaWeight
<
T
>>>
shared_weights_
;
...
...
src/turbomind/utils/cublasMMWrapper.cc
View file @
d26f4c73
...
@@ -36,9 +36,43 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle,
...
@@ -36,9 +36,43 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle,
mu_
(
mu
),
mu_
(
mu
),
allocator_
(
allocator
)
allocator_
(
allocator
)
{
{
//申请内存前读取环境变量确定weight_alyout格式
//m_weightlayout_switch = 0 -->nn 形式的rocblas
//m_weightlayout_switch = 1 -->tn pad 形式的rocblas
//m_weightlayout_switch = 2 -->tn pad 形式的ck
const
char
*
env_weightlayout_str
=
std
::
getenv
(
"LMDEPLOY_WEIGHTLAYOUT_SWITCH"
);
if
(
env_weightlayout_str
!=
nullptr
)
{
m_weightlayout_switch
=
std
::
stoi
(
env_weightlayout_str
);
}
const
char
*
env_dump_str
=
std
::
getenv
(
"LMDEPLOY_DUMP_SWITCH"
);
if
(
env_dump_str
!=
nullptr
)
{
m_dump_switch
=
std
::
stoi
(
env_dump_str
);
}
TM_LOG_DEBUG
(
__PRETTY_FUNCTION__
);
TM_LOG_DEBUG
(
__PRETTY_FUNCTION__
);
if
(
allocator_
!=
nullptr
)
{
if
(
allocator_
!=
nullptr
)
{
cublas_workspace_
=
allocator_
->
reMalloc
(
cublas_workspace_
,
CUBLAS_WORKSPACE_SIZE
,
false
);
cublas_workspace_
=
allocator_
->
reMalloc
(
cublas_workspace_
,
CUBLAS_WORKSPACE_SIZE
,
false
);
//当采用rocblas的时候或者采用ck并开启dump功能的时候需要申请反量化模块
if
(
m_weightlayout_switch
==
1
||
m_weightlayout_switch
==
0
||
(
m_weightlayout_switch
==
2
&&
m_dump_switch
==
1
))
{
//需要反量化后weight临时存储的空间
printf
(
"alloc space for deqeight
\n
"
);
deweight_workspace_
=
allocator_
->
reMalloc
(
deweight_workspace_
,
DEQ_WORKSPACE_SIZE
,
false
);
if
(
m_weightlayout_switch
==
1
||
(
m_weightlayout_switch
==
2
&&
m_dump_switch
==
1
))
{
printf
(
"alloc space for xpading
\n
"
);
printf
(
"weight layout is tn pading rocblas
\n
"
);
xpading_workspace_
=
allocator_
->
reMalloc
(
xpading_workspace_
,
XPAD_WORKSPACE_SIZE
,
false
);
}
}
else
if
(
m_weightlayout_switch
==
2
)
{
printf
(
"alloc space for ck workspace
\n
"
);
printf
(
"weight layout is tn pading ck
\n
"
);
ck_workspace_
=
allocator_
->
reMalloc
(
ck_workspace_
,
CK_WORKSPACE_SIZE
,
false
);
}
}
}
// hgemm-switch 0:fp32r,1:fp16r-fp32r,2:fp16r ----xzhou 20240427
// hgemm-switch 0:fp32r,1:fp16r-fp32r,2:fp16r ----xzhou 20240427
m_ihgemm_switch
=
0
;
m_ihgemm_switch
=
0
;
...
@@ -70,9 +104,38 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle,
...
@@ -70,9 +104,38 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle,
mu_
(
mu
),
mu_
(
mu
),
allocator_
(
allocator
)
allocator_
(
allocator
)
{
{
const
char
*
env_weightlayout_str
=
std
::
getenv
(
"LMDEPLOY_WEIGHTLAYOUT_SWITCH"
);
if
(
env_weightlayout_str
!=
nullptr
)
{
m_weightlayout_switch
=
std
::
stoi
(
env_weightlayout_str
);
}
const
char
*
env_dump_str
=
std
::
getenv
(
"LMDEPLOY_DUMP_SWITCH"
);
if
(
env_dump_str
!=
nullptr
)
{
m_dump_switch
=
std
::
stoi
(
env_dump_str
);
}
TM_LOG_DEBUG
(
__PRETTY_FUNCTION__
);
TM_LOG_DEBUG
(
__PRETTY_FUNCTION__
);
if
(
allocator_
!=
nullptr
)
{
if
(
allocator_
!=
nullptr
)
{
cublas_workspace_
=
allocator_
->
reMalloc
(
cublas_workspace_
,
CUBLAS_WORKSPACE_SIZE
,
false
);
cublas_workspace_
=
allocator_
->
reMalloc
(
cublas_workspace_
,
CUBLAS_WORKSPACE_SIZE
,
false
);
//当采用rocblas的时候或者采用ck并开启dump功能的时候需要申请反量化模块
if
(
m_weightlayout_switch
==
1
||
m_weightlayout_switch
==
0
||
(
m_weightlayout_switch
==
2
&&
m_dump_switch
==
1
))
{
//需要反量化后weight临时存储的空间
printf
(
"alloc space for deqeight
\n
"
);
deweight_workspace_
=
allocator_
->
reMalloc
(
deweight_workspace_
,
DEQ_WORKSPACE_SIZE
,
false
);
if
(
m_weightlayout_switch
==
1
||
(
m_weightlayout_switch
==
2
&&
m_dump_switch
=
1
))
{
printf
(
"alloc space for xpading
\n
"
);
printf
(
"weight layout is tn pading rocblas
\n
"
);
xpading_workspace_
=
allocator_
->
reMalloc
(
xpading_workspace_
,
XPAD_WORKSPACE_SIZE
,
false
);
}
}
else
if
(
m_weightlayout_switch
==
2
)
{
printf
(
"alloc space for ck workspace
\n
"
);
printf
(
"weight layout is tn pading ck
\n
"
);
ck_workspace_
=
allocator_
->
reMalloc
(
ck_workspace_
,
CK_WORKSPACE_SIZE
,
false
);
}
}
}
}
}
#endif
#endif
...
@@ -83,6 +146,22 @@ cublasMMWrapper::~cublasMMWrapper()
...
@@ -83,6 +146,22 @@ cublasMMWrapper::~cublasMMWrapper()
mu_
=
nullptr
;
mu_
=
nullptr
;
if
(
allocator_
!=
nullptr
)
{
if
(
allocator_
!=
nullptr
)
{
allocator_
->
free
((
void
**
)(
&
cublas_workspace_
));
allocator_
->
free
((
void
**
)(
&
cublas_workspace_
));
if
(
m_weightlayout_switch
==
1
||
m_weightlayout_switch
==
0
||
(
m_weightlayout_switch
==
2
&&
m_dump_switch
==
1
))
{
//需要反量化后weight临时存储的空间
printf
(
"free space for deqeight
\n
"
);
allocator_
->
free
((
void
**
)(
&
deweight_workspace_
));
if
(
m_weightlayout_switch
==
1
||
(
m_weightlayout_switch
==
2
&&
m_dump_switch
==
1
))
{
printf
(
"free space for xpading
\n
"
);
allocator_
->
free
((
void
**
)(
&
xpading_workspace_
));
}
}
else
if
(
m_weightlayout_switch
==
2
)
{
printf
(
"free space for ck workspace
\n
"
);
allocator_
->
free
((
void
**
)(
&
ck_workspace_
));
}
allocator_
=
nullptr
;
allocator_
=
nullptr
;
}
}
}
}
...
@@ -98,9 +177,38 @@ cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper& wrapper):
...
@@ -98,9 +177,38 @@ cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper& wrapper):
mu_
(
wrapper
.
mu_
),
mu_
(
wrapper
.
mu_
),
allocator_
(
wrapper
.
allocator_
)
allocator_
(
wrapper
.
allocator_
)
{
{
const
char
*
env_weightlayout_str
=
std
::
getenv
(
"LMDEPLOY_WEIGHTLAYOUT_SWITCH"
);
if
(
env_weightlayout_str
!=
nullptr
)
{
m_weightlayout_switch
=
std
::
stoi
(
env_weightlayout_str
);
}
const
char
*
env_dump_str
=
std
::
getenv
(
"LMDEPLOY_DUMP_SWITCH"
);
if
(
env_dump_str
!=
nullptr
)
{
m_dump_switch
=
std
::
stoi
(
env_dump_str
);
}
TM_LOG_DEBUG
(
__PRETTY_FUNCTION__
);
TM_LOG_DEBUG
(
__PRETTY_FUNCTION__
);
if
(
allocator_
!=
nullptr
)
{
if
(
allocator_
!=
nullptr
)
{
cublas_workspace_
=
allocator_
->
reMalloc
(
cublas_workspace_
,
CUBLAS_WORKSPACE_SIZE
,
false
);
cublas_workspace_
=
allocator_
->
reMalloc
(
cublas_workspace_
,
CUBLAS_WORKSPACE_SIZE
,
false
);
//当采用rocblas的时候或者采用ck并开启dump功能的时候需要申请反量化模块
if
(
m_weightlayout_switch
==
1
||
m_weightlayout_switch
==
0
||
(
m_weightlayout_switch
==
2
&&
m_dump_switch
==
1
))
{
//需要反量化后weight临时存储的空间
printf
(
"alloc space for deqeight
\n
"
);
deweight_workspace_
=
allocator_
->
reMalloc
(
deweight_workspace_
,
DEQ_WORKSPACE_SIZE
,
false
);
if
(
m_weightlayout_switch
==
1
||
(
m_weightlayout_switch
==
2
&&
m_dump_switch
==
1
))
{
printf
(
"alloc space for xpading
\n
"
);
printf
(
"weight layout is tn pading rocblas
\n
"
);
xpading_workspace_
=
allocator_
->
reMalloc
(
xpading_workspace_
,
XPAD_WORKSPACE_SIZE
,
false
);
}
}
else
if
(
m_weightlayout_switch
==
2
)
{
printf
(
"alloc space for ck workspace
\n
"
);
printf
(
"weight layout is tn pading ck
\n
"
);
ck_workspace_
=
allocator_
->
reMalloc
(
ck_workspace_
,
CK_WORKSPACE_SIZE
,
false
);
}
}
}
}
}
...
...
src/turbomind/utils/cublasMMWrapper.h
View file @
d26f4c73
...
@@ -70,6 +70,12 @@ protected:
...
@@ -70,6 +70,12 @@ protected:
const
bool
per_column_scaling
);
const
bool
per_column_scaling
);
public:
public:
void
*
ck_workspace_
=
nullptr
;
//x的pad
void
*
xpading_workspace_
=
nullptr
;
void
*
deweight_workspace_
=
nullptr
;
int
m_weightlayout_switch
=
1
;
int
m_dump_switch
=
0
;
cublasMMWrapper
(
cublasHandle_t
cublas_handle_
,
cublasMMWrapper
(
cublasHandle_t
cublas_handle_
,
cublasLtHandle_t
cublaslt_handle_
,
cublasLtHandle_t
cublaslt_handle_
,
cudaStream_t
stream
,
cudaStream_t
stream
,
...
...
src/turbomind/utils/cuda_utils.h
View file @
d26f4c73
...
@@ -38,7 +38,14 @@ namespace turbomind {
...
@@ -38,7 +38,14 @@ namespace turbomind {
#define COL32_ 32
#define COL32_ 32
// workspace for cublas gemm : 32MB
// workspace for cublas gemm : 32MB
#define CUBLAS_WORKSPACE_SIZE 33554432
#define CUBLAS_WORKSPACE_SIZE 33554432
#define CK_WORKSPACE_SIZE 1056768000
#define N_max 3000
#define M_max 22016
#define XPAD_WORKSPACE_SIZE 132096000
#define DEQ_WORKSPACE_SIZE 232096000
// workspace for ck gemm : 3000*22016*8*2= 1,056,768,000
// XPAD_WORKSPACE_SIZE :3000*22016*2 = 132,096,000
// DEQ_WORKSPACE_SIZE :4096*22016*2 = 180,355,072 < 232,096,000
typedef
struct
__align__
(
4
)
typedef
struct
__align__
(
4
)
{
{
half
x
,
y
,
z
,
w
;
half
x
,
y
,
z
,
w
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment