Commit e15351ca authored by coderfeli's avatar coderfeli
Browse files

tile m = 64 ok

parent 48d87d9c
......@@ -133,13 +133,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
//threadnum, mblock, nblock, kblock
256, 32, 128, 128,
256, 64, 128, 128,
// ak1, bk1
8, 8,
// mn_perxdl
32, 32,
// mn_xdlperwave
1, 1,
2, 1,
// a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
......@@ -169,8 +169,8 @@ int main(int argc, char* argv[])
ck::index_t N = 6144;
ck::index_t K = 8192;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = 32;
ck::index_t sorted_tile_num = 1;
ck::index_t sorted_tile_size = 64;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 64;
......@@ -368,7 +368,7 @@ int main(int argc, char* argv[])
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
sorted_token_ids, expert_ids, sorted_tile_size, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < SORTED_SIZE; ++m)
......
......@@ -176,8 +176,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NWave * warpSize == BlockSize);
// static constexpr index_t NumTokens = 1;
static constexpr index_t Experts = 8;
static constexpr index_t SortedTileSize = 32;
static constexpr index_t SortedTileSize = MPerBlock;
static constexpr auto MakeDsGridPointer()
......
......@@ -29,15 +29,17 @@ struct ReferenceMoeGemm : public device::BaseOperator
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids,
const Tensor<ADataType>& a_t_k,
const Tensor<ck::index_t>& expert_ids,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: expert_ids_{expert_ids},
sorted_token_ids_{sorted_token_ids},
: sorted_token_ids_{sorted_token_ids},
expert_ids_{expert_ids},
sorted_tile_size_{sorted_tile_size},
a_t_k_{a_t_k},
b_e_n_k_{b_e_n_k},
c_m_n_{c_m_n},
......@@ -56,7 +58,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t sorted_tile_size = 32;
index_t sorted_tile_size_;
};
// Invoker
......@@ -73,7 +75,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
const int t = arg.sorted_token_ids_(m);
const int e = arg.expert_ids_(m / arg.sorted_tile_size);
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0];
if(t < token_cnt) {
for(int k = 0; k < K; ++k)
......@@ -135,6 +137,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_m_n,
......@@ -142,7 +145,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{sorted_token_ids, expert_ids, a_t_k, b_e_n_k, c_m_n, a_element_op, b_element_op, c_element_op};
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_t_k, b_e_n_k, c_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment