Commit 7796fc73 authored by coderfeli's avatar coderfeli
Browse files

fix gemm2 scale, gemm2 ok now

parent 61e3c238
......@@ -358,9 +358,9 @@ int main(int argc, char* argv[])
if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * valid_size * N * K;
std::size_t flop = std::size_t(2) * valid_tile_num * N * K;
std::size_t num_btype =
sizeof(A0DataType) * valid_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_size * N;
sizeof(A0DataType) * valid_tile_num * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_tile_num * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
......@@ -58,7 +58,7 @@ struct MulABScaleExpertWeight
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
//real kernel use
//for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>
(EDataType& e,
......@@ -67,11 +67,12 @@ struct MulABScaleExpertWeight
const float& d1,
const float& d2) const
{
// e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
(void) d2;
e = ck::type_convert<EDataType>(c);
//for real kernel use
//warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. tofix:felix
(void) d0;
e = ck::type_convert<EDataType>(c * d1 * d2);
}
// for reference
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>
(float& e,
......@@ -80,8 +81,8 @@ struct MulABScaleExpertWeight
const float& d1,
const float& d2) const
{
(void) d2;
e = ck::type_convert<EDataType>(c);
// for reference cpu
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
}
};
......@@ -124,7 +125,7 @@ using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 64;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
......@@ -188,15 +189,13 @@ int main(int argc, char* argv[])
ck::index_t N = 6144;
ck::index_t K = 8192;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 9;
ck::index_t sorted_tile_num = 10;
ck::index_t valid_tile_num = 8;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t batch = 64;
ck::index_t tokens = 64;
ck::index_t topk = 2;
ck::index_t tokens = batch;
if(argc == 1)
{
// use default case
......@@ -236,6 +235,11 @@ int main(int argc, char* argv[])
for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i;
}
if (tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
......@@ -243,20 +247,21 @@ int main(int argc, char* argv[])
int tile_off = i % valid_size;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % batch) | ((tokenid / batch) << 24);
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({batch, topk, K}, {topk*K, K, 1}));
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk*K, K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({batch, N}, {StrideDs[0], 0}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
......@@ -274,7 +279,7 @@ int main(int argc, char* argv[])
case 0: break;
case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
......@@ -366,9 +371,9 @@ int main(int argc, char* argv[])
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * valid_size * N * K;
std::size_t flop = std::size_t(2) * tokens * topk * N * K;
std::size_t num_btype =
sizeof(A0DataType) * valid_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_size * N;
sizeof(A0DataType) * tokens * K * topk + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
......@@ -1139,7 +1139,7 @@ struct GridwiseMoeGemmScatter
{
ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
problem.NumTokens * problem.TopK, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
......
......@@ -125,7 +125,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
CDataType v_c{0};
D0DataType v_d0 = arg.d0_(m, n); // a
D0DataType v_d1 = arg.d1_(e, n); // b
arg.c_element_op_(v_c, v_acc, v_d0 * v_topk_w, v_d1, v_topk_w);
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
arg.c_t_n_(t, n) += v_c;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment