Commit 167deece authored by mtgu0705's avatar mtgu0705
Browse files

test expert = 8 and function pass.

parent f904a37d
...@@ -152,7 +152,8 @@ using AElementOp = PassThrough; ...@@ -152,7 +152,8 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
#if 0 #if 1
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
...@@ -168,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< ...@@ -168,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout, Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
64, MPerBlock, 16, KPerBlock, 256, MPerBlock, 128, KPerBlock,
AK1, BK1, AK1, BK1,
MNPerXDL, MNPerXDL, MNPerXDL, MNPerXDL,
MXDLPerWave, 1, MXDLPerWave, 1,
...@@ -208,12 +209,12 @@ int main(int argc, char* argv[]) ...@@ -208,12 +209,12 @@ int main(int argc, char* argv[])
// GEMM shape // GEMM shape
ck::index_t N = 6144; ck::index_t N = 6144;
ck::index_t K = 8192; ck::index_t K = 8192;
ck::index_t experts = 1; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 1; ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
// ck::index_t tokens = 128; ck::index_t tokens = 128;
ck::index_t tokens = 16; // ck::index_t tokens = 16;
if(argc == 1) if(argc == 1)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment