Commit 7ba5bff4 authored by coderfeli's avatar coderfeli
Browse files

one tile ok

parent 8a5bb9f3
...@@ -119,6 +119,7 @@ using CDEElementOp = MultiplyMultiply; ...@@ -119,6 +119,7 @@ using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 32; static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
...@@ -142,7 +143,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -142,7 +143,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// ak1, bk1 // ak1, bk1
AK1, BK1, AK1, BK1,
// mn_perxdl // mn_perxdl
32, 32, MNPerXDL, MNPerXDL,
// mn_xdlperwave // mn_xdlperwave
MXDLPerWave, 1, MXDLPerWave, 1,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
...@@ -173,11 +174,11 @@ int main(int argc, char* argv[]) ...@@ -173,11 +174,11 @@ int main(int argc, char* argv[])
// GEMM shape // GEMM shape
ck::index_t N = 128; ck::index_t N = 128;
ck::index_t K = 1024; ck::index_t K = 1024;
ck::index_t experts = 1; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 1; ck::index_t sorted_tile_num = 2;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 1; ck::index_t tokens = 32;
if(argc == 1) if(argc == 1)
{ {
...@@ -251,7 +252,7 @@ int main(int argc, char* argv[]) ...@@ -251,7 +252,7 @@ int main(int argc, char* argv[])
Tensor<D1DataType> d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{})); Tensor<D1DataType> d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl; std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl;
...@@ -358,8 +359,7 @@ int main(int argc, char* argv[]) ...@@ -358,8 +359,7 @@ int main(int argc, char* argv[])
{ {
invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1}); invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1});
e_device_buf.FromDevice(e_t_n_device_result.mData.data()); // e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<CShuffleDataType> c_t_n({tokens, N}); Tensor<CShuffleDataType> c_t_n({tokens, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
...@@ -376,10 +376,11 @@ int main(int argc, char* argv[]) ...@@ -376,10 +376,11 @@ int main(int argc, char* argv[])
sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, c_t_n, PassThrough{}, PassThrough{}, PassThrough{}); sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, c_t_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int m = 0; m < SORTED_SIZE; ++m)
for(int t = 0; t < tokens; ++t)
{ {
const int t = sorted_token_ids(m); // const int t = sorted_token_ids(m);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_t_n(t, n)); cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_t_n(t, n));
...@@ -389,6 +390,7 @@ int main(int argc, char* argv[]) ...@@ -389,6 +390,7 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data()); e_device_buf.FromDevice(e_t_n_device_result.mData.data());
e_t_n_device_result.savetxt("out.txt"); e_t_n_device_result.savetxt("out.txt");
e_t_n_host_result.savetxt("ref.txt"); e_t_n_host_result.savetxt("ref.txt");
return ck::utils::check_err( return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0 ? 0
......
...@@ -101,17 +101,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ...@@ -101,17 +101,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % mod_num)); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
const auto src_thread_slice_origins = generate_tuple( const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, [&](auto i) { return src_block_slice_origins[i] + src_thread_cluster_idx * thread_slice_lengths; },
Number<nSrc>{}); Number<nSrc>{});
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
const auto dst_thread_slice_origins = generate_tuple( const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, [&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; },
Number<nDst>{}); Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
......
...@@ -1115,8 +1115,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1115,8 +1115,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>( const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.NumTokens, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); problem.NumTokens, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock); c_grid_desc_m_n, problem.MBlock, problem.NBlock);
...@@ -1393,14 +1391,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1393,14 +1391,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = MPerBlock / EMThreads; constexpr auto EMRepeats = MPerBlock / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
static_assert(EMRepeats == 1, "only support 1 line per thread now!"); static_assert(EMRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / EMThreads * EMRepeats; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[token_pos]; StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, EMRepeats, 1>{}([&](auto m0) { static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N; scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
}); });
// printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock, ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
...@@ -1433,7 +1433,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1433,7 +1433,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_tuple(make_multi_index(0, 0, block_n_id, 0)), make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op}; c_element_op};
// if(threadIdx.x== 0) // if(threadIdx.x== 0)
// printf("offset %d size %d\n", scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid + scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() - scatter_offsets(I0)); p_c_grid + scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() - scatter_offsets(I0));
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
...@@ -1461,7 +1460,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1461,7 +1460,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// printf("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee\n");
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
block_sync_lds(); block_sync_lds();
......
...@@ -71,13 +71,13 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -71,13 +71,13 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
arg.c_t_n_.SetZero(); arg.c_t_n_.SetZero();
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc{0}; AccDataType v_acc{0};
ComputeTypeA v_a{0}; ComputeTypeA v_a{0};
ComputeTypeB v_b{0}; ComputeTypeB v_b{0};
const int t = arg.sorted_token_ids_(m); const int t = arg.sorted_token_ids_(m);
const int e = arg.expert_ids_(m / arg.sorted_tile_size_); const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.a_m_k_.mDesc.GetLengths()[0]; const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0];
if(t < token_cnt) { if(t < token_cnt) {
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
...@@ -105,17 +105,17 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -105,17 +105,17 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
v_acc += v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
}
CDataType v_c{0}; CDataType v_c{0};
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_t_n_(t, n) += v_c; arg.c_t_n_(t, n) += v_c;
}
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_t_n_.mDesc.GetLengths()[0], arg.c_t_n_.mDesc.GetLengths()[1])( f_mk_kn_mn, arg.a_m_k_.mDesc.GetLengths()[0], arg.c_t_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency()); 1);
return 0; return 0;
} }
......
...@@ -17,7 +17,7 @@ fi ...@@ -17,7 +17,7 @@ fi
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 --save-temps -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O1 -g --save-temps -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \ -D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment