"vscode:/vscode.git/clone" did not exist on "0042efd0157b51b1a2593da6a29aaf0acd9a1c59"
Commit 48d87d9c authored by coderfeli's avatar coderfeli
Browse files

a 16x16 ok

parent 965c9f0c
...@@ -141,8 +141,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -141,8 +141,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// mn_xdlperwave // mn_xdlperwave
1, 1, 1, 1,
// a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra // a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
...@@ -167,10 +169,10 @@ int main(int argc, char* argv[]) ...@@ -167,10 +169,10 @@ int main(int argc, char* argv[])
ck::index_t N = 6144; ck::index_t N = 6144;
ck::index_t K = 8192; ck::index_t K = 8192;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 1; ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = 32; ck::index_t sorted_tile_size = 32;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 32; ck::index_t tokens = 64;
if(argc == 1) if(argc == 1)
{ {
......
...@@ -1137,7 +1137,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1137,7 +1137,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos]; StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) { static_for<0, AMRepeats, 1>{}([&](auto m0) {
gather_offsets(m0) = p_sorted_token_ids[token_pos + m0] * problem.K; gather_offsets(m0) = p_sorted_token_ids[token_pos + m0] * problem.K;
printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
}); });
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
...@@ -1224,7 +1224,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1224,7 +1224,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(AK0Threads, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
// Blockwise GEMM pipeline // Blockwise GEMM pipeline
......
...@@ -185,8 +185,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ...@@ -185,8 +185,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
if(threadIdx.x==0) // if(threadIdx.x==0)
printf("use tid %d num %d off %d %d\n", threadIdx.x, ordered_src_access_idx[Number<ordered_gather_dim>{}](), src_coord_.GetOffset(), gather_offset ); // printf("use tid %d num %d off %d %d\n", threadIdx.x, ordered_src_access_idx[Number<ordered_gather_dim>{}](), src_coord_.GetOffset(), gather_offset );
auto src_vector_container = auto src_vector_container =
src_vector_type{src_buf.template Get<src_vector_t>(ld_offset, true)}; src_vector_type{src_buf.template Get<src_vector_t>(ld_offset, true)};
...@@ -257,8 +257,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ...@@ -257,8 +257,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
(); ();
// move src coord // move src coord
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
if(threadIdx.x==0) // if(threadIdx.x==0)
printf("use tid %d ori cord: %d i %d mov %d\n", threadIdx.x, src_coord_.GetOffset(), i.value, move_on_dim[i]); // printf("use tid %d ori cord: %d i %d mov %d\n", threadIdx.x, src_coord_.GetOffset(), i.value, move_on_dim[i]);
if (move_on_dim[i]) if (move_on_dim[i])
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
...@@ -272,8 +272,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ...@@ -272,8 +272,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
} }
} }
if(threadIdx.x==0) // if(threadIdx.x==0)
printf("use tid %d moved cord: %d\n", threadIdx.x, src_coord_.GetOffset()); // printf("use tid %d moved cord: %d\n", threadIdx.x, src_coord_.GetOffset());
}); });
}); });
...@@ -666,11 +666,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ...@@ -666,11 +666,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
constexpr auto reset_src_data_step = [&]() { constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_; Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i]; });
return reset_src_data_step_; return reset_src_data_step_;
}(); }();
return reset_src_data_step; return reset_src_data_step;
} }
...@@ -740,7 +739,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ...@@ -740,7 +739,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
const auto adjusted_step_idx = const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment