"test/vscode:/vscode.git/clone" did not exist on "d68df25588d7462d970900ae7ae7514b1e82bbce"
Commit 7796fc73 authored by coderfeli's avatar coderfeli
Browse files

fix gemm2 scale, gemm2 ok now

parent 61e3c238
...@@ -358,9 +358,9 @@ int main(int argc, char* argv[]) ...@@ -358,9 +358,9 @@ int main(int argc, char* argv[])
if (time_kernel) { if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * valid_size * N * K; std::size_t flop = std::size_t(2) * valid_tile_num * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(A0DataType) * valid_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_size * N; sizeof(A0DataType) * valid_tile_num * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_tile_num * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
...@@ -58,7 +58,7 @@ struct MulABScaleExpertWeight ...@@ -58,7 +58,7 @@ struct MulABScaleExpertWeight
template <typename E, typename C, typename D0, typename D1, typename D2> template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
//real kernel use //for real kernel use
template <> template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float> __host__ __device__ constexpr void operator()<EDataType, float, float, float, float>
(EDataType& e, (EDataType& e,
...@@ -67,11 +67,12 @@ struct MulABScaleExpertWeight ...@@ -67,11 +67,12 @@ struct MulABScaleExpertWeight
const float& d1, const float& d1,
const float& d2) const const float& d2) const
{ {
// e = ck::type_convert<EDataType>(c * d0 * d1 * d2); //for real kernel use
(void) d2; //warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. tofix:felix
e = ck::type_convert<EDataType>(c); (void) d0;
e = ck::type_convert<EDataType>(c * d1 * d2);
} }
// for reference // for reference cpu
template <> template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float> __host__ __device__ constexpr void operator()<float, float, float, float, float>
(float& e, (float& e,
...@@ -80,8 +81,8 @@ struct MulABScaleExpertWeight ...@@ -80,8 +81,8 @@ struct MulABScaleExpertWeight
const float& d1, const float& d1,
const float& d2) const const float& d2) const
{ {
(void) d2; // for reference cpu
e = ck::type_convert<EDataType>(c); e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
} }
}; };
...@@ -124,7 +125,7 @@ using BElementOp = PassThrough; ...@@ -124,7 +125,7 @@ using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight; using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 64; static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t BLOCKSIZE = 256; static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t MNPerXDL = 32;
...@@ -188,15 +189,13 @@ int main(int argc, char* argv[]) ...@@ -188,15 +189,13 @@ int main(int argc, char* argv[])
ck::index_t N = 6144; ck::index_t N = 6144;
ck::index_t K = 8192; ck::index_t K = 8192;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 9; ck::index_t sorted_tile_num = 10;
ck::index_t valid_tile_num = 8; ck::index_t valid_tile_num = 8;
ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t batch = 64; ck::index_t tokens = 64;
ck::index_t topk = 2; ck::index_t topk = 2;
ck::index_t tokens = batch;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -236,6 +235,11 @@ int main(int argc, char* argv[]) ...@@ -236,6 +235,11 @@ int main(int argc, char* argv[])
for (int i = 0; i < sorted_tile_num; i++) { for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i; expert_ids.mData[i] = i;
} }
if (tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
int token_per_tile = tokens * topk / valid_tile_num; int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0; int tokenid = 0;
// sorted_token_ids.mData[0] = 0; // sorted_token_ids.mData[0] = 0;
...@@ -243,20 +247,21 @@ int main(int argc, char* argv[]) ...@@ -243,20 +247,21 @@ int main(int argc, char* argv[])
int tile_off = i % valid_size; int tile_off = i % valid_size;
if(tile_off < token_per_tile) if(tile_off < token_per_tile)
{ {
sorted_token_ids.mData[i] = (tokenid % batch) | ((tokenid / batch) << 24); sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++; tokenid++;
} }
else else
{ {
sorted_token_ids.mData[i] = tokens; sorted_token_ids.mData[i] = tokens;
} }
} }
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({batch, topk, K}, {topk*K, K, 1})); Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk*K, K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({batch, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
...@@ -274,7 +279,7 @@ int main(int argc, char* argv[]) ...@@ -274,7 +279,7 @@ int main(int argc, char* argv[])
case 0: break; case 0: break;
case 1: case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2}); d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
...@@ -366,9 +371,9 @@ int main(int argc, char* argv[]) ...@@ -366,9 +371,9 @@ int main(int argc, char* argv[])
// not result correct here because output buf not setzero // not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * valid_size * N * K; std::size_t flop = std::size_t(2) * tokens * topk * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(A0DataType) * valid_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_size * N; sizeof(A0DataType) * tokens * K * topk + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
...@@ -1139,7 +1139,7 @@ struct GridwiseMoeGemmScatter ...@@ -1139,7 +1139,7 @@ struct GridwiseMoeGemmScatter
{ {
ignore = b_element_op; ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); problem.NumTokens * problem.TopK, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled = const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
......
...@@ -125,7 +125,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -125,7 +125,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
CDataType v_c{0}; CDataType v_c{0};
D0DataType v_d0 = arg.d0_(m, n); // a D0DataType v_d0 = arg.d0_(m, n); // a
D0DataType v_d1 = arg.d1_(e, n); // b D0DataType v_d1 = arg.d1_(e, n); // b
arg.c_element_op_(v_c, v_acc, v_d0 * v_topk_w, v_d1, v_topk_w); arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
arg.c_t_n_(t, n) += v_c; arg.c_t_n_(t, n) += v_c;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment