Unverified Commit 50643dd5 authored by ltqin's avatar ltqin Committed by GitHub
Browse files

Add bias scalar vectorload = 1 for gemm bias gemm (#791)

* first change bias load

* add bias dim and scalervector parameter

* make CDE0BlockTransferSrcVectorDim not work

* changse toinstance

* add limit for CDE0BlockTransferSrcScalarPerVector
parent 844b215d
...@@ -173,6 +173,8 @@ using DeviceGemmInstance = ...@@ -173,6 +173,8 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
9, // D0sTransferSrcVectorDim
4, // D0sTransferSrcScalaerPerVector
S<8, 32, 1>, // B1BlockTransfer S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
...@@ -189,7 +191,7 @@ int main(int argc, char* argv[]) ...@@ -189,7 +191,7 @@ int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
// GEMM shape // GEMM shape
ck::index_t M = 1024; ck::index_t M = 1024;
......
...@@ -196,6 +196,8 @@ template <typename A0Layout, ...@@ -196,6 +196,8 @@ template <typename A0Layout,
index_t B0BlockTransferSrcScalarPerVector, index_t B0BlockTransferSrcScalarPerVector,
index_t B0BlockTransferDstScalarPerVector_BK1, index_t B0BlockTransferDstScalarPerVector_BK1,
bool B0BlockLdsExtraN, bool B0BlockLdsExtraN,
index_t CDE0BlockTransferSrcVectorDim,
index_t CDE0BlockTransferSrcScalaerPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
...@@ -492,6 +494,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -492,6 +494,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
B0BlockTransferDstScalarPerVector_BK1, B0BlockTransferDstScalarPerVector_BK1,
true, true,
B0BlockLdsExtraN, B0BlockLdsExtraN,
CDE0BlockTransferSrcVectorDim,
CDE0BlockTransferSrcScalaerPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
......
...@@ -67,6 +67,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat ...@@ -67,6 +67,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat
index_t B0BlockTransferDstScalarPerVector_BK1, index_t B0BlockTransferDstScalarPerVector_BK1,
bool B0ThreadTransferSrcResetCoordinateAfterRun, // ignored bool B0ThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t B0BlockLdsExtraN, index_t B0BlockLdsExtraN,
index_t CDE0BlockTransferSrcVectorDim,
index_t CDE0BlockTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
...@@ -710,13 +712,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -710,13 +712,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
I1, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
I1, // MWaveId m1, // MWaveId
I1, // NWaveId n1, // NWaveId
I1, // MPerXdl m2, // MPerXdl
I1, // NGroupNum n2, // NGroupNum
I1, // NInputNum n3, // NInputNum
n4)); // registerNum n4)); // registerNum
auto d0s_thread_buf = generate_tuple( auto d0s_thread_buf = generate_tuple(
...@@ -732,8 +734,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -732,8 +734,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed( static_assert(CDE0BlockTransferSrcScalarPerVector <= n4,
make_tuple(Number<Gemm0MXdlPerWave>{}, Number<Gemm0NXdlPerWave>{}, n2, n4)); "vector load must be not greater than n4");
static_assert(n4 % CDE0BlockTransferSrcScalarPerVector == 0);
auto d0s_threadwise_copy = generate_tuple( auto d0s_threadwise_copy = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -742,10 +745,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -742,10 +745,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
A0B0B1DataType, A0B0B1DataType,
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, I1, I1, I1, I1, I1, I1, I1, I1, n4>, Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, 9, // CDE0BlockTransferSrcVectorDim
n4, CDE0BlockTransferSrcScalarPerVector,
1, 1,
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(block_work_idx[I0], // MBlockId make_multi_index(block_work_idx[I0], // MBlockId
...@@ -898,38 +910,27 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -898,38 +910,27 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
blockwise_gemm0, blockwise_gemm0,
acc0_thread_buf, acc0_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// bias+gelu // multiple d
if constexpr(NumD0Tensor)
{ {
static_for<0, Gemm0MXdlPerWave, 1>{}([&](auto mr) {
static_for<0, Gemm0NXdlPerWave, 1>{}([&](auto nr) {
static_for<0, n2, 1>{}([&](auto groupid) {
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).Run( d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d0s_grid_buf[i], d0s_grid_buf[i],
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0s_thread_buf(i)); d0s_thread_buf(i));
}); });
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
static_for<0, n4, 1>{}([&](auto i) {
constexpr index_t c_offset = acc0_thread_desc.CalculateOffset(
make_tuple(mr, nr, groupid, i));
// get reference to src data // get reference to src data
const auto src_data_refs = generate_tie( const auto src_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
[&](auto iSrc) -> const auto& { [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
return d0s_thread_buf[iSrc][i];
},
Number<NumD0Tensor>{}); Number<NumD0Tensor>{});
// get reference to dst data // get reference to dst data
auto dst_data_refs = generate_tie( auto dst_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
[&](auto) -> auto& { [&](auto) -> auto& { return acc0_thread_buf(i); },
return acc0_thread_buf(Number<c_offset>{});
},
Number<2>{}); Number<2>{});
unpack2(cde0_element_op, dst_data_refs, src_data_refs); unpack2(cde0_element_op, dst_data_refs, src_data_refs);
...@@ -937,27 +938,14 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -937,27 +938,14 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow( d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 1, -Gemm0NXdlPerWave, 0, 0, 0, 0, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 1, -Gemm0MXdlPerWave, 0, 0, 0, 0, 0, 0, 0));
}); });
} }
else
{
static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); });
}
// gemm1 // gemm1
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment