Commit 965c9f0c authored by coderfeli's avatar coderfeli
Browse files

debug 16x16 load

parent 83970cbe
......@@ -131,13 +131,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F16>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
32, 128, 128,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
//threadnum, mblock, nblock, kblock
256, 32, 128, 128,
// ak1, bk1
8, 8,
// mn_perxdl
32, 32,
// mn_xdlperwave
1, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
......@@ -162,7 +167,7 @@ int main(int argc, char* argv[])
ck::index_t N = 6144;
ck::index_t K = 8192;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_num = 1;
ck::index_t sorted_tile_size = 32;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 32;
......
......@@ -45,11 +45,12 @@ template <typename ThreadGroup,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1_mod8
{
static constexpr auto I0 = Number<0>{};
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
static constexpr index_t mod_num = ThreadClusterLengths{}.At(I0); // Dirty HACK FELIX, TODO fix
using Index = MultiIndex<nDim>;
// using GatherIndex = MultiIndex<gather_num>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_mod8(
const SrcDesc& src_desc,
......@@ -86,7 +87,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % 8));
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + src_thread_cluster_idx * thread_slice_lengths);
......@@ -104,7 +105,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % 8));
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
......
......@@ -1127,16 +1127,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto MLoadThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto KLoadThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0) * ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto MLoadRepeats = MPerBlock / MLoadThreads;
static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / KLoadThreads;
StaticallyIndexedArray<index_t, MLoadRepeats> token_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, MLoadRepeats, 1>{}([&](auto m0) {
token_offsets(m0) = p_sorted_token_ids[token_pos + MLoadThreads * m0] * problem.K;
constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto AKThreads = AK0Threads * AK1Threads;
constexpr auto AMRepeats = MPerBlock / AMThreads;
// static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
gather_offsets(m0) = p_sorted_token_ids[token_pos + m0] * problem.K;
printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
// printf("threadIdx.x %d off %d\n", threadIdx.x, token_offsets(I0));
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
......@@ -1194,7 +1196,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
token_offsets);
gather_offsets);
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
......@@ -1222,7 +1224,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(AK0Threads, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
// Blockwise GEMM pipeline
......
......@@ -178,15 +178,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// maintain a container record is_src_valid, waiting for RunWrite use.
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize() * sizeof(SrcData);//hack felix, todo use coord
const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize();//hack felix, todo use coord
//coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_) && (gather_offset < 32*512);
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
// if(blockIdx.x+blockIdx.y==0)
// printf("tid %d off %d %d\n", threadIdx.x, src_coord_.GetOffset(), gather_offset );
if(threadIdx.x==0)
printf("use tid %d num %d off %d %d\n", threadIdx.x, ordered_src_access_idx[Number<ordered_gather_dim>{}](), src_coord_.GetOffset(), gather_offset );
auto src_vector_container =
src_vector_type{src_buf.template Get<src_vector_t>(ld_offset, true)};
......@@ -235,7 +235,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// printf("tid %d %f\n",threadIdx.x, type_convert<float>(src_vector_container.template AsType<print_vec_t>()[idx]));
// });
// }
constexpr auto move_on_dim = [&]() constexpr
auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
......@@ -246,15 +246,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
move_on_dim_(i) &= i.value != ordered_gather_dim;
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, move_on_dim_(i), ordered_gather_dim);
});
return move_on_dim_;
}
();
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
if(threadIdx.x==0)
printf("use tid %d ori cord: %d i %d mov %d\n", threadIdx.x, src_coord_.GetOffset(), i.value, move_on_dim[i]);
if (move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
......@@ -267,7 +272,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
if(threadIdx.x==0)
printf("use tid %d moved cord: %d\n", threadIdx.x, src_coord_.GetOffset());
});
});
// move src coordinate back to slice origin (or not)
......
......@@ -423,14 +423,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3
dst_coords_[i].GetOffset(),
is_dst_valid,
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
if(1) {
static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
using print_vec_t = typename vector_type<DstData, 1>::type;
// if(1) {
// static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
// using print_vec_t = typename vector_type<DstData, 1>::type;
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid,
// type_convert<float>(dst_vectors[i].template AsType<print_vec_t>()[idx]));
});
}
// });
// }
});
// move coordinate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment