Commit f45cc232 authored by ltqin's avatar ltqin
Browse files

first change bias load

parent 87f2bbcf
...@@ -189,7 +189,7 @@ int main(int argc, char* argv[]) ...@@ -189,7 +189,7 @@ int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
// GEMM shape // GEMM shape
ck::index_t M = 1024; ck::index_t M = 1024;
......
...@@ -79,7 +79,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat ...@@ -79,7 +79,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat
index_t C1ShuffleGemm0NXdlPerWavePerShuffle, index_t C1ShuffleGemm0NXdlPerWavePerShuffle,
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched,
int D0sTransferSrcScalarPerVector = 4>
struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
...@@ -710,13 +711,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -710,13 +711,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
I1, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
I1, // MWaveId m1, // MWaveId
I1, // NWaveId n1, // NWaveId
I1, // MPerXdl m2, // MPerXdl
I1, // NGroupNum n2, // NGroupNum
I1, // NInputNum n3, // NInputNum
n4)); // registerNum n4)); // registerNum
auto d0s_thread_buf = generate_tuple( auto d0s_thread_buf = generate_tuple(
...@@ -732,9 +733,6 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -732,9 +733,6 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<Gemm0MXdlPerWave>{}, Number<Gemm0NXdlPerWave>{}, n2, n4));
auto d0s_threadwise_copy = generate_tuple( auto d0s_threadwise_copy = generate_tuple(
[&](auto i) { [&](auto i) {
return ThreadwiseTensorSliceTransfer_v2< return ThreadwiseTensorSliceTransfer_v2<
...@@ -742,10 +740,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -742,10 +740,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
A0B0B1DataType, A0B0B1DataType,
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, I1, I1, I1, I1, I1, I1, I1, I1, n4>, Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, 9,
n4, D0sTransferSrcScalarPerVector,
1, 1,
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(block_work_idx[I0], // MBlockId make_multi_index(block_work_idx[I0], // MBlockId
...@@ -898,66 +905,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -898,66 +905,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
blockwise_gemm0, blockwise_gemm0,
acc0_thread_buf, acc0_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// bias+gelu // multiple d
if constexpr(NumD0Tensor)
{ {
static_for<0, Gemm0MXdlPerWave, 1>{}([&](auto mr) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
static_for<0, Gemm0NXdlPerWave, 1>{}([&](auto nr) { d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
static_for<0, n2, 1>{}([&](auto groupid) { d0s_grid_buf[i],
static_for<0, NumD0Tensor, 1>{}([&](auto i) { d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0s_threadwise_copy(i).Run( make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0s_thread_buf(i));
d0s_grid_buf[i], });
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), // get reference to src data
d0s_thread_buf(i)); const auto src_data_refs = generate_tie(
}); // return type should be lvalue
[&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
static_for<0, n4, 1>{}([&](auto i) { Number<NumD0Tensor>{});
constexpr index_t c_offset = acc0_thread_desc.CalculateOffset(
make_tuple(mr, nr, groupid, i)); // get reference to dst data
auto dst_data_refs = generate_tie(
// get reference to src data // return type should be lvalue
const auto src_data_refs = generate_tie( [&](auto) -> auto& { return acc0_thread_buf(i); },
// return type should be lvalue Number<2>{});
[&](auto iSrc) -> const auto& {
return d0s_thread_buf[iSrc][i]; unpack2(cde0_element_op, dst_data_refs, src_data_refs);
},
Number<NumD0Tensor>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& {
return acc0_thread_buf(Number<c_offset>{});
},
Number<2>{});
unpack2(cde0_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 1, -Gemm0NXdlPerWave, 0, 0, 0, 0, 0, 0));
});
}); });
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow( d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 1, -Gemm0MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}); });
} }
else
{
static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); });
}
// gemm1 // gemm1
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment