Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b74918bc
Commit
b74918bc
authored
Jan 06, 2025
by
ThomasNing
Browse files
compiled version of cross gpu connection
parents
3fcad951
1c45ca35
Changes
486
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1516 additions
and
0 deletions
+1516
-0
example/ck_tile/15_fused_moe/fused_moegemm.hpp
example/ck_tile/15_fused_moe/fused_moegemm.hpp
+84
-0
example/ck_tile/15_fused_moe/fused_moesorting.hpp
example/ck_tile/15_fused_moe/fused_moesorting.hpp
+20
-0
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
+80
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+33
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
...ile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
+60
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
..._tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
+53
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
+14
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
+14
-0
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+73
-0
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+603
-0
example/ck_tile/15_fused_moe/misc/moe-0.png
example/ck_tile/15_fused_moe/misc/moe-0.png
+0
-0
example/ck_tile/15_fused_moe/misc/moe-1.png
example/ck_tile/15_fused_moe/misc/moe-1.png
+0
-0
example/ck_tile/15_fused_moe/misc/moe-2.png
example/ck_tile/15_fused_moe/misc/moe-2.png
+0
-0
example/ck_tile/15_fused_moe/misc/moe-3.png
example/ck_tile/15_fused_moe/misc/moe-3.png
+0
-0
example/ck_tile/16_batched_gemm/CMakeLists.txt
example/ck_tile/16_batched_gemm/CMakeLists.txt
+1
-0
example/ck_tile/16_batched_gemm/README.md
example/ck_tile/16_batched_gemm/README.md
+37
-0
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+103
-0
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+59
-0
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+280
-0
example/ck_tile/17_grouped_gemm/CMakeLists.txt
example/ck_tile/17_grouped_gemm/CMakeLists.txt
+2
-0
No files found.
example/ck_tile/15_fused_moe/fused_moegemm.hpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <string>
// this is only a convenient structure for creating an example
// this is not part of the host API
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
;
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
GDataType
=
ck_tile
::
bf16_t
;
using
DDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
GScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
DScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ST
,
SW
,
SQ
,
KW
>
{
using
ADataType
=
ck_tile
::
fp16_t
;
using
GDataType
=
ck_tile
::
fp16_t
;
using
DDataType
=
ck_tile
::
fp16_t
;
using
AccDataType
=
float
;
using
ODataType
=
ck_tile
::
fp16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
GScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
DScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
{
using
ADataType
=
ck_tile
::
int8_t
;
using
GDataType
=
ck_tile
::
int8_t
;
using
DDataType
=
ck_tile
::
int8_t
;
using
AccDataType
=
int32_t
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
GScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
DScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
// runtime args
struct
fused_moegemm_args
:
public
ck_tile
::
FusedMoeGemmHostArgs
{
};
// This is the public API, will be generated by script
struct
fused_moegemm_traits
{
std
::
string
prec_i
;
// input precision
std
::
string
prec_w
;
// weight precision
std
::
string
prec_o
;
// output precision
std
::
string
prec_st
;
// token scale data type
std
::
string
prec_sw
;
// weight scale data type
std
::
string
prec_sq
;
// smooth quant scale
std
::
string
prec_kw
;
// topk-weight data type
int
block_m
;
int
gate_only
;
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float
fused_moegemm
(
fused_moegemm_traits
,
fused_moegemm_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/15_fused_moe/fused_moesorting.hpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/fused_moe.hpp"
struct
fused_moesorting_trait
{
std
::
string
index_type
;
std
::
string
weight_type
;
// currently always float
};
struct
fused_moesorting_args
:
public
ck_tile
::
MoeSortingHostArgs
{
};
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
);
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fused_moe.hpp"
float
fused_moe
(
fused_moe_traits
t
,
fused_moe_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
auto
s_sub
=
ck_tile
::
stream_config
{
s
.
stream_id_
,
false
,
s
.
log_level_
,
0
,
1
};
auto
o_data_bytes
=
[
&
]()
{
if
(
t
.
prec_o
==
"fp32"
)
return
4
;
else
if
(
t
.
prec_o
==
"fp16"
||
t
.
prec_o
==
"bf16"
)
return
2
;
else
if
(
t
.
prec_o
==
"int8"
||
t
.
prec_o
==
"fp8"
)
return
1
;
return
1
;
}();
auto
t0
=
fused_moesorting_trait
{
"int32"
,
"fp32"
};
auto
a0
=
fused_moesorting_args
{
a
.
topk_ids_ptr
,
// const void* p_topk_ids;
a
.
topk_weight_ptr
,
// const void* p_weights;
a
.
sorted_token_ids_ptr
,
// void* p_sorted_token_ids;
a
.
sorted_weight_ptr
,
// void* p_sorted_weights;
a
.
sorted_expert_ids_ptr
,
// void* p_sorted_expert_ids;
a
.
num_sorted_tiles_ptr
,
// void* p_total_tokens_post_pad;
a
.
o_ptr
,
// void* p_moe_buf;
a
.
num_tokens
,
// index_t tokens;
a
.
block_m
,
// index_t unit_size;
a
.
num_experts
,
// index_t num_experts;
a
.
topk
,
// index_t topk;
a
.
num_tokens
*
a
.
stride_token
*
o_data_bytes
// index_t moe_buf_bytes;
};
auto
t1
=
fused_moegemm_traits
{
t
.
prec_i
,
t
.
prec_w
,
t
.
prec_o
,
t
.
prec_st
,
t
.
prec_sw
,
t
.
prec_sq
,
t
.
prec_kw
,
t
.
block_m
,
t
.
gate_only
,
t
.
fused_quant
};
auto
a1
=
fused_moegemm_args
{
a
.
a_ptr
,
// const void* a_ptr;
a
.
a_scale_ptr
,
// const void* a_scale_ptr;
a
.
g_ptr
,
// const void* g_ptr;
a
.
d_ptr
,
// const void* d_ptr;
a
.
g_scale_ptr
,
// const void* g_scale_ptr;
a
.
d_scale_ptr
,
// const void* d_scale_ptr;
a
.
y_smooth_scale_ptr
,
// const void* y_smooth_scale_ptr;
a
.
o_ptr
,
// void* o_ptr;
a
.
sorted_token_ids_ptr
,
// const void* sorted_token_ids_ptr;
a
.
sorted_weight_ptr
,
// const void* sorted_weight_ptr;
a
.
sorted_expert_ids_ptr
,
// const void* sorted_expert_ids_ptr;
a
.
num_sorted_tiles_ptr
,
// const void* num_sorted_tiles_ptr;
a
.
hidden_size
,
// index_t hidden_size;
a
.
intermediate_size
,
// index_t intermediate_size;
a
.
num_tokens
,
// index_t num_tokens;
a
.
num_experts
,
// index_t num_experts;
a
.
topk
,
// index_t topk;
a
.
stride_token
// index_t stride_token;
};
float
r0
=
-
1
;
float
r1
=
-
1
;
float
r
=
ck_tile
::
launch_kernel
(
s
,
[
=
,
&
r0
](
const
ck_tile
::
stream_config
&
)
{
r0
=
fused_moesorting
(
t0
,
a0
,
s_sub
);
},
[
=
,
&
r1
](
const
ck_tile
::
stream_config
&
)
{
r1
=
fused_moegemm
(
t1
,
a1
,
s_sub
);
});
// keep unsupported case return negative
if
(
r0
<
0
||
r1
<
0
)
return
-
1
;
return
r
;
}
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template
<
typename
Traits_
>
float
fused_moegemm_
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
<
ck_tile
::
index_t
...
Is
>
using
S
=
ck_tile
::
sequence
<
Is
...
>
;
float
fused_moegemm
(
fused_moegemm_traits
t
,
fused_moegemm_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
// clang-format off
float
r
=
-
1
;
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
// clang-format on
return
r
;
}
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <iostream>
template
<
ck_tile
::
index_t
...
Is
>
using
S
=
ck_tile
::
sequence
<
Is
...
>
;
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template
<
typename
Ts_
>
float
fused_moegemm_
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
)
{
using
f_traits
=
ck_tile
::
FusedMoeGemmTraits
<
Ts_
::
GateOnly
,
Ts_
::
FusedQuant
==
1
,
1
/*atomic*/
>
;
using
f_shape
=
ck_tile
::
FusedMoeGemmShape
<
typename
Ts_
::
BlockTile_0
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpTile_0
,
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpTile_0
>
;
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
typename
Ts_
::
GDataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
AccDataType
,
typename
Ts_
::
ODataType
,
typename
Ts_
::
AScaleDataType
,
typename
Ts_
::
GScaleDataType
,
typename
Ts_
::
DScaleDataType
,
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
TopkWeightDataType
,
typename
Ts_
::
IndexDataType
,
ck_tile
::
element_wise
::
FastGeluAsm
,
// TODO: hardcoded
f_shape
,
f_traits
>
;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmUk
<
f_problem
>
;
using
f_partitioner
=
ck_tile
::
FusedMoeGemmTilePartitioner_Linear
<
f_shape
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemmKernel
<
f_partitioner
,
f_pipeline
,
void
>
;
const
dim3
grids
=
f_kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
f_kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
static
int
printed
=
0
;
auto
kargs
=
f_kernel
::
MakeKargs
(
a
);
if
(
s
.
log_level_
>
0
&&
printed
==
0
)
{
std
::
cout
<<
", "
<<
f_kernel
::
GetName
()
<<
std
::
flush
;
printed
=
1
;
}
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
f_kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
,
typename
BlockTIle_
,
// seq<b_token, b_interm, b_hidden, b_down>
typename
WarpPerBlock_
,
typename
WarpTile_
,
// seq<*,*,*>, used to select mfma
ck_tile
::
index_t
GateOnly_
=
0
,
ck_tile
::
index_t
FusedQuant_
=
0
>
struct
fmoe_
// traits, ugly name, only used for internal
{
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
ADataType
>
;
using
GDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
GDataType
>
;
using
DDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
DDataType
>
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
AccDataType
>
;
using
ODataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
ODataType
>
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
AScaleDataType
>
;
using
GScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
GScaleDataType
>
;
using
DScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
DScaleDataType
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
YSmoothScaleDataType
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
IndexDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
static
constexpr
ck_tile
::
index_t
BT_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
0
>
{});
// block token
static
constexpr
ck_tile
::
index_t
BI_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
1
>
{});
// block intermediate
static
constexpr
ck_tile
::
index_t
BH_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
2
>
{});
// block hidden
static
constexpr
ck_tile
::
index_t
BD_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
3
>
{});
// block down
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
,
BH_
>
;
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
FusedQuant
=
FusedQuant_
;
};
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
if
(
a
.
num_experts
>
127
)
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
return
-
1
;
}
if
(
a
.
moe_buf_bytes
%
16
)
{
printf
(
"buf set size %d unaligned, must be multiple of 16
\n
"
,
a
.
moe_buf_bytes
);
return
-
1
;
}
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
index_t
smem_io_unroll_num
=
ck_tile
::
integer_divide_ceil
(
a
.
tokens
*
a
.
topk
,
64
);
switch
(
smem_io_unroll_num
)
{
case
(
1
):
{
MOE_SORTING_DISPATCH
(
1
);
}
case
(
2
):
{
MOE_SORTING_DISPATCH
(
2
);
}
case
(
3
):
{
MOE_SORTING_DISPATCH
(
3
);
}
case
(
5
):
{
MOE_SORTING_DISPATCH
(
5
);
}
case
(
6
):
{
MOE_SORTING_DISPATCH
(
6
);
}
case
(
7
):
{
MOE_SORTING_DISPATCH
(
7
);
}
case
(
8
):
{
MOE_SORTING_DISPATCH
(
8
);
}
case
(
9
):
{
MOE_SORTING_DISPATCH
(
9
);
}
case
(
10
):
{
MOE_SORTING_DISPATCH
(
10
);
}
case
(
11
):
{
MOE_SORTING_DISPATCH
(
11
);
}
default:
{
MOE_SORTING_DISPATCH
(
4
);
}
}
}
return
-
1
;
}
example/ck_tile/15_fused_moe/main.cpp
0 → 100644
View file @
b74918bc
#include <algorithm>
#include <cstring>
#include <unordered_set>
#include <vector>
#include <set>
#include "ck_tile/host.hpp"
#include "fused_moe.hpp"
// different threshold for different dtype
template
<
typename
DataType
>
auto
get_elimit
()
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
()
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
// mfma_type, 0:32x32, 1:16x16
// TODO: padding?
template
<
typename
T
>
auto
shuffle_moe_weight
(
const
ck_tile
::
HostTensor
<
T
>&
t
,
std
::
string
mfma_dtype
,
int
mfma_type
=
0
)
{
assert
(
t
.
get_lengths
().
size
()
==
3
);
int
b_
=
t
.
get_lengths
()[
0
];
int
n_
=
t
.
get_lengths
()[
1
];
int
k_
=
t
.
get_lengths
()[
2
];
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
0
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
32
,
32
,
k_
/
16
,
2
,
8
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
1
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
16
,
16
,
k_
/
32
,
4
,
8
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"int8"
||
mfma_dtype
==
"fp8"
)
&&
mfma_type
==
0
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
32
,
32
,
k_
/
32
,
2
,
16
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"int8"
||
mfma_dtype
==
"fp8"
)
&&
mfma_type
==
1
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
16
,
16
,
k_
/
64
,
4
,
16
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
return
t
;
}
template
<
typename
IndexType
>
void
topid_unique_gen
(
std
::
vector
<
IndexType
>&
host_tensor
,
int
tokens
,
int
topk
,
int
num_expert
,
int
seed
)
{
size_t
total_size
=
topk
*
tokens
;
std
::
srand
(
seed
);
std
::
set
<
IndexType
>
unique_set
;
IndexType
current_v
;
for
(
size_t
i
=
0
;
i
<
total_size
;
i
++
)
{
if
(
i
%
topk
==
0
)
{
unique_set
.
clear
();
}
current_v
=
std
::
rand
()
%
num_expert
;
while
(
unique_set
.
find
(
current_v
)
!=
unique_set
.
end
())
{
current_v
=
std
::
rand
()
%
num_expert
;
}
unique_set
.
insert
(
current_v
);
host_tensor
[
i
]
=
current_v
;
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"t"
,
"128"
,
"num input tokens"
)
.
insert
(
"e"
,
"32"
,
"num of experts"
)
.
insert
(
"k"
,
"5"
,
"topk"
)
.
insert
(
"h"
,
"8192"
,
"hidden_size of this model"
)
.
insert
(
"i"
,
"8192"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
.
insert
(
"bm"
,
"32"
,
"blocking factor for sorted tokens"
)
.
insert
(
"tp"
,
"8"
,
"tensor parallel size"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec_i"
,
"bf16"
,
"input precision"
)
.
insert
(
"prec_w"
,
"bf16"
,
"weight precision"
)
.
insert
(
"prec_o"
,
"bf16"
,
"output precision"
)
.
insert
(
"prec_st"
,
"auto"
,
"token scale data type. auto will set to fp32"
)
.
insert
(
"prec_sw"
,
"auto"
,
"weight scale data type. auto will set to fp32"
)
.
insert
(
"prec_sq"
,
"auto"
,
"(dynamic) smooth quant data type. auto will set to fp32"
)
.
insert
(
"prec_kw"
,
"auto"
,
"topk-weight data type. auto will set to fp32"
)
.
insert
(
"fquant"
,
"0"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"gate_only"
,
"1"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
.
insert
(
"api"
,
"0"
,
"benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm"
)
.
insert
(
"balance"
,
"0"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
.
insert
(
"init"
,
"2"
,
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"normalized(slow)"
)
.
insert
(
"seed"
,
"11939"
,
"seed used to do random"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
// SQ:smooth-quant-type, KW:topk-weight-type
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
tokens
=
arg_parser
.
get_int
(
"t"
);
ck_tile
::
index_t
experts
=
arg_parser
.
get_int
(
"e"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
hidden_size
=
arg_parser
.
get_int
(
"h"
);
ck_tile
::
index_t
intermediate_size
=
arg_parser
.
get_int
(
"i"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
block_m
=
arg_parser
.
get_int
(
"bm"
);
if
(
stride
<
0
)
stride
=
hidden_size
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
gate_only
=
arg_parser
.
get_int
(
"gate_only"
);
int
api
=
arg_parser
.
get_int
(
"api"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
init
=
arg_parser
.
get_int
(
"init"
);
uint32_t
seed
=
arg_parser
.
get_uint32
(
"seed"
);
// w0 (Gate+Up or Gate only, N size)
ck_tile
::
index_t
shared_intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
// w1 (Down, N size)
ck_tile
::
index_t
shared_intermediate_size_1
=
intermediate_size
/
tp
;
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_w
)
base_str
+=
"x"
+
prec_w
;
if
(
prec_i
!=
prec_o
)
base_str
+=
"="
+
prec_o
;
if
(
fused_quant
!=
0
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_st
+
"|"
+
prec_sw
+
"|"
+
prec_sq
+
")"
;
}
return
base_str
;
}();
auto
api_str
=
[
&
]()
{
if
(
api
==
0
)
return
std
::
string
(
"fmoe"
);
else
if
(
api
==
1
)
return
std
::
string
(
"moeg"
);
else
if
(
api
==
2
)
return
std
::
string
(
"moes"
);
return
std
::
string
(
""
);
}();
auto
stride_str
=
[
&
]()
{
if
(
stride
==
hidden_size
)
return
std
::
string
(
""
);
else
return
std
::
string
(
", st:"
)
+
std
::
to_string
(
stride
);
}();
std
::
cout
<<
"["
<<
api_str
<<
"|"
<<
prec_str
<<
"]"
<<
" t:"
<<
tokens
<<
", e:"
<<
experts
<<
", k:"
<<
topk
<<
stride_str
<<
", hidden:"
<<
hidden_size
<<
", interm:"
<<
intermediate_size
<<
", tp:"
<<
tp
<<
", shrd_interm:"
<<
shared_intermediate_size_0
<<
"|"
<<
shared_intermediate_size_1
<<
", go:"
<<
gate_only
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
AccDataType
=
typename
TypeConfig
::
AccDataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
GScaleDataType
=
typename
TypeConfig
::
GScaleDataType
;
using
DScaleDataType
=
typename
TypeConfig
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
TypeConfig
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
TypeConfig
::
TopkWeightDataType
;
using
IndexDataType
=
typename
TypeConfig
::
IndexDataType
;
// host verify
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
experts
,
shared_intermediate_size_0
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
experts
,
hidden_size
,
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
shared_intermediate_size_0
});
ck_tile
::
HostTensor
<
DScaleDataType
>
sd_host
({
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
shared_intermediate_size_1
});
// smooth-quant
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_token_ids_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
TopkWeightDataType
>
sorted_weight_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_expert_ids_host
(
{(
max_num_tokens_padded
+
block_m
-
1
)
/
block_m
});
ck_tile
::
HostTensor
<
IndexDataType
>
num_sorted_tiles_host
({
1
});
if
(
init
==
0
)
{
ck_tile
::
FillStepRange
<
ADataType
>
{
-
.5
f
,
.5
f
,
0.01
f
}(
a_host
);
ck_tile
::
FillStepRange
<
GDataType
>
{
-
.5
f
,
.5
f
,
0.01
f
}(
g_host
);
ck_tile
::
FillStepRange
<
DDataType
,
false
>
{
.5
f
,
-
.5
f
,
-
0.01
f
}(
d_host
);
ck_tile
::
FillStepRange
<
AScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sa_host
);
ck_tile
::
FillStepRange
<
GScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sg_host
);
ck_tile
::
FillStepRange
<
DScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sd_host
);
ck_tile
::
FillStepRange
<
YSmoothScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sy_host
);
ck_tile
::
FillStepRange
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
,
0.01
f
}(
topk_weight_host
);
}
else
if
(
init
==
1
)
{
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
d_host
);
ck_tile
::
FillUniformDistribution
<
AScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sa_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
topk_weight_host
);
}
else
if
(
init
==
2
)
{
ck_tile
::
FillNormalDistribution
<
ADataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillNormalDistribution
<
GDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillNormalDistribution
<
DDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
d_host
);
ck_tile
::
FillNormalDistribution
<
AScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sa_host
);
ck_tile
::
FillNormalDistribution
<
GScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillNormalDistribution
<
DScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillNormalDistribution
<
YSmoothScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillNormalDistribution
<
TopkWeightDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
topk_weight_host
);
}
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
// do moe sorting
if
(
balance
)
{
int
e_cnt
=
0
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
topk_ids_host
.
mData
.
size
());
i
++
)
{
topk_ids_host
.
mData
[
i
]
=
e_cnt
;
e_cnt
++
;
if
(
e_cnt
>=
experts
)
e_cnt
=
0
;
}
}
else
{
topid_unique_gen
<
IndexDataType
>
(
topk_ids_host
.
mData
,
tokens
,
topk
,
experts
,
11913
);
}
// leave it here for future debug purpose
#if 0
a_host.loadtxt("../../ater/input_torch.txt");
topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int");
// topk_ids_host.savetxt("topk_ids_2.txt");
topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
g_host.loadtxt("../../ater/w1_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
d_host.loadtxt("../../ater/w2_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
#endif
#if 0
std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl;
std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl;
std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl;
std::cout << "topk_weight_host:" << topk_weight_host << std::endl;
std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl;
#endif
auto
cal_tflops
=
[
&
](
auto
ms
)
{
double
flop_gemm_0
=
2
*
static_cast
<
double
>
(
tokens
)
*
topk
*
shared_intermediate_size_0
*
hidden_size
;
double
flop_gemm_1
=
2
*
static_cast
<
double
>
(
tokens
)
*
topk
*
shared_intermediate_size_1
*
hidden_size
;
return
(
flop_gemm_0
+
flop_gemm_1
)
/
(
static_cast
<
double
>
(
ms
)
*
1e-3
)
/
1e12
;
};
// TODO: this method we use expert-by-expert view, just for reference
auto
cal_tbps
=
[
&
](
auto
ms
)
{
double
token_bytes
=
static_cast
<
double
>
(
tokens
)
*
topk
/
experts
*
hidden_size
*
sizeof
(
ADataType
);
double
w0_bytes
=
static_cast
<
double
>
(
shared_intermediate_size_0
)
*
experts
*
hidden_size
*
sizeof
(
GDataType
);
double
w1_bytes
=
static_cast
<
double
>
(
shared_intermediate_size_1
)
*
experts
*
hidden_size
*
sizeof
(
DDataType
);
double
o_bytes
=
static_cast
<
double
>
(
tokens
)
*
topk
/
experts
*
hidden_size
*
sizeof
(
ODataType
);
double
topk_weights_bytes
=
static_cast
<
double
>
(
tokens
)
*
topk
*
sizeof
(
TopkWeightDataType
);
// ignore index, they are too small
return
(
token_bytes
+
w0_bytes
+
w1_bytes
+
o_bytes
+
topk_weights_bytes
)
/
(
static_cast
<
double
>
(
ms
)
*
1e-3
)
/
1e12
;
};
if
(
api
==
0
)
{
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_perm_host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_perm_host
);
ck_tile
::
DeviceMem
sa_buf
(
sa_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
topk_ids_buf
(
topk_ids_host
);
ck_tile
::
DeviceMem
topk_weight_buf
(
topk_weight_host
);
ck_tile
::
DeviceMem
sorted_token_ids_buf
(
sorted_token_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_weight_buf
(
sorted_weight_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_expert_ids_buf
(
sorted_expert_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
num_sorted_tiles_buf
(
num_sorted_tiles_host
.
get_element_space_size_in_bytes
());
fused_moe_traits
traits
{
prec_i
,
prec_w
,
prec_o
,
prec_st
,
prec_sw
,
prec_sq
,
prec_kw
,
block_m
,
gate_only
,
fused_quant
};
fused_moe_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
g_perm_buf
.
GetDeviceBuffer
(),
d_perm_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
o_buf
.
GetDeviceBuffer
(),
topk_ids_buf
.
GetDeviceBuffer
(),
topk_weight_buf
.
GetDeviceBuffer
(),
sorted_token_ids_buf
.
GetDeviceBuffer
(),
sorted_weight_buf
.
GetDeviceBuffer
(),
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
block_m
,
hidden_size
,
shared_intermediate_size_0
,
tokens
,
experts
,
topk
,
stride
};
float
ave_time
=
fused_moe
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
cal_tflops
(
ave_time
)
<<
" tflops, "
<<
cal_tbps
(
ave_time
)
<<
" TB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_weight_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
sa_host
,
sg_host
,
sd_host
,
sy_host
,
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
}
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
}
else
if
(
api
==
1
)
{
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_weight_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_perm_host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_perm_host
);
ck_tile
::
DeviceMem
sa_buf
(
sa_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
);
// manually clear output buffer for atomic
o_buf
.
SetZero
();
//
ck_tile
::
DeviceMem
sorted_token_ids_buf
(
sorted_token_ids_host
);
ck_tile
::
DeviceMem
sorted_weight_buf
(
sorted_weight_host
);
ck_tile
::
DeviceMem
sorted_expert_ids_buf
(
sorted_expert_ids_host
);
ck_tile
::
DeviceMem
num_sorted_tiles_buf
(
num_sorted_tiles_host
);
fused_moegemm_traits
traits
{
prec_i
,
prec_w
,
prec_o
,
prec_st
,
prec_sw
,
prec_sq
,
prec_kw
,
block_m
,
gate_only
,
fused_quant
};
fused_moegemm_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
g_perm_buf
.
GetDeviceBuffer
(),
d_perm_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
o_buf
.
GetDeviceBuffer
(),
sorted_token_ids_buf
.
GetDeviceBuffer
(),
sorted_weight_buf
.
GetDeviceBuffer
(),
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
hidden_size
,
shared_intermediate_size_0
,
tokens
,
experts
,
topk
,
stride
};
float
ave_time
=
fused_moegemm
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
cal_tflops
(
ave_time
)
<<
" tflops, "
<<
cal_tbps
(
ave_time
)
<<
" TB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
sa_host
,
sg_host
,
sd_host
,
sy_host
,
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
}
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
}
return
false
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
// no dynamic quant case
if
(
prec_i
==
"bf16"
&&
prec_w
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_kw
==
"fp32"
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_w
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_kw
==
"fp32"
)
{
return
run
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
example/ck_tile/15_fused_moe/misc/moe-0.png
0 → 100644
View file @
b74918bc
75 KB
example/ck_tile/15_fused_moe/misc/moe-1.png
0 → 100644
View file @
b74918bc
90.4 KB
example/ck_tile/15_fused_moe/misc/moe-2.png
0 → 100644
View file @
b74918bc
124 KB
example/ck_tile/15_fused_moe/misc/moe-3.png
0 → 100644
View file @
b74918bc
18.2 KB
example/ck_tile/16_batched_gemm/CMakeLists.txt
0 → 100644
View file @
b74918bc
add_executable
(
tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp
)
example/ck_tile/16_batched_gemm/README.md
0 → 100644
View file @
b74918bc
# Batched GEMM
This folder contains example for batched GEMM using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_batched_gemm -j
```
This will result in an executable
`build/bin/tile_example_batched_gemm`
## example
```
args:
-m m dimension (default:256)
-n n dimension (default:128)
-k k dimension (default:128)
-a_layout A tensor data layout (default:R) (R for Row, C for Col)
-b_layout B tensor data layout (default:R) (R for Row, C for Col)
-c_layout C tensor data layout (default:R) (R for Row, C for Col)
-stride_a Tensor A stride (default:128)
-stride_b Tensor B stride (default:128)
-stride_c Tensor C stride (default:128)
-batch_stride_a Batch A stride (default:32768)
-batch_stride_b Batch B stride (default:16384)
-batch_stride_c Batch C stride (default:32768)
-batch_count Batch count (default:16)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
\ No newline at end of file
example/ck_tile/16_batched_gemm/batched_gemm.cpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "batched_gemm.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
#include "run_batched_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_batched_gemm_example
(
argc
,
argv
);
}
example/ck_tile/16_batched_gemm/batched_gemm.hpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
template
<
typename
DataType
>
struct
BatchedGemmTypeConfig
;
template
<
>
struct
BatchedGemmTypeConfig
<
ck_tile
::
half_t
>
{
using
ADataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
using
Types
=
BatchedGemmTypeConfig
<
ck_tile
::
half_t
>
;
// Specific type aliases for easy access
using
ADataType
=
Types
::
ADataType
;
using
BDataType
=
Types
::
BDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"256"
,
"m dimension"
)
.
insert
(
"n"
,
"128"
,
"n dimension"
)
.
insert
(
"k"
,
"128"
,
"k dimension"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"batch_stride_a"
,
"32768"
,
"Batch A stride"
)
.
insert
(
"batch_stride_b"
,
"16384"
,
"Batch B stride"
)
.
insert
(
"batch_stride_c"
,
"32768"
,
"Batch C stride"
)
.
insert
(
"batch_count"
,
"16"
,
"Batch count"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
// host API
float
batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_batched_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
,
ck_tile
::
index_t
batch_stride_A
,
ck_tile
::
index_t
batch_stride_B
,
ck_tile
::
index_t
batch_stride_C
,
ck_tile
::
index_t
batch_count
,
int
n_warmup
,
int
n_repeat
)
{
ck_tile
::
BatchedGemmHostArgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
stride_A
=
stride_A
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
args
.
batch_stride_A
=
batch_stride_A
;
args
.
batch_stride_B
=
batch_stride_B
;
args
.
batch_stride_C
=
batch_stride_C
;
args
.
batch_count
=
batch_count
;
float
ave_time
=
batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
string
op_name
{
"Batched Gemm"
};
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
batch_count
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
batch_count
*
M
*
K
+
sizeof
(
BDataType
)
*
batch_count
*
N
*
K
+
sizeof
(
CDataType
)
*
batch_count
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Run "
<<
op_name
<<
"kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" batch_stride_A ="
<<
batch_stride_A
<<
" batch_stride_B ="
<<
batch_stride_B
<<
" batch_stride_C ="
<<
batch_stride_C
<<
" batch_count ="
<<
batch_count
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
return
ave_time
;
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
int
run_batched_gemm_example_with_layouts
(
int
argc
,
char
*
argv
[],
const
ALayout
a_layout
=
ALayout
{},
const
BLayout
b_layout
=
BLayout
{},
[[
maybe_unused
]]
const
CLayout
c_layout
=
CLayout
{})
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
stride_A
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
batch_stride_A
=
arg_parser
.
get_int
(
"batch_stride_a"
);
ck_tile
::
index_t
batch_stride_B
=
arg_parser
.
get_int
(
"batch_stride_b"
);
ck_tile
::
index_t
batch_stride_C
=
arg_parser
.
get_int
(
"batch_stride_c"
);
ck_tile
::
index_t
batch_count
=
arg_parser
.
get_int
(
"batch_count"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count_
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
stride
,
1_
uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
1_
uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
a_layout
);
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
b_layout
);
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
c_layout
);
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
batch_stride_A
,
a_layout
));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
batch_stride_B
,
b_layout
));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
c_layout
));
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
());
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
batch_stride_A
,
batch_stride_B
,
batch_stride_C
,
batch_count
,
n_warmup
,
n_repeat
);
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
CLayout
{}));
c_m_n_host_ref
.
SetZero
();
const
auto
b_n_k
=
b_k_n
.
transpose
({
0
,
2
,
1
});
ck_tile
::
reference_batched_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_n_k
,
c_m_n_host_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
CLayout
{}));
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
)));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_A
,
a_m_k_dev_buf
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_B
,
b_k_n_dev_buf
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
));
ck_tile
::
reference_batched_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
batch_stride_A
,
batch_stride_B
,
batch_stride_C
,
batch_count
);
ck_tile
::
hip_check_error
(
hipMemcpy
(
c_m_n_gpu_buf_ref
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
));
ck_tile
::
hip_check_error
(
hipFree
(
d_A
));
ck_tile
::
hip_check_error
(
hipFree
(
d_B
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
std
::
cout
<<
"The GPU verification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
return
pass
;
}
int
run_batched_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
return
run_batched_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_batched_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work else if(a_layout == "C" && b_layout == "C")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
example/ck_tile/17_grouped_gemm/CMakeLists.txt
0 → 100644
View file @
b74918bc
add_executable
(
tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp
)
Prev
1
…
3
4
5
6
7
8
9
10
11
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment