Commit 7822382b authored by coderfeli's avatar coderfeli
Browse files

tileM 32,64,128 ok

parent e15351ca
...@@ -118,6 +118,8 @@ using BElementOp = PassThrough; ...@@ -118,6 +118,8 @@ using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply; using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...@@ -133,13 +135,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -133,13 +135,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
//threadnum, mblock, nblock, kblock //threadnum, mblock, nblock, kblock
256, 64, 128, 128, 256, MPerBlock, 128, 128,
// ak1, bk1 // ak1, bk1
8, 8, 8, 8,
// mn_perxdl // mn_perxdl
32, 32, 32, 32,
// mn_xdlperwave // mn_xdlperwave
2, 1, MXDLPerWave, 1,
// a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra // a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
...@@ -169,10 +171,10 @@ int main(int argc, char* argv[]) ...@@ -169,10 +171,10 @@ int main(int argc, char* argv[])
ck::index_t N = 6144; ck::index_t N = 6144;
ck::index_t K = 8192; ck::index_t K = 8192;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 1; ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = 64; ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 64; ck::index_t tokens = 512;
if(argc == 1) if(argc == 1)
{ {
...@@ -337,7 +339,7 @@ int main(int argc, char* argv[]) ...@@ -337,7 +339,7 @@ int main(int argc, char* argv[])
if (time_kernel) { if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K * experts; std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N; sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment