Commit f904a37d authored by mtgu0705's avatar mtgu0705
Browse files

fix bug in moe_gemm1.cpp, now function pass.

parent 0f444e0c
...@@ -89,17 +89,70 @@ struct MulABScaleSilu ...@@ -89,17 +89,70 @@ struct MulABScaleSilu
} }
}; };
// using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate;
using CDEElementOp = MulABScale; using CDEElementOp = MulABScale;
// using CDEElementOp = MulABScaleSiluMulGate;
#if 1
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{
int KPack = 32;
int NLane = NXdl;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex / 2] = src[(n * K + k) / 2];
}
}
}
#endif
float i4_to_f32_gfx9(uint8_t i4)
{
static std::unordered_map<uint8_t, float> u = {{0b1000, -0.5000f},
{0b1001, -0.4375f},
{0b1010, -0.3750f},
{0b1011, -0.3125f},
{0b1100, -0.2500f},
{0b1101, -0.1875f},
{0b1110, -0.1250f},
{0b1111, -0.0625f},
{0b0, +0.0000f},
{0b1, +0.0625f},
{0b10, +0.1250f},
{0b11, +0.1875f},
{0b100, +0.2500f},
{0b101, +0.3125f},
{0b110, +0.3750f},
{0b111, +0.4375f}};
return u[i4];
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128; #if 0
static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
...@@ -115,7 +168,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< ...@@ -115,7 +168,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout, Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, MPerBlock, 128, KPerBlock, 64, MPerBlock, 16, KPerBlock,
AK1, BK1, AK1, BK1,
MNPerXDL, MNPerXDL, MNPerXDL, MNPerXDL,
MXDLPerWave, 1, MXDLPerWave, 1,
...@@ -124,6 +177,23 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< ...@@ -124,6 +177,23 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
CShuffleMXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>, CShuffleMXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>;
// clang-format on // clang-format on
#else
static constexpr ck::index_t MPerBlock = 16;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
64, 16, 16, 128,
16, 32,
16, 16,
1, 1,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 4>, S<4, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>;
// clang-format on
#endif
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -138,11 +208,12 @@ int main(int argc, char* argv[]) ...@@ -138,11 +208,12 @@ int main(int argc, char* argv[])
// GEMM shape // GEMM shape
ck::index_t N = 6144; ck::index_t N = 6144;
ck::index_t K = 8192; ck::index_t K = 8192;
ck::index_t experts = 8; ck::index_t experts = 1;
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 1;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 128; // ck::index_t tokens = 128;
ck::index_t tokens = 16;
if(argc == 1) if(argc == 1)
{ {
...@@ -169,7 +240,6 @@ int main(int argc, char* argv[]) ...@@ -169,7 +240,6 @@ int main(int argc, char* argv[])
ck::index_t StrideA = K; ck::index_t StrideA = K;
ck::index_t StrideB = K; ck::index_t StrideB = K;
ck::index_t StrideE = N; ck::index_t StrideE = N;
ck::index_t batch_stride_B = K * N;
constexpr ck::index_t NumDTensor = DsDataType::Size(); constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0}; constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0};
...@@ -194,8 +264,8 @@ int main(int argc, char* argv[]) ...@@ -194,8 +264,8 @@ int main(int argc, char* argv[])
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N*K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N*K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
...@@ -217,10 +287,22 @@ int main(int argc, char* argv[]) ...@@ -217,10 +287,22 @@ int main(int argc, char* argv[])
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 3}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 3});
break; break;
case 2: case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
break;
case 4:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
break; break;
default: default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
...@@ -238,6 +320,7 @@ int main(int argc, char* argv[]) ...@@ -238,6 +320,7 @@ int main(int argc, char* argv[])
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt"); a0_t_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data());
...@@ -252,8 +335,9 @@ int main(int argc, char* argv[]) ...@@ -252,8 +335,9 @@ int main(int argc, char* argv[])
// do GEMM // do GEMM
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
// preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); #if 1
printf("Start PreShuffle\n"); preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, device_op.GetPreShuffleParameters());
#else
// weight pre-shuffle // weight pre-shuffle
int KPack = 32; // int4 -> 32, fp8 -> 16, fp16 -> 8 int KPack = 32; // int4 -> 32, fp8 -> 16, fp16 -> 8
int NLane = device_op.GetPreShuffleParameters(); int NLane = device_op.GetPreShuffleParameters();
...@@ -279,20 +363,20 @@ int main(int argc, char* argv[]) ...@@ -279,20 +363,20 @@ int main(int argc, char* argv[])
int k2 = tempk % KPack; int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2; k1 * KPack * NLane + n1 * KPack + k2;
b0_preshuffled(e * batch_stride_B + outputIndex) = b0_preshuffled(e, outputIndex % K, outputIndex / K) = b0_e_n_k(e, k, n);
b0_e_n_k(e * batch_stride_B + n * K + k);
} }
} }
} }
printf("End PreShuffle, and Start vector permute\n"); #endif
// vector pk_i4x4 permute // vector pk_i4x4 permute
for(int e = 0; e < experts; e++) for(int e = 0; e < experts; e++)
{ {
for(int i = 0; i < N; i++) for(int i = 0; i < N; i++)
{ {
for(int j = 0; j < K; j++) for(int j = 0; j < K; j += 8)
{ {
int input[8]; int input[8];
...@@ -341,7 +425,6 @@ int main(int argc, char* argv[]) ...@@ -341,7 +425,6 @@ int main(int argc, char* argv[])
b0_device_buf.ToDevice(b0_preshuffled.mData.data()); b0_device_buf.ToDevice(b0_preshuffled.mData.data());
printf("End Permute and Start GEMM\n");
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
...@@ -370,6 +453,7 @@ int main(int argc, char* argv[]) ...@@ -370,6 +453,7 @@ int main(int argc, char* argv[])
"wrong! device_gemm with the specified compilation parameters does " "wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); "not support this GEMM problem");
} }
if (time_kernel) { if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
...@@ -381,8 +465,8 @@ int main(int argc, char* argv[]) ...@@ -381,8 +465,8 @@ int main(int argc, char* argv[])
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< std::endl; << " GB/s" << device_op.GetTypeString() << std::endl;
} }
if(do_verification) if(do_verification)
...@@ -421,11 +505,91 @@ int main(int argc, char* argv[]) ...@@ -421,11 +505,91 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
e_m_n_device_result.savetxt("out.txt"); e_m_n_device_result.savetxt("out.txt");
e_m_n_host_result.savetxt("ref.txt"); e_m_n_host_result.savetxt("ref.txt");
#if 0
printf("A Matrix:\n");
for(int t = 0; t < tokens; t++)
{
for(int k = 0; k < K; k++)
{
printf("%f,", ck::type_convert<float>(a0_t_k(t, k)));
}
printf("\n");
}
printf("\n");
printf("B Matrix:\n");
for(int e = 0; e < experts; e++)
{
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b0_e_n_k(e, k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
printf("%f,", i4_to_f32_gfx9(i4));
}
printf("\n");
}
printf("\n");
}
printf("\n");
printf("B preshuflled Matrix:\n");
for(int e = 0; e < experts; e++)
{
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b0_preshuffled(e, k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
printf("%f,", i4_to_f32_gfx9(i4));
}
printf("\n");
}
printf("\n");
}
printf("\n");
printf("C device Matrix:\n");
for(int m = 0; m < SORTED_SIZE; m++)
{
for(int n = 0; n < N; n++)
{
printf("%f,", ck::type_convert<float>(e_m_n_device_result(m, n)));
}
printf("\n");
}
printf("\n");
printf("C host Matrix:\n");
for(int m = 0; m < SORTED_SIZE; m++)
{
for(int n = 0; n < N; n++)
{
printf("%f,", ck::type_convert<float>(e_m_n_host_result(m, n)));
}
printf("\n");
}
#endif
return ck::utils::check_err( return ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0 ? 0
: 1; : 1;
} }
printf("end of kernel\n");
return 0; return 0;
} }
...@@ -446,6 +446,34 @@ struct DeviceMoeGemm ...@@ -446,6 +446,34 @@ struct DeviceMoeGemm
throw std::runtime_error("todo: only v1 & v2 support now"); throw std::runtime_error("todo: only v1 & v2 support now");
} }
} }
#if 1
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
// if(arg.KBatch > 1)
// {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle<
// GridwiseGemm,
// false,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// Run(kernel);
// }
// else
{
const auto kernel = kernel_moe_gemm_gather<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
}
}
#endif
return ave_time; return ave_time;
} }
......
...@@ -1086,7 +1086,7 @@ struct GridwiseMoeGemmGather ...@@ -1086,7 +1086,7 @@ struct GridwiseMoeGemmGather
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
#if 1 #if 0
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
...@@ -1193,7 +1193,7 @@ struct GridwiseMoeGemmGather ...@@ -1193,7 +1193,7 @@ struct GridwiseMoeGemmGather
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize());
// if(threadIdx.x==0) // if(threadIdx.x==0)
// printf("tid %d eid %d expert_stride %d bufsize %d\n", // printf("tid %d eid %d expert_stride %d bufsize %d\n",
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); // threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -1248,7 +1248,7 @@ struct GridwiseMoeGemmGather ...@@ -1248,7 +1248,7 @@ struct GridwiseMoeGemmGather
decltype(b_grid_desc_bpreshuffled), decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>, Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<0, 1, 2, 3>, Sequence<1, 2, 0, 3>,
3, 3,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
......
...@@ -109,7 +109,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -109,7 +109,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
// } // }
if constexpr(is_same_v<BDataType, pk_i4_t>) if constexpr(is_same_v<BDataType, pk_i4_t>)
{ {
uint8_t i4x2 = arg.b_e_n_k_(e, n, k).data; uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data;
uint8_t i4 = 0; uint8_t i4 = 0;
if(k % 2 == 1) if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf; i4 = (i4x2 >> 0) & 0xf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment