Commit fc1558e3 authored by mtgu0705's avatar mtgu0705
Browse files

update int4 moe with latest input changes.

parent 9ff2394e
...@@ -154,12 +154,15 @@ using AElementOp = PassThrough; ...@@ -154,12 +154,15 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
#if 1 #if 0
static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MXDLPerWave = 2;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType); static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
...@@ -171,17 +174,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< ...@@ -171,17 +174,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout, Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, MPerBlock, 128, KPerBlock, BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock,
AK1, BK1, AK1, BK1,
MNPerXDL, MNPerXDL, MNPerXDL, MNPerXDL,
MXDLPerWave, 1, MXDLPerWave, NXDLPerWave,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
CShuffleMXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>, MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
// clang-format on // clang-format on
#else #else
static constexpr ck::index_t MPerBlock = 16; static constexpr ck::index_t MPerBlock = 16;
static constexpr ck::index_t Nswizzle = false;
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout, Row, Col, DsLayout, ELayout,
...@@ -194,7 +198,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< ...@@ -194,7 +198,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 4>, S<4, 1, 1>, 1, 1, S<1, 16, 1, 4>, S<4, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
// clang-format on // clang-format on
#endif #endif
...@@ -209,26 +213,30 @@ int main(int argc, char* argv[]) ...@@ -209,26 +213,30 @@ int main(int argc, char* argv[])
// experts = 8 // experts = 8
// per expert: // per expert:
// GEMM shape // GEMM shape
ck::index_t N = 6144; ck::index_t N = 14336 * 2;
ck::index_t K = 8192; ck::index_t K = 4096;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 16;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t valid_tile_num = 13;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t tokens = 128; ck::index_t valid_size = valid_tile_num * MPerBlock;
// ck::index_t tokens = 16; ck::index_t tokens = 64;
ck::index_t topk = 2;
// ck::index_t tokens = batch * topk;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
} }
else if(argc == 6) else if(argc == 7)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]); N = std::stoi(argv[4]);
K = std::stoi(argv[5]); K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
} }
else else
{ {
...@@ -236,10 +244,15 @@ int main(int argc, char* argv[]) ...@@ -236,10 +244,15 @@ int main(int argc, char* argv[])
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf( printf(
"arg4 to 5: N, K\n"); "arg4 to 5: N, K, tokens\n");
exit(0); exit(0);
} }
if (tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
ck::index_t StrideA = K; ck::index_t StrideA = K;
ck::index_t StrideB = K; ck::index_t StrideB = K;
ck::index_t StrideE = N; ck::index_t StrideE = N;
...@@ -249,21 +262,29 @@ int main(int argc, char* argv[]) ...@@ -249,21 +262,29 @@ int main(int argc, char* argv[])
ck::index_t KBatch = 1; ck::index_t KBatch = 1;
// const ck::index_t experts = 8; // const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1})); Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1})); Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 1, 2,2,0,0,0};
int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 7, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for (int i = 0; i < sorted_tile_num; i++) { for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i; expert_ids.mData[i] = eids[i];
} }
int token_per_tile = tokens / sorted_tile_num; int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0; int tokenid = 0;
// sorted_token_ids.mData[0] = 0; // sorted_token_ids.mData[0] = 0;
for (int i = 0; i < SORTED_SIZE; i++) { for (int i = 0; i < sorted_size; i++) {
int tile_off = i % sorted_tile_size; int tile_off = i % MPerBlock;
if(tile_off < token_per_tile) if(tile_off < token_per_tile)
sorted_token_ids.mData[i] = tokenid++; {
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else else
{
sorted_token_ids.mData[i] = tokens; sorted_token_ids.mData[i] = tokens;
} }
}
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
...@@ -271,41 +292,29 @@ int main(int argc, char* argv[]) ...@@ -271,41 +292,29 @@ int main(int argc, char* argv[])
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N*K, 1, K})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N*K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 3});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 3});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
break; break;
case 4: case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1}); d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1}); d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
break; break;
default: default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
...@@ -313,23 +322,23 @@ int main(int argc, char* argv[]) ...@@ -313,23 +322,23 @@ int main(int argc, char* argv[])
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt"); a0_t_k.savetxt("a.txt");
d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -432,13 +441,15 @@ int main(int argc, char* argv[]) ...@@ -432,13 +441,15 @@ int main(int argc, char* argv[])
auto argument = auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(), expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(), a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(), std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()}, d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
tokens, tokens,
SORTED_SIZE, topk,
sorted_size,
N, N,
K, K,
StrideA, StrideA,
...@@ -456,13 +467,12 @@ int main(int argc, char* argv[]) ...@@ -456,13 +467,12 @@ int main(int argc, char* argv[])
"wrong! device_gemm with the specified compilation parameters does " "wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); "not support this GEMM problem");
} }
if (time_kernel) { if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K; std::size_t flop = std::size_t(2) * tokens * topk * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N; sizeof(A0DataType) * valid_tile_num * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_tile_num * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -476,9 +486,9 @@ int main(int argc, char* argv[]) ...@@ -476,9 +486,9 @@ int main(int argc, char* argv[])
{ {
invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1}); invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1});
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<CShuffleDataType> c_m_n({SORTED_SIZE, N}); Tensor<CShuffleDataType> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType,
B0DataType, B0DataType,
...@@ -491,108 +501,37 @@ int main(int argc, char* argv[]) ...@@ -491,108 +501,37 @@ int main(int argc, char* argv[])
auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument( auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, sorted_tile_size, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); sorted_token_ids, expert_ids, max_token_id, MPerBlock, a0_t_k, b0_e_n_k, c_t_k_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int m = 0; m < SORTED_SIZE; ++m) for(int m = 0; m < valid_size; ++m)
{
const int t = sorted_token_ids(m);
const int e = expert_ids(m / sorted_tile_size);
for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(e, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); const int fuse_t = sorted_token_ids.mData[m];
e_m_n_device_result.savetxt("out.txt"); const int t = fuse_t & 0xffffff;
e_m_n_host_result.savetxt("ref.txt"); const int topk_id = (fuse_t & 0xff000000) >> 24;
// printf("m %d fuset %d %d %d\n",m, fuse_t, t, topk_id);
#if 0 if (t >= tokens)
printf("A Matrix:\n");
for(int t = 0; t < tokens; t++)
{ {
for(int k = 0; k < K; k++) continue;
{
printf("%f,", ck::type_convert<float>(a0_t_k(t, k)));
}
printf("\n");
} }
printf("\n"); const int e = expert_ids(m / MPerBlock);
for(int n = 0; n < N; ++n)
printf("B Matrix:\n");
for(int e = 0; e < experts; e++)
{
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b0_e_n_k(e, k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
printf("%f,", i4_to_f32_gfx9(i4));
}
printf("\n");
}
printf("\n");
}
printf("\n");
printf("B preshuflled Matrix:\n");
for(int e = 0; e < experts; e++)
{
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b0_preshuffled(e, k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
printf("%f,", i4_to_f32_gfx9(i4));
}
printf("\n");
}
printf("\n");
}
printf("\n");
printf("C device Matrix:\n");
for(int m = 0; m < SORTED_SIZE; m++)
{
for(int n = 0; n < N; n++)
{
printf("%f,", ck::type_convert<float>(e_m_n_device_result(m, n)));
}
printf("\n");
}
printf("\n");
printf("C host Matrix:\n");
for(int m = 0; m < SORTED_SIZE; m++)
{
for(int n = 0; n < N; n++)
{ {
printf("%f,", ck::type_convert<float>(e_m_n_host_result(m, n))); cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), d1_e_n(e, n));
// printf("m %d n %d topk %d token %d %f %f\n",m, n,topk_id, t, e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n));
} }
printf("\n");
} }
#endif
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
e_t_n_device_result.savetxt("out.txt");
e_t_n_host_result.savetxt("ref.txt");
return ck::utils::check_err( return ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0 ? 0
: 1; : 1;
} }
printf("end of kernel\n");
return 0; return 0;
} }
...@@ -57,7 +57,7 @@ struct MulABScaleExpertWeight ...@@ -57,7 +57,7 @@ struct MulABScaleExpertWeight
template <typename E, typename C, typename D0, typename D1, typename D2> template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
//real kernel use //for real kernel use
template <> template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float> __host__ __device__ constexpr void operator()<EDataType, float, float, float, float>
(EDataType& e, (EDataType& e,
...@@ -66,9 +66,12 @@ struct MulABScaleExpertWeight ...@@ -66,9 +66,12 @@ struct MulABScaleExpertWeight
const float& d1, const float& d1,
const float& d2) const const float& d2) const
{ {
e = ck::type_convert<EDataType>(c * d0 * d1 * d2); //for real kernel use
//warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. tofix:felix
(void) d0;
e = ck::type_convert<EDataType>(c * d1 * d2);
} }
// for reference // for reference cpu
template <> template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float> __host__ __device__ constexpr void operator()<float, float, float, float, float>
(float& e, (float& e,
...@@ -77,6 +80,7 @@ struct MulABScaleExpertWeight ...@@ -77,6 +80,7 @@ struct MulABScaleExpertWeight
const float& d1, const float& d1,
const float& d2) const const float& d2) const
{ {
// for reference cpu
e = ck::type_convert<EDataType>(c * d0 * d1 * d2); e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
} }
}; };
...@@ -121,14 +125,16 @@ using BElementOp = PassThrough; ...@@ -121,14 +125,16 @@ using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight; using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 64; static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t BLOCKSIZE = 256; static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t MXDLPerWave = 2;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint // static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; // static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t CShuffleNLane = NPerBlock / 2; static constexpr ck::index_t CShuffleNLane = 32;
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
...@@ -143,11 +149,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm ...@@ -143,11 +149,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock,
AK1, BK1, AK1, BK1,
MNPerXDL, MNPerXDL, MNPerXDL, MNPerXDL,
MXDLPerWave, 1, MXDLPerWave, NXDLPerWave,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
CShuffleMXDLPerWave, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>, MXDLPerWave, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
// clang-format on // clang-format on
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -161,25 +167,35 @@ int main(int argc, char* argv[]) ...@@ -161,25 +167,35 @@ int main(int argc, char* argv[])
// experts = 8 // experts = 8
// per expert: // per expert:
// GEMM shape // GEMM shape
ck::index_t N = 6144; ck::index_t N = 4096;
ck::index_t K = 8192; ck::index_t K = 14336;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 16;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t valid_tile_num = 13;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t tokens = 64; ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 512;
ck::index_t topk = 2;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
} }
else if(argc == 6) else if(argc == 3)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]); N = std::stoi(argv[4]);
K = std::stoi(argv[5]); K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
} }
else else
{ {
...@@ -187,7 +203,7 @@ int main(int argc, char* argv[]) ...@@ -187,7 +203,7 @@ int main(int argc, char* argv[])
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf( printf(
"arg4 to 5: N, K\n"); "arg4 to 6: N, K, tokens\n");
exit(0); exit(0);
} }
...@@ -200,80 +216,97 @@ int main(int argc, char* argv[]) ...@@ -200,80 +216,97 @@ int main(int argc, char* argv[])
ck::index_t KBatch = 1; ck::index_t KBatch = 1;
// const ck::index_t experts = 8; // const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1})); Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1})); Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 7, 7, 3, 3, 3};
for (int i = 0; i < sorted_tile_num; i++) { for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i; expert_ids.mData[i] = eids[i];
}
if (tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
} }
int token_per_tile = tokens / sorted_tile_num; int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0; int tokenid = 0;
// sorted_token_ids.mData[0] = 0; // sorted_token_ids.mData[0] = 0;
for (int i = 0; i < SORTED_SIZE; i++) { for (int i = 0; i < sorted_size; i++) {
int tile_off = i % sorted_tile_size; int tile_off = i % MPerBlock;
if(tile_off < token_per_tile) if(tile_off < token_per_tile)
sorted_token_ids.mData[i] = tokenid++; {
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else else
{
sorted_token_ids.mData[i] = tokens; sorted_token_ids.mData[i] = tokens;
} }
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1})); }
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk*K, K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N*K, 1, K})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N*K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N*K, 1, K})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N*K, 1, K}));
Tensor<D0DataType> d0_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({SORTED_SIZE, N}, {1, 0})); Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero(); e_t_n_device_result.SetZero();
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2}); d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break; break;
case 2: case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
} }
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_m_k.savetxt("a.txt"); a0_t_k_k.savetxt("a.txt");
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
d0_m_n.savetxt("d0_m_n.txt", "int"); d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int"); d1_e_n.savetxt("d1_e_n.txt", "int");
d2_e_n.savetxt("d2_e_n.txt", "int"); d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data()); a0_device_buf.ToDevice(a0_t_k_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data()); d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data()); e_device_buf.ToDevice(e_t_n_device_result.mData.data());
...@@ -345,6 +378,7 @@ int main(int argc, char* argv[]) ...@@ -345,6 +378,7 @@ int main(int argc, char* argv[])
auto argument = auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(), expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(), a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(), std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
...@@ -352,7 +386,8 @@ int main(int argc, char* argv[]) ...@@ -352,7 +386,8 @@ int main(int argc, char* argv[])
d2_device_buf.GetDeviceBuffer()}, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
tokens, tokens,
SORTED_SIZE, topk,
sorted_size,
N, N,
K, K,
StrideA, StrideA,
...@@ -374,9 +409,9 @@ int main(int argc, char* argv[]) ...@@ -374,9 +409,9 @@ int main(int argc, char* argv[])
// not result correct here because output buf not setzero // not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K; std::size_t flop = std::size_t(2) * tokens * topk * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N; sizeof(A0DataType) * tokens * K * topk + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -409,10 +444,9 @@ int main(int argc, char* argv[]) ...@@ -409,10 +444,9 @@ int main(int argc, char* argv[])
auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument( auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, d0_m_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op); sorted_token_ids, expert_ids, max_token_id, MPerBlock, a0_t_k_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t) for(int t = 0; t < tokens; ++t)
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment