Commit db53dba4 authored by coderfeli's avatar coderfeli
Browse files

hotfix:gemm1 use real tokens and gemm2 ok

parent 58db931e
...@@ -35,7 +35,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -35,7 +35,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8; using A0DataType = F8;
using B0DataType = F8; using B0DataType = F8;
using EDataType = F32; using EDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
...@@ -133,6 +133,8 @@ using BElementOp = PassThrough; ...@@ -133,6 +133,8 @@ using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 32; static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
...@@ -156,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm ...@@ -156,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
//threadnum, mblock, nblock, kblock //threadnum, mblock, nblock, kblock
256, MPerBlock, 128, KPerBlock, BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock,
// ak1, bk1 // ak1, bk1
AK1, BK1, AK1, BK1,
// mn_perxdl // mn_perxdl
...@@ -196,10 +198,10 @@ int main(int argc, char* argv[]) ...@@ -196,10 +198,10 @@ int main(int argc, char* argv[])
ck::index_t valid_tile_num = 8; ck::index_t valid_tile_num = 8;
ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t batch = 64; ck::index_t tokens = 64;
ck::index_t topk = 2; ck::index_t topk = 2;
ck::index_t tokens = batch * topk; // ck::index_t tokens = batch * topk;
if(argc == 1) if(argc == 1)
{ {
...@@ -225,7 +227,6 @@ int main(int argc, char* argv[]) ...@@ -225,7 +227,6 @@ int main(int argc, char* argv[])
ck::index_t StrideA = K; ck::index_t StrideA = K;
ck::index_t StrideB = K; ck::index_t StrideB = K;
// ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size(); constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0}; constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0};
...@@ -241,14 +242,14 @@ int main(int argc, char* argv[]) ...@@ -241,14 +242,14 @@ int main(int argc, char* argv[])
for (int i = 0; i < sorted_tile_num; i++) { for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i; expert_ids.mData[i] = i;
} }
int token_per_tile = tokens / valid_tile_num; int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0; int tokenid = 0;
// sorted_token_ids.mData[0] = 0; // sorted_token_ids.mData[0] = 0;
for (int i = 0; i < sorted_size; i++) { for (int i = 0; i < sorted_size; i++) {
int tile_off = i % valid_size; int tile_off = i % valid_size;
if(tile_off < token_per_tile) if(tile_off < token_per_tile)
{ {
sorted_token_ids.mData[i] = (tokenid % batch) | ((tokenid / batch) << 24); sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++; tokenid++;
} }
else else
...@@ -258,13 +259,13 @@ int main(int argc, char* argv[]) ...@@ -258,13 +259,13 @@ int main(int argc, char* argv[])
} }
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({batch, K}, {K, 1})); Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({batch, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({batch, topk, N}, {topk * N, N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({batch, topk, N}, {topk * N, N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
...@@ -293,8 +294,6 @@ int main(int argc, char* argv[]) ...@@ -293,8 +294,6 @@ int main(int argc, char* argv[])
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
...@@ -304,13 +303,14 @@ int main(int argc, char* argv[]) ...@@ -304,13 +303,14 @@ int main(int argc, char* argv[])
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt"); a0_t_k.savetxt("a.txt");
d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
// e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -358,9 +358,9 @@ int main(int argc, char* argv[]) ...@@ -358,9 +358,9 @@ int main(int argc, char* argv[])
if (time_kernel) { if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * sorted_size * N * K; std::size_t flop = std::size_t(2) * valid_size * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(A0DataType) * sorted_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * sorted_size * N; sizeof(A0DataType) * valid_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_size * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -376,7 +376,7 @@ int main(int argc, char* argv[]) ...@@ -376,7 +376,7 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data()); e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<CShuffleDataType> c_t_k_n({batch, topk, N}, {topk * N, N, 1}); Tensor<CShuffleDataType> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType,
B0DataType, B0DataType,
......
...@@ -67,7 +67,9 @@ struct MulABScaleExpertWeight ...@@ -67,7 +67,9 @@ struct MulABScaleExpertWeight
const float& d1, const float& d1,
const float& d2) const const float& d2) const
{ {
e = ck::type_convert<EDataType>(c * d0 * d1 * d2); // e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
(void) d2;
e = ck::type_convert<EDataType>(c * d0 * d1);
} }
// for reference // for reference
template <> template <>
...@@ -78,7 +80,8 @@ struct MulABScaleExpertWeight ...@@ -78,7 +80,8 @@ struct MulABScaleExpertWeight
const float& d1, const float& d1,
const float& d2) const const float& d2) const
{ {
e = ck::type_convert<EDataType>(c * d0 * d1 * d2); (void) d2;
e = ck::type_convert<EDataType>(c * d0 * d1);
} }
}; };
...@@ -187,10 +190,12 @@ int main(int argc, char* argv[]) ...@@ -187,10 +190,12 @@ int main(int argc, char* argv[])
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 9; ck::index_t sorted_tile_num = 9;
ck::index_t valid_tile_num = 8; ck::index_t valid_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t sorted_size = sorted_tile_num * sorted_tile_size; ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * sorted_tile_size; ck::index_t batch = 64;
ck::index_t tokens = 64; ck::index_t topk = 2;
ck::index_t tokens = batch;
if(argc == 1) if(argc == 1)
{ {
...@@ -231,77 +236,83 @@ int main(int argc, char* argv[]) ...@@ -231,77 +236,83 @@ int main(int argc, char* argv[])
for (int i = 0; i < sorted_tile_num; i++) { for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i; expert_ids.mData[i] = i;
} }
int token_per_tile = tokens / sorted_tile_num; int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0; int tokenid = 0;
// sorted_token_ids.mData[0] = 0; // sorted_token_ids.mData[0] = 0;
for (int i = 0; i < sorted_size; i++) { for (int i = 0; i < sorted_size; i++) {
int tile_off = i % sorted_tile_size; int tile_off = i % valid_size;
if(tile_off < token_per_tile) if(tile_off < token_per_tile)
sorted_token_ids.mData[i] = tokenid++; {
sorted_token_ids.mData[i] = (tokenid % batch) | ((tokenid / batch) << 24);
tokenid++;
}
else else
{
sorted_token_ids.mData[i] = tokens; sorted_token_ids.mData[i] = tokens;
} }
}
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({sorted_size, K}, {K, 1})); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({batch, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_m_n(HostTensorDescriptor({sorted_size, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({batch, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero(); e_t_n_device_result.SetZero();
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2}); d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break; break;
case 2: case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
} }
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_m_k.savetxt("a.txt"); a0_t_k.savetxt("a.txt");
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
d0_m_n.savetxt("d0_m_n.txt", "int"); d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int"); d1_e_n.savetxt("d1_e_n.txt", "int");
d2_e_n.savetxt("d2_e_n.txt", "int"); d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data()); d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data()); e_device_buf.ToDevice(e_t_n_device_result.mData.data());
...@@ -332,6 +343,7 @@ int main(int argc, char* argv[]) ...@@ -332,6 +343,7 @@ int main(int argc, char* argv[])
d2_device_buf.GetDeviceBuffer()}, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
tokens, tokens,
topk,
sorted_size, sorted_size,
N, N,
K, K,
...@@ -354,9 +366,9 @@ int main(int argc, char* argv[]) ...@@ -354,9 +366,9 @@ int main(int argc, char* argv[])
// not result correct here because output buf not setzero // not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * sorted_size * N * K; std::size_t flop = std::size_t(2) * valid_size * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(A0DataType) * sorted_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * sorted_size * N; sizeof(A0DataType) * valid_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_size * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -387,12 +399,12 @@ int main(int argc, char* argv[]) ...@@ -387,12 +399,12 @@ int main(int argc, char* argv[])
auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument( auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a0_m_k, b0_e_n_k, d0_m_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op); sorted_token_ids, expert_ids, max_token_id, MPerBlock, a0_t_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t) for(int t = 0; t < tokens; ++t)
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n)); e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
...@@ -402,7 +414,6 @@ int main(int argc, char* argv[]) ...@@ -402,7 +414,6 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data()); e_device_buf.FromDevice(e_t_n_device_result.mData.data());
e_t_n_device_result.savetxt("out.txt"); e_t_n_device_result.savetxt("out.txt");
e_t_n_host_result.savetxt("ref.txt"); e_t_n_host_result.savetxt("ref.txt");
return ck::utils::check_err( return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0 ? 0
......
...@@ -638,7 +638,6 @@ struct GridwiseMoeGemmGather ...@@ -638,7 +638,6 @@ struct GridwiseMoeGemmGather
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_) CElementwiseOperation c_element_op_)
: Problem{NumTokens_, TopK_, M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, : Problem{NumTokens_, TopK_, M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
p_sorted_token_ids{p_sorted_token_ids_}, p_sorted_token_ids{p_sorted_token_ids_},
p_sorted_expert_ids{p_sorted_expert_ids_}, p_sorted_expert_ids{p_sorted_expert_ids_},
p_max_token_id{p_max_token_id_}, p_max_token_id{p_max_token_id_},
...@@ -663,7 +662,6 @@ struct GridwiseMoeGemmGather ...@@ -663,7 +662,6 @@ struct GridwiseMoeGemmGather
const index_t * p_sorted_token_ids; const index_t * p_sorted_token_ids;
const index_t * p_sorted_expert_ids; const index_t * p_sorted_expert_ids;
const index_t * p_max_token_id; const index_t * p_max_token_id;
const ADataType* p_a_grid; const ADataType* p_a_grid;
const BDataType* p_b_grid; const BDataType* p_b_grid;
DsGridPointer p_ds_grid; DsGridPointer p_ds_grid;
...@@ -1146,7 +1144,7 @@ struct GridwiseMoeGemmGather ...@@ -1146,7 +1144,7 @@ struct GridwiseMoeGemmGather
const auto b_grid_desc_bpreshuffled = const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>( const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); problem.NumTokens * problem.TopK, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(), // printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock); // problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
...@@ -1165,7 +1163,6 @@ struct GridwiseMoeGemmGather ...@@ -1165,7 +1163,6 @@ struct GridwiseMoeGemmGather
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto AKThreads = AK0Threads * AK1Threads; constexpr auto AKThreads = AK0Threads * AK1Threads;
constexpr auto AMRepeats = MPerBlock / AMThreads; constexpr auto AMRepeats = MPerBlock / AMThreads;
// static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
if(token_pos >= max_token_id || token0 >= problem.NumTokens) if(token_pos >= max_token_id || token0 >= problem.NumTokens)
...@@ -1177,8 +1174,6 @@ struct GridwiseMoeGemmGather ...@@ -1177,8 +1174,6 @@ struct GridwiseMoeGemmGather
gather_offsets(m0) = token_offset * problem.K; gather_offsets(m0) = token_offset * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
}); });
// const index_t m_block_data_idx_on_grid =
// __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
// N0, K0, Blocksize*KPack // N0, K0, Blocksize*KPack
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -536,6 +536,7 @@ struct GridwiseMoeGemmScatter ...@@ -536,6 +536,7 @@ struct GridwiseMoeGemmScatter
struct Problem struct Problem
{ {
__host__ __device__ Problem(index_t NumTokens_, __host__ __device__ Problem(index_t NumTokens_,
index_t TopK_,
index_t M_, index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
...@@ -546,6 +547,7 @@ struct GridwiseMoeGemmScatter ...@@ -546,6 +547,7 @@ struct GridwiseMoeGemmScatter
index_t KBatch_) index_t KBatch_)
: :
NumTokens{NumTokens_}, NumTokens{NumTokens_},
TopK{TopK_},
M{M_}, M{M_},
N{N_}, N{N_},
K{K_}, K{K_},
...@@ -571,6 +573,7 @@ struct GridwiseMoeGemmScatter ...@@ -571,6 +573,7 @@ struct GridwiseMoeGemmScatter
{ {
std::cout << "problem {" std::cout << "problem {"
<< "NumTokens:" << NumTokens << ", " << "NumTokens:" << NumTokens << ", "
<< "TopK:" << TopK << ", "
<< "M:" << M << ", " << "M:" << M << ", "
<< "N:" << N << ", " << "N:" << N << ", "
<< "K:" << K << ", " << "K:" << K << ", "
...@@ -588,6 +591,7 @@ struct GridwiseMoeGemmScatter ...@@ -588,6 +591,7 @@ struct GridwiseMoeGemmScatter
} }
index_t NumTokens; index_t NumTokens;
index_t TopK;
index_t M; index_t M;
index_t N; index_t N;
index_t K; index_t K;
...@@ -621,6 +625,7 @@ struct GridwiseMoeGemmScatter ...@@ -621,6 +625,7 @@ struct GridwiseMoeGemmScatter
std::array<const void*, NumDTensor> p_ds_grid_, std::array<const void*, NumDTensor> p_ds_grid_,
CDataType* p_c_grid_, CDataType* p_c_grid_,
index_t NumTokens_, index_t NumTokens_,
index_t TopK_,
index_t M_, index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
...@@ -632,8 +637,7 @@ struct GridwiseMoeGemmScatter ...@@ -632,8 +637,7 @@ struct GridwiseMoeGemmScatter
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_) CElementwiseOperation c_element_op_)
: Problem{NumTokens_, M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, : Problem{NumTokens_, TopK_, M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
p_sorted_token_ids{p_sorted_token_ids_}, p_sorted_token_ids{p_sorted_token_ids_},
p_sorted_expert_ids{p_sorted_expert_ids_}, p_sorted_expert_ids{p_sorted_expert_ids_},
p_max_token_id{p_max_token_id_}, p_max_token_id{p_max_token_id_},
...@@ -1135,7 +1139,7 @@ struct GridwiseMoeGemmScatter ...@@ -1135,7 +1139,7 @@ struct GridwiseMoeGemmScatter
{ {
ignore = b_element_op; ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled = const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
...@@ -1151,12 +1155,25 @@ struct GridwiseMoeGemmScatter ...@@ -1151,12 +1155,25 @@ struct GridwiseMoeGemmScatter
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
const index_t m_block_data_idx_on_grid = // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto AKThreads = AK0Threads * AK1Threads;
constexpr auto AMRepeats = MPerBlock / AMThreads;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
if(m_block_data_idx_on_grid >= max_token_id || token0 >= problem.NumTokens) if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return; return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t token_offset = (token_pos + m0 < max_token_id) ?
(p_sorted_token_ids[token_pos + m0] & 0xffffff) : problem.NumTokens;
gather_offsets(m0) = token_offset * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
// N0, K0, Blocksize*KPack // N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
...@@ -1177,7 +1194,7 @@ struct GridwiseMoeGemmScatter ...@@ -1177,7 +1194,7 @@ struct GridwiseMoeGemmScatter
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1_mod8<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -1198,13 +1215,15 @@ struct GridwiseMoeGemmScatter ...@@ -1198,13 +1215,15 @@ struct GridwiseMoeGemmScatter
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
1,
BlockwiseGemmPipe::GlobalBufferNum>( BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, 0, 0),
a_element_op, a_element_op,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
// Thread-wise copy // Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
...@@ -1384,11 +1403,11 @@ struct GridwiseMoeGemmScatter ...@@ -1384,11 +1403,11 @@ struct GridwiseMoeGemmScatter
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const DDataType *ptr_ = p_ds_grid[i]; const DDataType *ptr_ = p_ds_grid[i];
// hack logic here to support different kind of strides. todo fix it. // hack logic here to support different kind of strides. todo fix it.
// ascale M, 1; bscale E, N, 1, move ptr to E // ascale t, 1; bscale E, N, 1, move ptr to E
if (i.value == 1) if (i.value == 1)
{ {
ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1); ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1);
// if ( threadIdx.x ==0) // if ( threadIdx.x % 16 ==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]); // printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
} }
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1428,7 +1447,6 @@ struct GridwiseMoeGemmScatter ...@@ -1428,7 +1447,6 @@ struct GridwiseMoeGemmScatter
using CDEBlockTransferCluster = using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = MPerBlock / EMThreads; constexpr auto EMRepeats = MPerBlock / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
...@@ -1436,10 +1454,12 @@ struct GridwiseMoeGemmScatter ...@@ -1436,10 +1454,12 @@ struct GridwiseMoeGemmScatter
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos]; StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme // too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I2]; const float *p_sorted_weights_2 = p_ds_grid[I2];
const float *p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, EMRepeats, 1>{}([&](auto m0) { static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = (p_sorted_token_ids[c_token_pos + m0] & 0xffffff) * problem.N; scatter_offsets(m0) = (p_sorted_token_ids[c_token_pos + m0] & 0xffffff) * problem.N;
scatter_weights(m0) = p_sorted_weights[c_token_pos + m0]; scatter_weights(m0) = p_sorted_weights_2[c_token_pos + m0]
* p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
}); });
......
...@@ -35,7 +35,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -35,7 +35,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const Tensor<ck::index_t>& expert_ids, const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id, const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
const Tensor<D0DataType>& d0, const Tensor<D0DataType>& d0,
const Tensor<D1DataType>& d1, const Tensor<D1DataType>& d1,
...@@ -48,7 +48,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -48,7 +48,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
expert_ids_{expert_ids}, expert_ids_{expert_ids},
max_token_id_{max_token_id}, max_token_id_{max_token_id},
sorted_tile_size_{sorted_tile_size}, sorted_tile_size_{sorted_tile_size},
a_m_k_{a_m_k}, a_t_k_{a_t_k},
b_e_n_k_{b_e_n_k}, b_e_n_k_{b_e_n_k},
d0_{d0}, d0_{d0},
d1_{d1}, d1_{d1},
...@@ -64,7 +64,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -64,7 +64,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const Tensor<ck::index_t>& expert_ids_; const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& max_token_id_; const Tensor<ck::index_t>& max_token_id_;
index_t sorted_tile_size_; index_t sorted_tile_size_;
const Tensor<ADataType>& a_m_k_; const Tensor<ADataType>& a_t_k_;
const Tensor<BDataType>& b_e_n_k_; const Tensor<BDataType>& b_e_n_k_;
const Tensor<D0DataType>& d0_; const Tensor<D0DataType>& d0_;
const Tensor<D1DataType>& d1_; const Tensor<D1DataType>& d1_;
...@@ -85,7 +85,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -85,7 +85,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
{ {
arg.c_t_n_.SetZero(); arg.c_t_n_.SetZero();
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_t_k_.mDesc.GetLengths()[1];
AccDataType v_acc{0}; AccDataType v_acc{0};
ComputeTypeA v_a{0}; ComputeTypeA v_a{0};
ComputeTypeB v_b{0}; ComputeTypeB v_b{0};
...@@ -101,11 +101,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -101,11 +101,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
if constexpr(is_same_v<AElementwiseOperation, if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>) ck::tensor_operation::element_wise::ConvertBF16RTN>)
{ {
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_t_k_(t, k));
} }
else else
{ {
arg.a_element_op_(v_a, arg.a_m_k_(m, k)); arg.a_element_op_(v_a, arg.a_t_k_(t, k));
} }
// same for B matrix // same for B matrix
if constexpr(is_same_v<BElementwiseOperation, if constexpr(is_same_v<BElementwiseOperation,
...@@ -157,7 +157,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -157,7 +157,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const Tensor<ck::index_t>& expert_ids, const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id, const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
const Tensor<D0DataType>& d0, const Tensor<D0DataType>& d0,
const Tensor<D1DataType>& d1, const Tensor<D1DataType>& d1,
...@@ -167,7 +167,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -167,7 +167,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a_m_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op}; return Argument{sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a_t_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment