Commit 965c9f0c authored by coderfeli's avatar coderfeli
Browse files

debug 16x16 load

parent 83970cbe
...@@ -131,13 +131,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -131,13 +131,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F16>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F16>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, AElementOp, BElementOp, CDEElementOp, GemmSpec,
32, 128, 128, //threadnum, mblock, nblock, kblock
256, 32, 128, 128,
// ak1, bk1
8, 8, 8, 8,
// mn_perxdl
32, 32, 32, 32,
// mn_xdlperwave
1, 1, 1, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
...@@ -162,7 +167,7 @@ int main(int argc, char* argv[]) ...@@ -162,7 +167,7 @@ int main(int argc, char* argv[])
ck::index_t N = 6144; ck::index_t N = 6144;
ck::index_t K = 8192; ck::index_t K = 8192;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 1;
ck::index_t sorted_tile_size = 32; ck::index_t sorted_tile_size = 32;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 32; ck::index_t tokens = 32;
......
...@@ -45,11 +45,12 @@ template <typename ThreadGroup, ...@@ -45,11 +45,12 @@ template <typename ThreadGroup,
index_t NumThreadScratch = 1> index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1_mod8 struct ThreadGroupTensorSliceTransfer_v4r1_mod8
{ {
static constexpr auto I0 = Number<0>{};
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{}); static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
static constexpr index_t mod_num = ThreadClusterLengths{}.At(I0); // Dirty HACK FELIX, TODO fix
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
// using GatherIndex = MultiIndex<gather_num>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_mod8( __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_mod8(
const SrcDesc& src_desc, const SrcDesc& src_desc,
...@@ -86,7 +87,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8 ...@@ -86,7 +87,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % 8)); make_multi_index(ThreadGroup::GetThreadId() % mod_num));
threadwise_transfer_.SetSrcSliceOrigin(src_desc, threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + src_thread_cluster_idx * thread_slice_lengths); src_block_slice_origin + src_thread_cluster_idx * thread_slice_lengths);
...@@ -104,7 +105,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8 ...@@ -104,7 +105,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % 8)); make_multi_index(ThreadGroup::GetThreadId() % mod_num));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc, threadwise_transfer_.SetSrcSliceOrigin(src_desc,
......
...@@ -1127,16 +1127,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1127,16 +1127,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]); const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto MLoadThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto KLoadThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0) * ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto MLoadRepeats = MPerBlock / MLoadThreads; constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
static_assert(MLoadRepeats == 1, "only support 1 line per thread now!"); constexpr auto AKThreads = AK0Threads * AK1Threads;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / KLoadThreads; constexpr auto AMRepeats = MPerBlock / AMThreads;
StaticallyIndexedArray<index_t, MLoadRepeats> token_offsets; //= p_sorted_token_ids[token_pos]; // static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
static_for<0, MLoadRepeats, 1>{}([&](auto m0) { const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
token_offsets(m0) = p_sorted_token_ids[token_pos + MLoadThreads * m0] * problem.K; StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
gather_offsets(m0) = p_sorted_token_ids[token_pos + m0] * problem.K;
printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
}); });
// printf("threadIdx.x %d off %d\n", threadIdx.x, token_offsets(I0));
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
...@@ -1194,7 +1196,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1194,7 +1196,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
token_offsets); gather_offsets);
// Thread-wise copy // Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
...@@ -1222,7 +1224,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1222,7 +1224,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(AK0Threads, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
// Blockwise GEMM pipeline // Blockwise GEMM pipeline
......
...@@ -178,15 +178,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ...@@ -178,15 +178,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// maintain a container record is_src_valid, waiting for RunWrite use. // maintain a container record is_src_valid, waiting for RunWrite use.
const index_t ld_offset = src_coord_.GetOffset() + gather_offset; const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize() * sizeof(SrcData);//hack felix, todo use coord const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize();//hack felix, todo use coord
//coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_) && (gather_offset < 32*512); //coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_) && (gather_offset < 32*512);
src_oob_thread_scratch_tuple_(thread_scratch_id) src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid); .template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
// if(blockIdx.x+blockIdx.y==0) if(threadIdx.x==0)
// printf("tid %d off %d %d\n", threadIdx.x, src_coord_.GetOffset(), gather_offset ); printf("use tid %d num %d off %d %d\n", threadIdx.x, ordered_src_access_idx[Number<ordered_gather_dim>{}](), src_coord_.GetOffset(), gather_offset );
auto src_vector_container = auto src_vector_container =
src_vector_type{src_buf.template Get<src_vector_t>(ld_offset, true)}; src_vector_type{src_buf.template Get<src_vector_t>(ld_offset, true)};
...@@ -235,7 +235,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ...@@ -235,7 +235,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// printf("tid %d %f\n",threadIdx.x, type_convert<float>(src_vector_container.template AsType<print_vec_t>()[idx])); // printf("tid %d %f\n",threadIdx.x, type_convert<float>(src_vector_container.template AsType<print_vec_t>()[idx]));
// }); // });
// } // }
constexpr auto move_on_dim = [&]() constexpr auto move_on_dim = [&]() constexpr
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim_; StaticallyIndexedArray<bool, nDim> move_on_dim_;
...@@ -246,15 +246,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ...@@ -246,15 +246,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
move_on_dim_(i) &= move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
}); });
move_on_dim_(i) &= i.value != ordered_gather_dim;
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, move_on_dim_(i), ordered_gather_dim);
}); });
return move_on_dim_; return move_on_dim_;
} }
(); ();
// move src coord // move src coord
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i]) if(threadIdx.x==0)
printf("use tid %d ori cord: %d i %d mov %d\n", threadIdx.x, src_coord_.GetOffset(), i.value, move_on_dim[i]);
if (move_on_dim[i])
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
...@@ -267,7 +272,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ...@@ -267,7 +272,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
} }
} }
if(threadIdx.x==0)
printf("use tid %d moved cord: %d\n", threadIdx.x, src_coord_.GetOffset());
}); });
}); });
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
......
...@@ -423,14 +423,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3 ...@@ -423,14 +423,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3
dst_coords_[i].GetOffset(), dst_coords_[i].GetOffset(),
is_dst_valid, is_dst_valid,
dst_vectors[i].template AsType<dst_vector_t>()[I0]); dst_vectors[i].template AsType<dst_vector_t>()[I0]);
if(1) { // if(1) {
static_for<0, DstScalarPerVector, 1>{}([&](auto idx) { // static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>; // using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
using print_vec_t = typename vector_type<DstData, 1>::type; // using print_vec_t = typename vector_type<DstData, 1>::type;
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid, // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid,
// type_convert<float>(dst_vectors[i].template AsType<print_vec_t>()[idx])); // type_convert<float>(dst_vectors[i].template AsType<print_vec_t>()[idx]));
}); // });
} // }
}); });
// move coordinate // move coordinate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment