Commit b3ae04f8 authored by coderfeli's avatar coderfeli
Browse files

fix ref gemm no padding

parent 1078d229
......@@ -132,7 +132,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
......@@ -255,13 +255,13 @@ int main(int argc, char* argv[])
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({batch, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({batch, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({batch, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({batch, topk, N}, {topk * N, N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
......@@ -370,7 +370,7 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<CShuffleDataType> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
Tensor<CShuffleDataType> c_t_k_n({batch, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType,
B0DataType,
......@@ -401,8 +401,8 @@ int main(int argc, char* argv[])
const int e = expert_ids(m / sorted_tile_size);
for(int n = 0; n < N; ++n)
{
cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(m, topk_id, n), d0_t_n(t, n), d1_e_n(e, n));
printf("m %d fuset %d %d %d %f %f\n",m, topk_id, t, n, e_t_n_host_result(t, topk_id, n), c_t_k_n(m, topk_id, n));
cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), d1_e_n(e, n));
printf("m %d n %d topk %d token %d %f %f\n",m, n,topk_id, t, e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n));
}
}
......
......@@ -74,8 +74,6 @@ struct ReferenceMoeGemm : public device::BaseOperator
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
if(m >= max_sorted_num)
return;
const int t = arg.sorted_token_ids_(m) & 0xffffff;
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
......@@ -112,6 +110,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
arg.c_element_op_(v_c, v_acc);
arg.c_t_k_n_(t, topk_id, n) = v_c;
printf("ref m %d n %d t %d topk %d v %f\n", m, n, t, topk_id, v_c);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment