Commit d4b8f1e3 authored by coderfeli's avatar coderfeli
Browse files

add codes for a scatter

parent 82e1f1b9
...@@ -195,7 +195,10 @@ int main(int argc, char* argv[]) ...@@ -195,7 +195,10 @@ int main(int argc, char* argv[])
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 128; ck::index_t batch = 64;
ck::index_t topk = 2;
ck::index_t tokens = batch * topk;
if(argc == 1) if(argc == 1)
{ {
...@@ -241,10 +244,15 @@ int main(int argc, char* argv[]) ...@@ -241,10 +244,15 @@ int main(int argc, char* argv[])
for (int i = 0; i < SORTED_SIZE; i++) { for (int i = 0; i < SORTED_SIZE; i++) {
int tile_off = i % sorted_tile_size; int tile_off = i % sorted_tile_size;
if(tile_off < token_per_tile) if(tile_off < token_per_tile)
sorted_token_ids.mData[i] = tokenid++; {
sorted_token_ids.mData[i] = (tokenid % batch) & ((tokenid / batch) << 24);
tokenid++;
}
else else
{
sorted_token_ids.mData[i] = tokens; sorted_token_ids.mData[i] = tokens;
} }
}
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
...@@ -252,14 +260,14 @@ int main(int argc, char* argv[]) ...@@ -252,14 +260,14 @@ int main(int argc, char* argv[])
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -290,14 +298,14 @@ int main(int argc, char* argv[]) ...@@ -290,14 +298,14 @@ int main(int argc, char* argv[])
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt"); a0_t_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data()); // e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -322,6 +330,7 @@ int main(int argc, char* argv[]) ...@@ -322,6 +330,7 @@ int main(int argc, char* argv[])
d1_device_buf.GetDeviceBuffer()}, d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
tokens, tokens,
topk,
SORTED_SIZE, SORTED_SIZE,
N, N,
K, K,
...@@ -359,9 +368,9 @@ int main(int argc, char* argv[]) ...@@ -359,9 +368,9 @@ int main(int argc, char* argv[])
{ {
invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1}); invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1});
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<CShuffleDataType> c_m_n({SORTED_SIZE, N}); Tensor<CShuffleDataType> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType,
B0DataType, B0DataType,
...@@ -374,25 +383,31 @@ int main(int argc, char* argv[]) ...@@ -374,25 +383,31 @@ int main(int argc, char* argv[])
auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument( auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, sorted_tile_size, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); sorted_token_ids, expert_ids, sorted_tile_size, a0_t_k, b0_e_n_k, c_t_k_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int m = 0; m < SORTED_SIZE; ++m) for(int m = 0; m < SORTED_SIZE; ++m)
{ {
const int t = sorted_token_ids(m); const int fuse_t = sorted_token_ids(m);
const int t = fuse_t & 0xffffff;
if (t >= tokens)
{
continue;
}
const int topk_id = (fuse_t & 0xff000000) >> 24;
const int e = expert_ids(m / sorted_tile_size); const int e = expert_ids(m / sorted_tile_size);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(e, n)); cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(m, topk_id, n), d0_t_n(t, n), d1_e_n(e, n));
} }
} }
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_t_n_device_result.mData.data());
e_m_n_device_result.savetxt("out.txt"); e_t_n_device_result.savetxt("out.txt");
e_m_n_host_result.savetxt("ref.txt"); e_t_n_host_result.savetxt("ref.txt");
return ck::utils::check_err( return ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0 ? 0
: 1; : 1;
} }
......
...@@ -505,6 +505,7 @@ struct DeviceMoeGemm ...@@ -505,6 +505,7 @@ struct DeviceMoeGemm
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
index_t NumTokens, index_t NumTokens,
index_t TopK,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -524,6 +525,7 @@ struct DeviceMoeGemm ...@@ -524,6 +525,7 @@ struct DeviceMoeGemm
p_ds, p_ds,
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
NumTokens, NumTokens,
TopK,
M, M,
N, N,
K, K,
...@@ -563,7 +565,8 @@ struct DeviceMoeGemm ...@@ -563,7 +565,8 @@ struct DeviceMoeGemm
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
p_ds, p_ds,
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
M, M, //randoms set, no use
0,
M, M,
N, N,
K, K,
......
...@@ -545,6 +545,7 @@ struct GridwiseMoeGemmGather ...@@ -545,6 +545,7 @@ struct GridwiseMoeGemmGather
index_t KBatch_) index_t KBatch_)
: :
NumTokens{NumTokens_}, NumTokens{NumTokens_},
TopK{TopK_},
M{M_}, M{M_},
N{N_}, N{N_},
K{K_}, K{K_},
...@@ -570,6 +571,7 @@ struct GridwiseMoeGemmGather ...@@ -570,6 +571,7 @@ struct GridwiseMoeGemmGather
{ {
std::cout << "problem {" std::cout << "problem {"
<< "NumTokens:" << NumTokens << ", " << "NumTokens:" << NumTokens << ", "
<< "TopK:" << TopK << ", "
<< "M:" << M << ", " << "M:" << M << ", "
<< "N:" << N << ", " << "N:" << N << ", "
<< "K:" << K << ", " << "K:" << K << ", "
...@@ -587,6 +589,7 @@ struct GridwiseMoeGemmGather ...@@ -587,6 +589,7 @@ struct GridwiseMoeGemmGather
} }
index_t NumTokens; index_t NumTokens;
index_t TopK;
index_t M; index_t M;
index_t N; index_t N;
index_t K; index_t K;
...@@ -619,6 +622,7 @@ struct GridwiseMoeGemmGather ...@@ -619,6 +622,7 @@ struct GridwiseMoeGemmGather
std::array<const void*, NumDTensor> p_ds_grid_, std::array<const void*, NumDTensor> p_ds_grid_,
CDataType* p_c_grid_, CDataType* p_c_grid_,
index_t NumTokens_, index_t NumTokens_,
index_t TopK_,
index_t M_, index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
...@@ -630,7 +634,7 @@ struct GridwiseMoeGemmGather ...@@ -630,7 +634,7 @@ struct GridwiseMoeGemmGather
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_) CElementwiseOperation c_element_op_)
: Problem{NumTokens_, M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, : Problem{NumTokens_, TopK_, M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
p_sorted_token_ids{p_sorted_token_ids_}, p_sorted_token_ids{p_sorted_token_ids_},
p_sorted_expert_ids{p_sorted_expert_ids_}, p_sorted_expert_ids{p_sorted_expert_ids_},
...@@ -1155,10 +1159,10 @@ struct GridwiseMoeGemmGather ...@@ -1155,10 +1159,10 @@ struct GridwiseMoeGemmGather
// static_assert(MLoadRepeats == 1, "only support 1 line per thread now!"); // static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
const index_t t0 = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); const index_t t0 = p_sorted_token_ids[block_m_id * MPerBlock];
if(t0 >= problem.NumTokens) if((t0 & 0xffffff) >= problem.NumTokens)
return; return;
const index_t topk_id = (t0 & 0xff000000) >> 24;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos]; StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) { static_for<0, AMRepeats, 1>{}([&](auto m0) {
gather_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.K; gather_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.K;
...@@ -1450,7 +1454,7 @@ struct GridwiseMoeGemmGather ...@@ -1450,7 +1454,7 @@ struct GridwiseMoeGemmGather
// too hack here, 2 specific for topk weights, fixme // too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I0]; const float *p_sorted_weights = p_ds_grid[I0];
static_for<0, EMRepeats, 1>{}([&](auto m0) { static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = 0; scatter_offsets(m0) = ((p_sorted_token_ids[c_token_pos + m0] & 0xffffff) * problem.TopK + topk_id) * problem.N;
scatter_weights(m0) = p_sorted_weights[(c_token_pos + m0) * problem.StrideDs[0]]; scatter_weights(m0) = p_sorted_weights[(c_token_pos + m0) * problem.StrideDs[0]];
// if(threadIdx.x % 16 == 0) // if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
...@@ -1482,13 +1486,13 @@ struct GridwiseMoeGemmGather ...@@ -1482,13 +1486,13 @@ struct GridwiseMoeGemmGather
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1, //ScatterDim 1, //ScatterDim
false, //OutputScatter: false, only use scatter weights true, //OutputScatter: false, only use scatter weights
1 // ScatterWeightIdx: ascale 1 // ScatterWeightIdx: ascale
> >
{c_ds_desc_refs, {c_ds_desc_refs,
idx_c_ds_block_begin, idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op, c_element_op,
scatter_offsets, scatter_offsets,
scatter_weights}; scatter_weights};
......
...@@ -33,7 +33,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -33,7 +33,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k, const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_m_n, Tensor<CDataType>& c_t_k_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
...@@ -42,7 +42,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -42,7 +42,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
sorted_tile_size_{sorted_tile_size}, sorted_tile_size_{sorted_tile_size},
a_t_k_{a_t_k}, a_t_k_{a_t_k},
b_e_n_k_{b_e_n_k}, b_e_n_k_{b_e_n_k},
c_m_n_{c_m_n}, c_t_k_n_{c_t_k_n},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
...@@ -54,7 +54,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -54,7 +54,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
index_t sorted_tile_size_; index_t sorted_tile_size_;
const Tensor<ADataType>& a_t_k_; const Tensor<ADataType>& a_t_k_;
const Tensor<BDataType>& b_e_n_k_; const Tensor<BDataType>& b_e_n_k_;
Tensor<CDataType>& c_m_n_; Tensor<CDataType>& c_t_k_n_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
...@@ -74,7 +74,8 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -74,7 +74,8 @@ struct ReferenceMoeGemm : public device::BaseOperator
AccDataType v_acc{0}; AccDataType v_acc{0};
ComputeTypeA v_a{0}; ComputeTypeA v_a{0};
ComputeTypeB v_b{0}; ComputeTypeB v_b{0};
const int t = arg.sorted_token_ids_(m); const int t = arg.sorted_token_ids_(m) & 0xffffff;
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
const int e = arg.expert_ids_(m / arg.sorted_tile_size_); const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0]; const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0];
if(t < token_cnt) { if(t < token_cnt) {
...@@ -110,11 +111,11 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -110,11 +111,11 @@ struct ReferenceMoeGemm : public device::BaseOperator
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = v_c; arg.c_t_k_n_(t, topk_id, n) = v_c;
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( f_mk_kn_mn, arg.sorted_tile_size_, arg.c_t_k_n_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -140,12 +141,12 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -140,12 +141,12 @@ struct ReferenceMoeGemm : public device::BaseOperator
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k, const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_m_n, Tensor<CDataType>& c_t_k_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_t_k, b_e_n_k, c_m_n, a_element_op, b_element_op, c_element_op}; return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_t_k, b_e_n_k, c_t_k_n, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment