Commit e15351ca authored by coderfeli's avatar coderfeli
Browse files

tile m = 64 ok

parent 48d87d9c
...@@ -133,13 +133,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -133,13 +133,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
//threadnum, mblock, nblock, kblock //threadnum, mblock, nblock, kblock
256, 32, 128, 128, 256, 64, 128, 128,
// ak1, bk1 // ak1, bk1
8, 8, 8, 8,
// mn_perxdl // mn_perxdl
32, 32, 32, 32,
// mn_xdlperwave // mn_xdlperwave
1, 1, 2, 1,
// a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra // a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
...@@ -169,8 +169,8 @@ int main(int argc, char* argv[]) ...@@ -169,8 +169,8 @@ int main(int argc, char* argv[])
ck::index_t N = 6144; ck::index_t N = 6144;
ck::index_t K = 8192; ck::index_t K = 8192;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 1;
ck::index_t sorted_tile_size = 32; ck::index_t sorted_tile_size = 64;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 64; ck::index_t tokens = 64;
...@@ -368,7 +368,7 @@ int main(int argc, char* argv[]) ...@@ -368,7 +368,7 @@ int main(int argc, char* argv[])
auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument( auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); sorted_token_ids, expert_ids, sorted_tile_size, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int m = 0; m < SORTED_SIZE; ++m) for(int m = 0; m < SORTED_SIZE; ++m)
......
...@@ -176,8 +176,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -176,8 +176,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NWave * warpSize == BlockSize); static_assert(NWave * warpSize == BlockSize);
// static constexpr index_t NumTokens = 1; // static constexpr index_t NumTokens = 1;
static constexpr index_t Experts = 8; static constexpr index_t SortedTileSize = MPerBlock;
static constexpr index_t SortedTileSize = 32;
static constexpr auto MakeDsGridPointer() static constexpr auto MakeDsGridPointer()
......
...@@ -30,14 +30,16 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -30,14 +30,16 @@ struct ReferenceMoeGemm : public device::BaseOperator
{ {
Argument(const Tensor<ck::index_t>& sorted_token_ids, Argument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids, const Tensor<ck::index_t>& expert_ids,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k, const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_m_n, Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
: expert_ids_{expert_ids}, : sorted_token_ids_{sorted_token_ids},
sorted_token_ids_{sorted_token_ids}, expert_ids_{expert_ids},
sorted_tile_size_{sorted_tile_size},
a_t_k_{a_t_k}, a_t_k_{a_t_k},
b_e_n_k_{b_e_n_k}, b_e_n_k_{b_e_n_k},
c_m_n_{c_m_n}, c_m_n_{c_m_n},
...@@ -56,7 +58,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -56,7 +58,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t sorted_tile_size = 32; index_t sorted_tile_size_;
}; };
// Invoker // Invoker
...@@ -73,7 +75,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -73,7 +75,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
ComputeTypeA v_a{0}; ComputeTypeA v_a{0};
ComputeTypeB v_b{0}; ComputeTypeB v_b{0};
const int t = arg.sorted_token_ids_(m); const int t = arg.sorted_token_ids_(m);
const int e = arg.expert_ids_(m / arg.sorted_tile_size); const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0]; const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0];
if(t < token_cnt) { if(t < token_cnt) {
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
...@@ -135,6 +137,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -135,6 +137,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids, static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids, const Tensor<ck::index_t>& expert_ids,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k, const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_m_n, Tensor<CDataType>& c_m_n,
...@@ -142,7 +145,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -142,7 +145,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{sorted_token_ids, expert_ids, a_t_k, b_e_n_k, c_m_n, a_element_op, b_element_op, c_element_op}; return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_t_k, b_e_n_k, c_m_n, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment