Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
072dfbfe
Commit
072dfbfe
authored
Dec 03, 2024
by
letaoqin
Browse files
gemm0 debugged
parent
69114f25
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
42 additions
and
35 deletions
+42
-35
example/ck_tile/16_fused_moe_general/main.cpp
example/ck_tile/16_fused_moe_general/main.cpp
+6
-3
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+1
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+12
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+23
-23
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+0
-8
No files found.
example/ck_tile/16_fused_moe_general/main.cpp
View file @
072dfbfe
...
@@ -241,6 +241,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -241,6 +241,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
-
.5
f
,
.5
f
}(
g_host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
-
.5
f
,
.5
f
}(
g_host
);
// ck_tile::FillConstant<ADataType>{1}(a_host);
// ck_tile::FillConstant<GDataType>{1}(g_host);
//ck_tile::FillStepRange<GDataType>{0.0f, 32.0f*128,1.0f}(g_host);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
-
.5
f
,
.5
f
}(
d_host
);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
-
.5
f
,
.5
f
}(
d_host
);
ck_tile
::
FillUniformDistribution
<
AScaleDataType
>
{
-
.5
f
,
.5
f
}(
sa_host
);
ck_tile
::
FillUniformDistribution
<
AScaleDataType
>
{
-
.5
f
,
.5
f
}(
sa_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
}(
sg_host
);
...
@@ -282,7 +285,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -282,7 +285,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// output_matrix_2d(a_host, tokens, hidden_size);
// output_matrix_2d(a_host, tokens, hidden_size);
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
// std::cout << num_sorted_tiles_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
output_matrix_3d
(
g_host
,
experts
,
shared_intermediate_size_0
,
hidden_size
);
//
output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
// std::cout << topk_weight_host << std::endl;
// std::cout << topk_weight_host << std::endl;
...
@@ -290,8 +293,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -290,8 +293,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// done, preparing GPU buffer
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_
perm_
host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_
perm_
host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_host
);
ck_tile
::
DeviceMem
sa_buf
(
sa_host
);
ck_tile
::
DeviceMem
sa_buf
(
sa_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
...
...
include/ck_tile/host/fill.hpp
View file @
072dfbfe
...
@@ -278,6 +278,7 @@ struct FillConstant
...
@@ -278,6 +278,7 @@ struct FillConstant
{
{
T
value_
{
0
};
T
value_
{
0
};
FillConstant
(
float
value
)
:
value_
(
ck_tile
::
type_convert
<
T
>
(
value
)){}
template
<
typename
ForwardIter
>
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
{
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
072dfbfe
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
...
@@ -298,7 +299,6 @@ struct FusedMoeGemmGlKernel
...
@@ -298,7 +299,6 @@ struct FusedMoeGemmGlKernel
return
a_window_
;
return
a_window_
;
}();
}();
// TODO: gtile using NSub to have less register pressure
const
auto
g_window
=
[
&
]()
{
const
auto
g_window
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
;
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
;
...
@@ -313,6 +313,17 @@ struct FusedMoeGemmGlKernel
...
@@ -313,6 +313,17 @@ struct FusedMoeGemmGlKernel
g_view_
,
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
idx_n0
,
0
});
{
idx_n0
,
0
});
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// for(int i = 0; i < 16; i++)
// {
// printf("in G index is %d , value is: %f\n",
// i,
// ck_tile::type_convert<float>(g_ptr[i]));
// }
// }
return
g_window_
;
return
g_window_
;
}();
}();
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
072dfbfe
...
@@ -115,6 +115,7 @@ struct FusedMoeGemmPipeline_General
...
@@ -115,6 +115,7 @@ struct FusedMoeGemmPipeline_General
a_window_
.
get_window_origin
(),
a_window_
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
// load g to register
auto
g_global_to_dram_window
=
make_tile_window
(
auto
g_global_to_dram_window
=
make_tile_window
(
g_window_
.
get_bottom_tensor_view
(),
g_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
...
@@ -153,27 +154,26 @@ struct FusedMoeGemmPipeline_General
...
@@ -153,27 +154,26 @@ struct FusedMoeGemmPipeline_General
}
}
#endif
#endif
// load g to register
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
#if 0
#if 0
{
{
constexpr auto
a
_spans = decltype(g_dram_block)::get_distributed_spans();
constexpr auto
g
_spans = decltype(g_dram_block)::get_distributed_spans();
int counter = 0;
int counter = 0;
sweep_tile_span(
a
_spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(
g
_spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(
a
_spans[number<1>{}], [&](auto idxk) {
sweep_tile_span(
g
_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxn, idxk);
constexpr auto i_j_idx = make_tuple(idxn, idxk);
const auto tile_idx = get_x_indices_from_distributed_indices(
g_dram_block.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
{
counter = counter + 1;
counter = counter + 1;
index_t idn_0 = idxn.impl_.at(0);
const auto row = tile_idx.at(number<0>{});
index_t idk_0 = idxk.impl_.at(0);
const auto col = tile_idx.at(number<1>{});
index_t idk_1 = idxk.impl_.at(1);
printf("in G row is %d , col is %d, counter is %d, value is: %f"
printf("in A idn is %d , idk_0 is %d idk_1 is %d, counter is %d, value is: "
" \n",
"%f \n",
row,
idn_0,
col,
idk_0,
idk_1,
counter,
counter,
ck_tile::type_convert<float>(g_dram_block(i_j_idx)));
ck_tile::type_convert<float>(g_dram_block(i_j_idx)));
}
}
...
@@ -185,7 +185,7 @@ struct FusedMoeGemmPipeline_General
...
@@ -185,7 +185,7 @@ struct FusedMoeGemmPipeline_General
clear_tile
(
s_acc
);
// initialize C
clear_tile
(
s_acc
);
// initialize C
constexpr
index_t
kK0
=
BlockShape
::
Block_K0
;
constexpr
index_t
kK0
=
BlockShape
::
Block_K0
;
const
index_t
k0_loops
=
ck_tile
::
integer_divide_ceil
(
intermediate
_size
,
kK0
);
const
index_t
k0_loops
=
ck_tile
::
integer_divide_ceil
(
hidden
_size
,
kK0
);
index_t
iCounter0
=
k0_loops
-
1
;
index_t
iCounter0
=
k0_loops
-
1
;
while
(
iCounter0
>
0
)
while
(
iCounter0
>
0
)
{
{
...
@@ -208,25 +208,25 @@ struct FusedMoeGemmPipeline_General
...
@@ -208,25 +208,25 @@ struct FusedMoeGemmPipeline_General
block_sync_lds
();
block_sync_lds
();
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
}
}
#if
1
#if
0
{
{
constexpr auto a_spans = decltype(s_acc)::get_distributed_spans();
constexpr auto a_spans = decltype(s_acc)::get_distributed_spans();
int counter = 0;
int counter = 0;
//a_spans[0] = 1;
//a_spans[0] = 1;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) {
constexpr
auto
i_j_idx
=
make_tuple
(
idxn
,
idxn
);
constexpr auto i_j_idx = make_tuple(idxm, idxn);
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
const auto tile_idx = get_x_indices_from_distributed_indices(
g_dram_block.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
{
counter = counter + 1;
counter = counter + 1;
index_t
idm_0
=
idxm
.
impl_
.
at
(
0
);
const auto row = tile_idx.at(number<0>{});
index_t
idn_0
=
idxn
.
impl_
.
at
(
0
);
const auto col = tile_idx.at(number<1>{});
index_t
idn_1
=
idxn
.
impl_
.
at
(
1
);
printf("in c row is %d , col is %d, counter is %d, value is: "
printf
(
"in A idn is %d , idn_0 is %d, idn_1 is %d, counter is %d, value is: "
"%f \n",
"%f \n",
idm_0
,
row,
idn_0
,
col,
idn_1
,
counter,
counter,
ck_tile::type_convert<float>(s_acc(i_j_idx)));
ck_tile::type_convert<float>(s_acc(i_j_idx)));
}
}
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
072dfbfe
...
@@ -186,14 +186,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -186,14 +186,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr
auto
g_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
auto
g_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
g_outer_dstr_enc
,
typename
WG
::
BWarpDstrEncoding
{});
g_outer_dstr_enc
,
typename
WG
::
BWarpDstrEncoding
{});
// constexpr auto g_block_dstr_encode = tile_distribution_encoding<
// sequence<1>,
// tuple<sequence<1, 4, 32>, sequence<4, 2, 4>>,
// tuple<sequence<0, 1>, sequence<2, 1>>,
// tuple<sequence<0, 1>, sequence<1, 2>>,
// sequence<1, 2, 2>,
// sequence<0, 0, 2>>{};
return
make_static_tile_distribution
(
g_block_dstr_encode
);
return
make_static_tile_distribution
(
g_block_dstr_encode
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment