Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
eaf8e616
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "35e5c53294b0161426ed7a5b465f01d5ee4ab78b"
Commit
eaf8e616
authored
Nov 22, 2024
by
letaoqin
Browse files
write a data to lds
parent
3b51749a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
84 deletions
+68
-84
example/ck_tile/16_fused_moe_general/main.cpp
example/ck_tile/16_fused_moe_general/main.cpp
+19
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+37
-26
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+12
-58
No files found.
example/ck_tile/16_fused_moe_general/main.cpp
View file @
eaf8e616
...
...
@@ -59,6 +59,21 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype,
return
t
;
}
template
<
typename
IndexType
>
void
output_matrix_2d
(
ck_tile
::
HostTensor
<
IndexType
>&
data
,
int
m
,
int
n
)
{
std
::
cout
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
std
::
cout
<<
"Line "
<<
i
<<
"
\t
"
;
for
(
int
j
=
0
;
j
<
n
;
j
++
)
{
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
i
,
j
))
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
}
template
<
typename
IndexType
>
void
topid_unique_gen
(
std
::
vector
<
IndexType
>&
host_tensor
,
int
tokens
,
int
topk
,
int
num_expert
,
int
seed
)
...
...
@@ -256,6 +271,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// }
// std::cout << std::endl;
// }
output_matrix_2d
(
a_host
,
tokens
,
hidden_size
);
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl;
...
...
@@ -277,6 +293,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
sorted_weight_buf
(
sorted_weight_host
);
ck_tile
::
DeviceMem
sorted_expert_ids_buf
(
sorted_expert_ids_host
);
ck_tile
::
DeviceMem
num_sorted_tiles_buf
(
num_sorted_tiles_host
);
o_buf
.
SetZero
();
fused_moegemm_traits
traits
{
prec_i
,
prec_w
,
...
...
@@ -363,6 +380,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
pass
&=
ck_tile
::
check_err
(
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
output_matrix_2d
(
o_dev
,
tokens
,
hidden_size
);
}
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
eaf8e616
...
...
@@ -70,15 +70,15 @@ struct FusedMoeGemmPipeline_General
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
//
//
matrix a or tokens smem
//
constexpr index_t smem_mat_a =
//
BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
//
//
shuffle C matrix
//
constexpr index_t smem_bridge =
//
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
//
return max(smem_mat_a, smem_bridge);
return
Policy
::
template
GetSmemSize
<
Problem
>();
// matrix a or tokens smem
constexpr
index_t
smem_mat_a
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_K0
*
sizeof
(
ADataType
);
// shuffle C matrix
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_mat_a
,
smem_bridge
);
//
return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
...
...
@@ -105,35 +105,46 @@ struct FusedMoeGemmPipeline_General
ignore
=
hidden_size
;
ignore
=
intermediate_size
;
//
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
//
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
//
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>());
//
auto a_lds_win = make_tile_window(a_lds_view, make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), {0, 0});
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>());
auto
a_lds_win
=
make_tile_window
(
a_lds_view
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
auto
a_global_to_dram_window
=
make_tile_window
(
a_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
a_window_
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
// auto o_win = make_tile_window_linear(
// o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
auto
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
//store_tile(a_lds_win, a_dram_block);
ignore
=
a_dram_block
;
store_tile
(
a_lds_win
,
a_dram_block
);
store_tile
(
o_window_
,
a_dram_block
);
#if 0
//check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram)::get_distributed_spans();
int counter = 0;
constexpr auto a_spans = decltype(a_dram
_block
)::get_distributed_spans();
int counter
= 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk){
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0){
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idn_0 = idxk.impl_.at(0);
printf("in A idm is %d , idn_ is %d , counter is %d, value is: %f \n", idm_0, idn_0, counter, ck_tile::type_convert<float>(a_dram(i_j_idx)));
}
});
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idn_0 = idxk.impl_.at(0);
printf("in A idm is %d , idn_ is %d , counter is %d, value is: %f \n",
idm_0,
idn_0,
counter,
ck_tile::type_convert<float>(a_dram_block(i_j_idx)));
}
});
});
#endif
}
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
eaf8e616
...
...
@@ -232,53 +232,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
}
}
#if 0
// Caution: this will require global memory pre-shuffled to follow the mfma layout
template <index_t NPerBlock,
index_t KPerBlock,
index_t WavesPerBlock_N,
index_t WavesPerBlock_K,
typename WarpGemm,
index_t Alignment,
FusedMoeGemmWeightPermuteEnum PermuteEnum =
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled()
{
static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0);
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
constexpr index_t Nr_p = WavesPerBlock_N;
constexpr index_t Kr_p = WavesPerBlock_K;
constexpr index_t Nr_y = Nr / Nr_p;
constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>, // 0
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>,
// Nr_p, Kr_p Kw Nw
tuple<sequence<1, 2>, sequence<3, 3>>,
tuple<sequence<1, 1>, sequence<0, 1>>,
// Nr_y Kr_y Kv
sequence<1, 2, 3>,
sequence<0, 0, 2>>{});
// clang-format on
}
}
#endif
template
<
index_t
WarpPerBlock_N_
,
index_t
WarpPerBlock_K_
,
index_t
Repeat_N_
,
...
...
@@ -414,11 +367,11 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}
),
make_merge_transform
(
make_tuple
(
number
<
wavesPerM
>
{},
number
<
wavesPer
K
>
{})),
make_merge_transform
(
make_tuple
(
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}
,
sequence
<
2
>
{}
));
//
make_pass_through_transform(),
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
wavesPer
M
>
{})),
make_merge_transform
(
make_tuple
(
number
<
wavesPerK
>
{},
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
...
...
@@ -446,12 +399,13 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
make_tuple
(
//make_pass_through_transform(number<NumIssues>{}),
//make_pass_through_transform(number<NumWarps>{}),
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment