Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bb7c4112
Commit
bb7c4112
authored
Nov 29, 2024
by
letaoqin
Browse files
debugging
parent
7881eff9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
40 deletions
+99
-40
example/ck_tile/16_fused_moe_general/main.cpp
example/ck_tile/16_fused_moe_general/main.cpp
+19
-1
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+3
-4
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+57
-23
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+20
-12
No files found.
example/ck_tile/16_fused_moe_general/main.cpp
View file @
bb7c4112
...
...
@@ -73,6 +73,23 @@ void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m, int n)
std
::
cout
<<
std
::
endl
;
}
}
template
<
typename
IndexType
>
void
output_matrix_3d
(
ck_tile
::
HostTensor
<
IndexType
>&
data
,
int
M
,
int
N
,
int
J
)
{
std
::
cout
<<
std
::
endl
;
for
(
int
m
=
0
;
m
<
M
;
m
++
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
std
::
cout
<<
"experts: "
<<
m
<<
" Line: "
<<
n
<<
"
\t
"
;
for
(
int
j
=
0
;
j
<
J
;
j
++
)
{
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
m
,
n
,
j
))
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
}
}
template
<
typename
IndexType
>
void
topid_unique_gen
(
...
...
@@ -265,7 +282,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// output_matrix_2d(a_host, tokens, hidden_size);
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl;
output_matrix_3d
(
g_host
,
experts
,
shared_intermediate_size_0
,
hidden_size
);
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
// std::cout << topk_weight_host << std::endl;
// std::cout << sorted_weight_host << std::endl;
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
bb7c4112
...
...
@@ -301,11 +301,10 @@ struct FusedMoeGemmGlKernel
// TODO: gtile using NSub to have less register pressure
const
auto
g_window
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
idx_n0
*
kargs
.
hidden_size
;
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
;
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
BlockShape
::
Block_N0
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
intermediate_size
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
hidden_size
,
1
),
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
...
...
@@ -313,7 +312,7 @@ struct FusedMoeGemmGlKernel
const
auto
g_window_
=
make_tile_window
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
{
idx_n
0
,
0
});
return
g_window_
;
}();
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
bb7c4112
...
...
@@ -98,6 +98,9 @@ struct FusedMoeGemmPipeline_General
index_t
hidden_size
,
index_t
intermediate_size
)
{
ignore
=
d_window_
;
ignore
=
hidden_size
;
ignore
=
intermediate_size
;
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsBlockDesc_A
<
Problem
>());
...
...
@@ -126,10 +129,60 @@ struct FusedMoeGemmPipeline_General
// save tokens to lds
auto
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
store_tile
(
a_lds_win
,
a_dram_block
);
#if 0
{
// check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idk_0 = idxk.impl_.at(0);
printf("in A idm is %d , idk_ is %d , counter is %d, value is: %f \n",
idm_0,
idk_0,
counter,
ck_tile::type_convert<float>(a_dram_block(i_j_idx)));
}
});
});
}
#endif
// load g to register
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
#if 1
{
constexpr
auto
a_spans
=
decltype
(
g_dram_block
)
::
get_distributed_spans
();
int
counter
=
0
;
sweep_tile_span
(
a_spans
[
number
<
0
>
{}],
[
&
](
auto
idxn
)
{
sweep_tile_span
(
a_spans
[
number
<
1
>
{}],
[
&
](
auto
idxk
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idxn
,
idxk
);
if
(
threadIdx
.
x
==
1
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
counter
=
counter
+
1
;
index_t
idn_0
=
idxn
.
impl_
.
at
(
0
);
index_t
idk_0
=
idxk
.
impl_
.
at
(
0
);
index_t
idk_1
=
idxk
.
impl_
.
at
(
1
);
printf
(
"in A idn is %d , idk_0 is %d idk_1 is %d, counter is %d, value is: "
"%f
\n
"
,
idn_0
,
idk_0
,
idk_1
,
counter
,
ck_tile
::
type_convert
<
float
>
(
g_dram_block
(
i_j_idx
)));
}
});
});
}
#endif
clear_tile
(
s_acc
);
// initialize C
constexpr
index_t
kK0
=
BlockShape
::
Block_K0
;
const
index_t
k0_loops
=
ck_tile
::
integer_divide_ceil
(
intermediate_size
,
kK0
);
...
...
@@ -196,6 +249,10 @@ struct FusedMoeGemmPipeline_General
block_sync_lds
();
move_tile_window
(
d_global_to_dram_window
,
{
kN1
,
0
});
d
=
load_tile
(
d_global_to_dram_window
);
// move out window and save data
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
store_tile
(
o_window_
,
o
);
move_tile_window
(
o_window_
,
{
kN1
,
0
});
iCounter1
--
;
}
...
...
@@ -204,30 +261,7 @@ struct FusedMoeGemmPipeline_General
block_sync_lds
();
gemm_1
(
o_acc
,
y
,
d
);
}
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
store_tile
(
o_window_
,
o
);
// store_tile(o_window_, a_dram_block);
#if 0
//check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idn_0 = idxk.impl_.at(0);
printf("in A idm is %d , idn_ is %d , counter is %d, value is: %f \n",
idm_0,
idn_0,
counter,
ck_tile::type_convert<float>(a_dram_block(i_j_idx)));
}
});
});
#endif
}
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
bb7c4112
...
...
@@ -173,18 +173,26 @@ struct FusedMoeGemmPipelineGeneralPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
()
{
using
WG
=
decltype
(
GetWarpGemm0
<
Problem
>
());
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
auto
g_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S_
::
Repeat_N0
,
S_
::
WarpPerBlock_N0
>
,
sequence
<
S_
::
Repeat_K0
>>
,
tuple
<
sequence
<
1
>>
,
tuple
<
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
g_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
g_outer_dstr_enc
,
typename
WG
::
BWarpDstrEncoding
{});
// using WG = decltype(GetWarpGemm0<Problem>());
// using S_ = typename Problem::BlockShape;
// static_assert(S_::WarpPerBlock_N0==4);
// constexpr auto g_outer_dstr_enc = tile_distribution_encoding<
// sequence<S_::WarpPerBlock_M0>,
// tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>,
// tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// g_outer_dstr_enc, typename WG::BWarpDstrEncoding{});
constexpr
auto
g_block_dstr_encode
=
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
1
,
4
,
32
>
,
sequence
<
4
,
2
,
4
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
0
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{};
return
make_static_tile_distribution
(
g_block_dstr_encode
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment