Commit 61e3c238 authored by coderfeli's avatar coderfeli
Browse files

fix moe gemm2

parent db53dba4
...@@ -69,7 +69,7 @@ struct MulABScaleExpertWeight ...@@ -69,7 +69,7 @@ struct MulABScaleExpertWeight
{ {
// e = ck::type_convert<EDataType>(c * d0 * d1 * d2); // e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
(void) d2; (void) d2;
e = ck::type_convert<EDataType>(c * d0 * d1); e = ck::type_convert<EDataType>(c);
} }
// for reference // for reference
template <> template <>
...@@ -81,7 +81,7 @@ struct MulABScaleExpertWeight ...@@ -81,7 +81,7 @@ struct MulABScaleExpertWeight
const float& d2) const const float& d2) const
{ {
(void) d2; (void) d2;
e = ck::type_convert<EDataType>(c * d0 * d1); e = ck::type_convert<EDataType>(c);
} }
}; };
...@@ -253,7 +253,7 @@ int main(int argc, char* argv[]) ...@@ -253,7 +253,7 @@ int main(int argc, char* argv[])
} }
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({batch, K}, {K, 1})); Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({batch, topk, K}, {topk*K, K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({batch, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({batch, N}, {StrideDs[0], 0}));
...@@ -262,7 +262,7 @@ int main(int argc, char* argv[]) ...@@ -262,7 +262,7 @@ int main(int argc, char* argv[])
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero(); e_t_n_device_result.SetZero();
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
...@@ -273,21 +273,21 @@ int main(int argc, char* argv[]) ...@@ -273,21 +273,21 @@ int main(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2}); d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break; break;
case 2: case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break; break;
default: default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
...@@ -296,13 +296,13 @@ int main(int argc, char* argv[]) ...@@ -296,13 +296,13 @@ int main(int argc, char* argv[])
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt"); a0_t_k_k.savetxt("a.txt");
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
d0_t_n.savetxt("d0_t_n.txt", "int"); d0_t_n.savetxt("d0_t_n.txt", "int");
...@@ -311,7 +311,7 @@ int main(int argc, char* argv[]) ...@@ -311,7 +311,7 @@ int main(int argc, char* argv[])
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data()); a0_device_buf.ToDevice(a0_t_k_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data()); d2_device_buf.ToDevice(d2_e_n.mData.data());
...@@ -399,7 +399,7 @@ int main(int argc, char* argv[]) ...@@ -399,7 +399,7 @@ int main(int argc, char* argv[])
auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument( auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, max_token_id, MPerBlock, a0_t_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op); sorted_token_ids, expert_ids, max_token_id, MPerBlock, a0_t_k_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t) for(int t = 0; t < tokens; ++t)
......
...@@ -1167,8 +1167,8 @@ struct GridwiseMoeGemmScatter ...@@ -1167,8 +1167,8 @@ struct GridwiseMoeGemmScatter
return; return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos]; StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) { static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t token_offset = (token_pos + m0 < max_token_id) ? const index_t fused_token = p_sorted_token_ids[token_pos + m0];
(p_sorted_token_ids[token_pos + m0] & 0xffffff) : problem.NumTokens; const index_t token_offset = (fused_token & 0xffffff) * problem.TopK + (fused_token >> 24);
gather_offsets(m0) = token_offset * problem.K; gather_offsets(m0) = token_offset * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
}); });
......
...@@ -35,7 +35,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -35,7 +35,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const Tensor<ck::index_t>& expert_ids, const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id, const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k, const Tensor<ADataType>& a_t_k_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
const Tensor<D0DataType>& d0, const Tensor<D0DataType>& d0,
const Tensor<D1DataType>& d1, const Tensor<D1DataType>& d1,
...@@ -48,7 +48,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -48,7 +48,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
expert_ids_{expert_ids}, expert_ids_{expert_ids},
max_token_id_{max_token_id}, max_token_id_{max_token_id},
sorted_tile_size_{sorted_tile_size}, sorted_tile_size_{sorted_tile_size},
a_t_k_{a_t_k}, a_t_k_k_{a_t_k_k},
b_e_n_k_{b_e_n_k}, b_e_n_k_{b_e_n_k},
d0_{d0}, d0_{d0},
d1_{d1}, d1_{d1},
...@@ -64,7 +64,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -64,7 +64,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const Tensor<ck::index_t>& expert_ids_; const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& max_token_id_; const Tensor<ck::index_t>& max_token_id_;
index_t sorted_tile_size_; index_t sorted_tile_size_;
const Tensor<ADataType>& a_t_k_; const Tensor<ADataType>& a_t_k_k_;
const Tensor<BDataType>& b_e_n_k_; const Tensor<BDataType>& b_e_n_k_;
const Tensor<D0DataType>& d0_; const Tensor<D0DataType>& d0_;
const Tensor<D1DataType>& d1_; const Tensor<D1DataType>& d1_;
...@@ -85,11 +85,12 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -85,11 +85,12 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
{ {
arg.c_t_n_.SetZero(); arg.c_t_n_.SetZero();
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_t_k_.mDesc.GetLengths()[1]; const int K = arg.a_t_k_k_.mDesc.GetLengths()[2];
AccDataType v_acc{0}; AccDataType v_acc{0};
ComputeTypeA v_a{0}; ComputeTypeA v_a{0};
ComputeTypeB v_b{0}; ComputeTypeB v_b{0};
const int t = arg.sorted_token_ids_(m); const int t = arg.sorted_token_ids_(m) & 0xffffff;
const int topk_id = arg.sorted_token_ids_(m) >> 24;
const int e = arg.expert_ids_(m / arg.sorted_tile_size_); const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0]; const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0];
D2DataType v_topk_w = arg.d2_(m, 0); //expert D2DataType v_topk_w = arg.d2_(m, 0); //expert
...@@ -101,11 +102,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -101,11 +102,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
if constexpr(is_same_v<AElementwiseOperation, if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>) ck::tensor_operation::element_wise::ConvertBF16RTN>)
{ {
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_t_k_(t, k)); ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_t_k_k_(t, topk_id, k));
} }
else else
{ {
arg.a_element_op_(v_a, arg.a_t_k_(t, k)); arg.a_element_op_(v_a, arg.a_t_k_k_(t, topk_id, k));
} }
// same for B matrix // same for B matrix
if constexpr(is_same_v<BElementwiseOperation, if constexpr(is_same_v<BElementwiseOperation,
...@@ -124,7 +125,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -124,7 +125,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
CDataType v_c{0}; CDataType v_c{0};
D0DataType v_d0 = arg.d0_(m, n); // a D0DataType v_d0 = arg.d0_(m, n); // a
D0DataType v_d1 = arg.d1_(e, n); // b D0DataType v_d1 = arg.d1_(e, n); // b
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w); arg.c_element_op_(v_c, v_acc, v_d0 * v_topk_w, v_d1, v_topk_w);
arg.c_t_n_(t, n) += v_c; arg.c_t_n_(t, n) += v_c;
} }
...@@ -157,7 +158,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -157,7 +158,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const Tensor<ck::index_t>& expert_ids, const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id, const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k, const Tensor<ADataType>& a_t_k_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
const Tensor<D0DataType>& d0, const Tensor<D0DataType>& d0,
const Tensor<D1DataType>& d1, const Tensor<D1DataType>& d1,
...@@ -167,7 +168,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -167,7 +168,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a_t_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op}; return Argument{sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a_t_k_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment