Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9533a172
Unverified
Commit
9533a172
authored
Dec 02, 2024
by
Illia Silin
Committed by
GitHub
Dec 02, 2024
Browse files
Merge branch 'develop' into codegen-enable-hiprtc
parents
c2cf0733
50ee4267
Changes
503
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1176 additions
and
8 deletions
+1176
-8
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
+14
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
+14
-0
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+73
-0
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+603
-0
example/ck_tile/15_fused_moe/misc/moe-0.png
example/ck_tile/15_fused_moe/misc/moe-0.png
+0
-0
example/ck_tile/15_fused_moe/misc/moe-1.png
example/ck_tile/15_fused_moe/misc/moe-1.png
+0
-0
example/ck_tile/15_fused_moe/misc/moe-2.png
example/ck_tile/15_fused_moe/misc/moe-2.png
+0
-0
example/ck_tile/15_fused_moe/misc/moe-3.png
example/ck_tile/15_fused_moe/misc/moe-3.png
+0
-0
example/ck_tile/16_batched_gemm/CMakeLists.txt
example/ck_tile/16_batched_gemm/CMakeLists.txt
+1
-0
example/ck_tile/16_batched_gemm/README.md
example/ck_tile/16_batched_gemm/README.md
+37
-0
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+103
-0
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+63
-0
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+253
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+5
-0
include/ck/ck.hpp
include/ck/ck.hpp
+5
-3
include/ck/library/utility/algorithm.hpp
include/ck/library/utility/algorithm.hpp
+0
-0
include/ck/library/utility/check_err.hpp
include/ck/library/utility/check_err.hpp
+5
-5
include/ck/library/utility/conv_common.hpp
include/ck/library/utility/conv_common.hpp
+0
-0
include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp
...ary/utility/convolution_host_tensor_descriptor_helper.hpp
+0
-0
include/ck/library/utility/convolution_parameter.hpp
include/ck/library/utility/convolution_parameter.hpp
+0
-0
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
0 → 100644
View file @
9533a172
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
0 → 100644
View file @
9533a172
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
0 → 100644
View file @
9533a172
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
if
(
a
.
num_experts
>
127
)
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
return
-
1
;
}
if
(
a
.
moe_buf_bytes
%
16
)
{
printf
(
"buf set size %d unaligned, must be multiple of 16
\n
"
,
a
.
moe_buf_bytes
);
return
-
1
;
}
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
index_t
smem_io_unroll_num
=
ck_tile
::
integer_divide_ceil
(
a
.
tokens
*
a
.
topk
,
64
);
switch
(
smem_io_unroll_num
)
{
case
(
1
):
{
MOE_SORTING_DISPATCH
(
1
);
}
case
(
2
):
{
MOE_SORTING_DISPATCH
(
2
);
}
case
(
3
):
{
MOE_SORTING_DISPATCH
(
3
);
}
case
(
5
):
{
MOE_SORTING_DISPATCH
(
5
);
}
case
(
6
):
{
MOE_SORTING_DISPATCH
(
6
);
}
case
(
7
):
{
MOE_SORTING_DISPATCH
(
7
);
}
case
(
8
):
{
MOE_SORTING_DISPATCH
(
8
);
}
case
(
9
):
{
MOE_SORTING_DISPATCH
(
9
);
}
case
(
10
):
{
MOE_SORTING_DISPATCH
(
10
);
}
case
(
11
):
{
MOE_SORTING_DISPATCH
(
11
);
}
default:
{
MOE_SORTING_DISPATCH
(
4
);
}
}
}
return
-
1
;
}
example/ck_tile/15_fused_moe/main.cpp
0 → 100644
View file @
9533a172
#include <algorithm>
#include <cstring>
#include <unordered_set>
#include <vector>
#include <set>
#include "ck_tile/host.hpp"
#include "fused_moe.hpp"
// different threshold for different dtype
template
<
typename
DataType
>
auto
get_elimit
()
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
()
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
// mfma_type, 0:32x32, 1:16x16
// TODO: padding?
template
<
typename
T
>
auto
shuffle_moe_weight
(
const
ck_tile
::
HostTensor
<
T
>&
t
,
std
::
string
mfma_dtype
,
int
mfma_type
=
0
)
{
assert
(
t
.
get_lengths
().
size
()
==
3
);
int
b_
=
t
.
get_lengths
()[
0
];
int
n_
=
t
.
get_lengths
()[
1
];
int
k_
=
t
.
get_lengths
()[
2
];
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
0
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
32
,
32
,
k_
/
16
,
2
,
8
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
1
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
16
,
16
,
k_
/
32
,
4
,
8
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"int8"
||
mfma_dtype
==
"fp8"
)
&&
mfma_type
==
0
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
32
,
32
,
k_
/
32
,
2
,
16
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"int8"
||
mfma_dtype
==
"fp8"
)
&&
mfma_type
==
1
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
16
,
16
,
k_
/
64
,
4
,
16
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
return
t
;
}
template
<
typename
IndexType
>
void
topid_unique_gen
(
std
::
vector
<
IndexType
>&
host_tensor
,
int
tokens
,
int
topk
,
int
num_expert
,
int
seed
)
{
size_t
total_size
=
topk
*
tokens
;
std
::
srand
(
seed
);
std
::
set
<
IndexType
>
unique_set
;
IndexType
current_v
;
for
(
size_t
i
=
0
;
i
<
total_size
;
i
++
)
{
if
(
i
%
topk
==
0
)
{
unique_set
.
clear
();
}
current_v
=
std
::
rand
()
%
num_expert
;
while
(
unique_set
.
find
(
current_v
)
!=
unique_set
.
end
())
{
current_v
=
std
::
rand
()
%
num_expert
;
}
unique_set
.
insert
(
current_v
);
host_tensor
[
i
]
=
current_v
;
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"t"
,
"128"
,
"num input tokens"
)
.
insert
(
"e"
,
"32"
,
"num of experts"
)
.
insert
(
"k"
,
"5"
,
"topk"
)
.
insert
(
"h"
,
"8192"
,
"hidden_size of this model"
)
.
insert
(
"i"
,
"8192"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
.
insert
(
"bm"
,
"32"
,
"blocking factor for sorted tokens"
)
.
insert
(
"tp"
,
"8"
,
"tensor parallel size"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec_i"
,
"bf16"
,
"input precision"
)
.
insert
(
"prec_w"
,
"bf16"
,
"weight precision"
)
.
insert
(
"prec_o"
,
"bf16"
,
"output precision"
)
.
insert
(
"prec_st"
,
"auto"
,
"token scale data type. auto will set to fp32"
)
.
insert
(
"prec_sw"
,
"auto"
,
"weight scale data type. auto will set to fp32"
)
.
insert
(
"prec_sq"
,
"auto"
,
"(dynamic) smooth quant data type. auto will set to fp32"
)
.
insert
(
"prec_kw"
,
"auto"
,
"topk-weight data type. auto will set to fp32"
)
.
insert
(
"fquant"
,
"0"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"gate_only"
,
"1"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
.
insert
(
"api"
,
"0"
,
"benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm"
)
.
insert
(
"balance"
,
"0"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
.
insert
(
"init"
,
"2"
,
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"normalized(slow)"
)
.
insert
(
"seed"
,
"11939"
,
"seed used to do random"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
// SQ:smooth-quant-type, KW:topk-weight-type
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
tokens
=
arg_parser
.
get_int
(
"t"
);
ck_tile
::
index_t
experts
=
arg_parser
.
get_int
(
"e"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
hidden_size
=
arg_parser
.
get_int
(
"h"
);
ck_tile
::
index_t
intermediate_size
=
arg_parser
.
get_int
(
"i"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
block_m
=
arg_parser
.
get_int
(
"bm"
);
if
(
stride
<
0
)
stride
=
hidden_size
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
gate_only
=
arg_parser
.
get_int
(
"gate_only"
);
int
api
=
arg_parser
.
get_int
(
"api"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
init
=
arg_parser
.
get_int
(
"init"
);
uint32_t
seed
=
arg_parser
.
get_uint32
(
"seed"
);
// w0 (Gate+Up or Gate only, N size)
ck_tile
::
index_t
shared_intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
// w1 (Down, N size)
ck_tile
::
index_t
shared_intermediate_size_1
=
intermediate_size
/
tp
;
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_w
)
base_str
+=
"x"
+
prec_w
;
if
(
prec_i
!=
prec_o
)
base_str
+=
"="
+
prec_o
;
if
(
fused_quant
!=
0
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_st
+
"|"
+
prec_sw
+
"|"
+
prec_sq
+
")"
;
}
return
base_str
;
}();
auto
api_str
=
[
&
]()
{
if
(
api
==
0
)
return
std
::
string
(
"fmoe"
);
else
if
(
api
==
1
)
return
std
::
string
(
"moeg"
);
else
if
(
api
==
2
)
return
std
::
string
(
"moes"
);
return
std
::
string
(
""
);
}();
auto
stride_str
=
[
&
]()
{
if
(
stride
==
hidden_size
)
return
std
::
string
(
""
);
else
return
std
::
string
(
", st:"
)
+
std
::
to_string
(
stride
);
}();
std
::
cout
<<
"["
<<
api_str
<<
"|"
<<
prec_str
<<
"]"
<<
" t:"
<<
tokens
<<
", e:"
<<
experts
<<
", k:"
<<
topk
<<
stride_str
<<
", hidden:"
<<
hidden_size
<<
", interm:"
<<
intermediate_size
<<
", tp:"
<<
tp
<<
", shrd_interm:"
<<
shared_intermediate_size_0
<<
"|"
<<
shared_intermediate_size_1
<<
", go:"
<<
gate_only
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
AccDataType
=
typename
TypeConfig
::
AccDataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
GScaleDataType
=
typename
TypeConfig
::
GScaleDataType
;
using
DScaleDataType
=
typename
TypeConfig
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
TypeConfig
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
TypeConfig
::
TopkWeightDataType
;
using
IndexDataType
=
typename
TypeConfig
::
IndexDataType
;
// host verify
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
experts
,
shared_intermediate_size_0
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
experts
,
hidden_size
,
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
shared_intermediate_size_0
});
ck_tile
::
HostTensor
<
DScaleDataType
>
sd_host
({
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
shared_intermediate_size_1
});
// smooth-quant
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_token_ids_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
TopkWeightDataType
>
sorted_weight_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_expert_ids_host
(
{(
max_num_tokens_padded
+
block_m
-
1
)
/
block_m
});
ck_tile
::
HostTensor
<
IndexDataType
>
num_sorted_tiles_host
({
1
});
if
(
init
==
0
)
{
ck_tile
::
FillStepRange
<
ADataType
>
{
-
.5
f
,
.5
f
,
0.01
f
}(
a_host
);
ck_tile
::
FillStepRange
<
GDataType
>
{
-
.5
f
,
.5
f
,
0.01
f
}(
g_host
);
ck_tile
::
FillStepRange
<
DDataType
,
false
>
{
.5
f
,
-
.5
f
,
-
0.01
f
}(
d_host
);
ck_tile
::
FillStepRange
<
AScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sa_host
);
ck_tile
::
FillStepRange
<
GScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sg_host
);
ck_tile
::
FillStepRange
<
DScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sd_host
);
ck_tile
::
FillStepRange
<
YSmoothScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sy_host
);
ck_tile
::
FillStepRange
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
,
0.01
f
}(
topk_weight_host
);
}
else
if
(
init
==
1
)
{
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
d_host
);
ck_tile
::
FillUniformDistribution
<
AScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sa_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
topk_weight_host
);
}
else
if
(
init
==
2
)
{
ck_tile
::
FillNormalDistribution
<
ADataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillNormalDistribution
<
GDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillNormalDistribution
<
DDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
d_host
);
ck_tile
::
FillNormalDistribution
<
AScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sa_host
);
ck_tile
::
FillNormalDistribution
<
GScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillNormalDistribution
<
DScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillNormalDistribution
<
YSmoothScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillNormalDistribution
<
TopkWeightDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
topk_weight_host
);
}
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
// do moe sorting
if
(
balance
)
{
int
e_cnt
=
0
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
topk_ids_host
.
mData
.
size
());
i
++
)
{
topk_ids_host
.
mData
[
i
]
=
e_cnt
;
e_cnt
++
;
if
(
e_cnt
>=
experts
)
e_cnt
=
0
;
}
}
else
{
topid_unique_gen
<
IndexDataType
>
(
topk_ids_host
.
mData
,
tokens
,
topk
,
experts
,
11913
);
}
// leave it here for future debug purpose
#if 0
a_host.loadtxt("../../ater/input_torch.txt");
topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int");
// topk_ids_host.savetxt("topk_ids_2.txt");
topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
g_host.loadtxt("../../ater/w1_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
d_host.loadtxt("../../ater/w2_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
#endif
#if 0
std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl;
std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl;
std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl;
std::cout << "topk_weight_host:" << topk_weight_host << std::endl;
std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl;
#endif
auto
cal_tflops
=
[
&
](
auto
ms
)
{
double
flop_gemm_0
=
2
*
static_cast
<
double
>
(
tokens
)
*
topk
*
shared_intermediate_size_0
*
hidden_size
;
double
flop_gemm_1
=
2
*
static_cast
<
double
>
(
tokens
)
*
topk
*
shared_intermediate_size_1
*
hidden_size
;
return
(
flop_gemm_0
+
flop_gemm_1
)
/
(
static_cast
<
double
>
(
ms
)
*
1e-3
)
/
1e12
;
};
// TODO: this method we use expert-by-expert view, just for reference
auto
cal_tbps
=
[
&
](
auto
ms
)
{
double
token_bytes
=
static_cast
<
double
>
(
tokens
)
*
topk
/
experts
*
hidden_size
*
sizeof
(
ADataType
);
double
w0_bytes
=
static_cast
<
double
>
(
shared_intermediate_size_0
)
*
experts
*
hidden_size
*
sizeof
(
GDataType
);
double
w1_bytes
=
static_cast
<
double
>
(
shared_intermediate_size_1
)
*
experts
*
hidden_size
*
sizeof
(
DDataType
);
double
o_bytes
=
static_cast
<
double
>
(
tokens
)
*
topk
/
experts
*
hidden_size
*
sizeof
(
ODataType
);
double
topk_weights_bytes
=
static_cast
<
double
>
(
tokens
)
*
topk
*
sizeof
(
TopkWeightDataType
);
// ignore index, they are too small
return
(
token_bytes
+
w0_bytes
+
w1_bytes
+
o_bytes
+
topk_weights_bytes
)
/
(
static_cast
<
double
>
(
ms
)
*
1e-3
)
/
1e12
;
};
if
(
api
==
0
)
{
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_perm_host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_perm_host
);
ck_tile
::
DeviceMem
sa_buf
(
sa_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
topk_ids_buf
(
topk_ids_host
);
ck_tile
::
DeviceMem
topk_weight_buf
(
topk_weight_host
);
ck_tile
::
DeviceMem
sorted_token_ids_buf
(
sorted_token_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_weight_buf
(
sorted_weight_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_expert_ids_buf
(
sorted_expert_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
num_sorted_tiles_buf
(
num_sorted_tiles_host
.
get_element_space_size_in_bytes
());
fused_moe_traits
traits
{
prec_i
,
prec_w
,
prec_o
,
prec_st
,
prec_sw
,
prec_sq
,
prec_kw
,
block_m
,
gate_only
,
fused_quant
};
fused_moe_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
g_perm_buf
.
GetDeviceBuffer
(),
d_perm_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
o_buf
.
GetDeviceBuffer
(),
topk_ids_buf
.
GetDeviceBuffer
(),
topk_weight_buf
.
GetDeviceBuffer
(),
sorted_token_ids_buf
.
GetDeviceBuffer
(),
sorted_weight_buf
.
GetDeviceBuffer
(),
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
block_m
,
hidden_size
,
shared_intermediate_size_0
,
tokens
,
experts
,
topk
,
stride
};
float
ave_time
=
fused_moe
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
cal_tflops
(
ave_time
)
<<
" tflops, "
<<
cal_tbps
(
ave_time
)
<<
" TB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_weight_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
sa_host
,
sg_host
,
sd_host
,
sy_host
,
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
}
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
}
else
if
(
api
==
1
)
{
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_weight_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_perm_host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_perm_host
);
ck_tile
::
DeviceMem
sa_buf
(
sa_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
);
// manually clear output buffer for atomic
o_buf
.
SetZero
();
//
ck_tile
::
DeviceMem
sorted_token_ids_buf
(
sorted_token_ids_host
);
ck_tile
::
DeviceMem
sorted_weight_buf
(
sorted_weight_host
);
ck_tile
::
DeviceMem
sorted_expert_ids_buf
(
sorted_expert_ids_host
);
ck_tile
::
DeviceMem
num_sorted_tiles_buf
(
num_sorted_tiles_host
);
fused_moegemm_traits
traits
{
prec_i
,
prec_w
,
prec_o
,
prec_st
,
prec_sw
,
prec_sq
,
prec_kw
,
block_m
,
gate_only
,
fused_quant
};
fused_moegemm_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
g_perm_buf
.
GetDeviceBuffer
(),
d_perm_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
o_buf
.
GetDeviceBuffer
(),
sorted_token_ids_buf
.
GetDeviceBuffer
(),
sorted_weight_buf
.
GetDeviceBuffer
(),
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
hidden_size
,
shared_intermediate_size_0
,
tokens
,
experts
,
topk
,
stride
};
float
ave_time
=
fused_moegemm
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
cal_tflops
(
ave_time
)
<<
" tflops, "
<<
cal_tbps
(
ave_time
)
<<
" TB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
sa_host
,
sg_host
,
sd_host
,
sy_host
,
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
}
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
}
return
false
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
// no dynamic quant case
if
(
prec_i
==
"bf16"
&&
prec_w
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_kw
==
"fp32"
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_w
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_kw
==
"fp32"
)
{
return
run
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
example/ck_tile/15_fused_moe/misc/moe-0.png
0 → 100644
View file @
9533a172
75 KB
example/ck_tile/15_fused_moe/misc/moe-1.png
0 → 100644
View file @
9533a172
90.4 KB
example/ck_tile/15_fused_moe/misc/moe-2.png
0 → 100644
View file @
9533a172
124 KB
example/ck_tile/15_fused_moe/misc/moe-3.png
0 → 100644
View file @
9533a172
18.2 KB
example/ck_tile/16_batched_gemm/CMakeLists.txt
0 → 100644
View file @
9533a172
add_executable
(
tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp
)
example/ck_tile/16_batched_gemm/README.md
0 → 100644
View file @
9533a172
# Batched GEMM
This folder contains example for batched GEMM using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_batched_gemm -j
```
This will result in an executable
`build/bin/tile_example_batched_gemm`
## example
```
args:
-m m dimension (default:256)
-n n dimension (default:128)
-k k dimension (default:128)
-a_layout A tensor data layout (default:R) (R for Row, C for Col)
-b_layout B tensor data layout (default:R) (R for Row, C for Col)
-c_layout C tensor data layout (default:R) (R for Row, C for Col)
-stride_a Tensor A stride (default:128)
-stride_b Tensor B stride (default:128)
-stride_c Tensor C stride (default:128)
-batch_stride_a Batch A stride (default:32768)
-batch_stride_b Batch B stride (default:16384)
-batch_stride_c Batch C stride (default:32768)
-batch_count Batch count (default:16)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
\ No newline at end of file
example/ck_tile/16_batched_gemm/batched_gemm.cpp
0 → 100644
View file @
9533a172
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "batched_gemm.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
batched_gemm
(
const
batched_gemm_kargs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
#include "run_batched_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_batched_gemm_example
(
argc
,
argv
);
}
example/ck_tile/16_batched_gemm/batched_gemm.hpp
0 → 100644
View file @
9533a172
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
template
<
typename
DataType
>
struct
BatchedGemmTypeConfig
;
template
<
>
struct
BatchedGemmTypeConfig
<
ck_tile
::
half_t
>
{
using
ADataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
using
Types
=
BatchedGemmTypeConfig
<
ck_tile
::
half_t
>
;
// Specific type aliases for easy access
using
ADataType
=
Types
::
ADataType
;
using
BDataType
=
Types
::
BDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmHostArgs
{
};
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"256"
,
"m dimension"
)
.
insert
(
"n"
,
"128"
,
"n dimension"
)
.
insert
(
"k"
,
"128"
,
"k dimension"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"batch_stride_a"
,
"32768"
,
"Batch A stride"
)
.
insert
(
"batch_stride_b"
,
"16384"
,
"Batch B stride"
)
.
insert
(
"batch_stride_c"
,
"32768"
,
"Batch C stride"
)
.
insert
(
"batch_count"
,
"16"
,
"Batch count"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
// host API
float
batched_gemm
(
batched_gemm_kargs
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
0 → 100644
View file @
9533a172
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_batched_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
,
ck_tile
::
index_t
batch_stride_A
,
ck_tile
::
index_t
batch_stride_B
,
ck_tile
::
index_t
batch_stride_C
,
ck_tile
::
index_t
batch_count
,
int
n_warmup
,
int
n_repeat
)
{
batched_gemm_kargs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
stride_A
=
stride_A
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
args
.
batch_stride_A
=
batch_stride_A
;
args
.
batch_stride_B
=
batch_stride_B
;
args
.
batch_stride_C
=
batch_stride_C
;
args
.
batch_count
=
batch_count
;
float
ave_time
=
batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
string
op_name
{
"Batched Gemm"
};
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
batch_count
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
batch_count
*
M
*
K
+
sizeof
(
BDataType
)
*
batch_count
*
N
*
K
+
sizeof
(
CDataType
)
*
batch_count
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Run "
<<
op_name
<<
"kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" batch_stride_A ="
<<
batch_stride_A
<<
" batch_stride_B ="
<<
batch_stride_B
<<
" batch_stride_C ="
<<
batch_stride_C
<<
" batch_count ="
<<
batch_count
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
return
ave_time
;
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
int
run_batched_gemm_example_with_layouts
(
int
argc
,
char
*
argv
[],
const
ALayout
a_layout
=
ALayout
{},
const
BLayout
b_layout
=
BLayout
{},
[[
maybe_unused
]]
const
CLayout
c_layout
=
CLayout
{})
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
stride_A
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
batch_stride_A
=
arg_parser
.
get_int
(
"batch_stride_a"
);
ck_tile
::
index_t
batch_stride_B
=
arg_parser
.
get_int
(
"batch_stride_b"
);
ck_tile
::
index_t
batch_stride_C
=
arg_parser
.
get_int
(
"batch_stride_c"
);
ck_tile
::
index_t
batch_count
=
arg_parser
.
get_int
(
"batch_count"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count_
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
stride
,
1_
uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
1_
uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
a_layout
);
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
b_layout
);
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
c_layout
);
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
batch_stride_A
,
a_layout
));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
batch_stride_B
,
b_layout
));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
c_layout
));
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
());
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
batch_stride_A
,
batch_stride_B
,
batch_stride_C
,
batch_count
,
n_warmup
,
n_repeat
);
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
CLayout
{}));
c_m_n_host_ref
.
SetZero
();
const
auto
b_n_k
=
b_k_n
.
transpose
({
0
,
2
,
1
});
ck_tile
::
reference_batched_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_n_k
,
c_m_n_host_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
CLayout
{}));
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ck_tile
::
reference_batched_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_gpu_buf_ref
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
batch_stride_A
,
batch_stride_B
,
batch_stride_C
,
batch_count
);
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
std
::
cout
<<
"The GPU verification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
return
pass
;
}
int
run_batched_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
return
run_batched_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_batched_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work else if(a_layout == "C" && b_layout == "C")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
example/ck_tile/CMakeLists.txt
View file @
9533a172
...
@@ -11,3 +11,8 @@ add_subdirectory(06_permute)
...
@@ -11,3 +11,8 @@ add_subdirectory(06_permute)
add_subdirectory
(
09_topk_softmax
)
add_subdirectory
(
09_topk_softmax
)
add_subdirectory
(
10_rmsnorm2d
)
add_subdirectory
(
10_rmsnorm2d
)
add_subdirectory
(
11_add_rmsnorm2d_rdquant
)
add_subdirectory
(
11_add_rmsnorm2d_rdquant
)
add_subdirectory
(
12_smoothquant
)
add_subdirectory
(
13_moe_sorting
)
add_subdirectory
(
14_moe_smoothquant
)
add_subdirectory
(
15_fused_moe
)
add_subdirectory
(
16_batched_gemm
)
include/ck/ck.hpp
View file @
9533a172
...
@@ -61,13 +61,15 @@
...
@@ -61,13 +61,15 @@
#define __gfx101__
#define __gfx101__
#endif
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
defined(__gfx10_3_generic__)
#define __gfx103__
#define __gfx103__
#endif
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__
#define __gfx11__
#endif
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx1200__) || defined(__gfx1201__)
|| defined(__gfx12_generic__)
#define __gfx12__
#define __gfx12__
#endif
#endif
...
...
library/
include/ck/library/utility/algorithm.hpp
→
include/ck/library/utility/algorithm.hpp
View file @
9533a172
File moved
library/
include/ck/library/utility/check_err.hpp
→
include/ck/library/utility/check_err.hpp
View file @
9533a172
...
@@ -24,7 +24,7 @@ namespace ck {
...
@@ -24,7 +24,7 @@ namespace ck {
namespace
utils
{
namespace
utils
{
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
number
OfA
ccumulations
=
1
)
double
get_relative_threshold
(
const
int
number
_of_a
ccumulations
=
1
)
{
{
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1)
...
@@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1)
}
}
else
else
{
{
acc_error
=
std
::
pow
(
2
,
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
number
OfA
ccumulations
;
acc_error
=
std
::
pow
(
2
,
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
number
_of_a
ccumulations
;
}
}
return
std
::
max
(
acc_error
,
midway_error
);
return
std
::
max
(
acc_error
,
midway_error
);
}
}
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number
OfA
ccumulations
=
1
)
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number
_of_a
ccumulations
=
1
)
{
{
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA
...
@@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA
else
else
{
{
acc_error
=
acc_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
number
OfA
ccumulations
;
std
::
pow
(
2
,
expo
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
number
_of_a
ccumulations
;
}
}
return
std
::
max
(
acc_error
,
midway_error
);
return
std
::
max
(
acc_error
,
midway_error
);
}
}
...
@@ -206,7 +206,7 @@ typename std::enable_if<
...
@@ -206,7 +206,7 @@ typename std::enable_if<
check_err
(
const
Range
&
out
,
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-
3
,
double
rtol
=
1e-
1
,
double
atol
=
1e-3
)
double
atol
=
1e-3
)
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
...
...
library/
include/ck/library/utility/conv_common.hpp
→
include/ck/library/utility/conv_common.hpp
View file @
9533a172
File moved
library/
include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp
→
include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp
View file @
9533a172
File moved
library/
include/ck/library/utility/convolution_parameter.hpp
→
include/ck/library/utility/convolution_parameter.hpp
View file @
9533a172
File moved
Prev
1
…
4
5
6
7
8
9
10
11
12
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment