Commit 167deece authored by mtgu0705's avatar mtgu0705
Browse files

test expert = 8 and function pass.

parent f904a37d
......@@ -152,7 +152,8 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
#if 0
#if 1
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
......@@ -168,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
64, MPerBlock, 16, KPerBlock,
256, MPerBlock, 128, KPerBlock,
AK1, BK1,
MNPerXDL, MNPerXDL,
MXDLPerWave, 1,
......@@ -208,12 +209,12 @@ int main(int argc, char* argv[])
// GEMM shape
ck::index_t N = 6144;
ck::index_t K = 8192;
ck::index_t experts = 1;
ck::index_t sorted_tile_num = 1;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
// ck::index_t tokens = 128;
ck::index_t tokens = 16;
ck::index_t tokens = 128;
// ck::index_t tokens = 16;
if(argc == 1)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment