Commit a8169558 authored by Jing Zhang's avatar Jing Zhang
Browse files

imprve threadwise gemm with dot2

parent 8eaa6d5d
......@@ -29,21 +29,20 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
index_t w;
};
static constexpr index_t KPerThreadLoop = 4;
static constexpr auto KPerThread = ThreadMatrixC{}.GetLength(I0);
static constexpr auto HPerThread = ThreadMatrixC{}.GetLength(I2);
static constexpr auto WPerThread = ThreadMatrixC{}.GetLength(I3);
// HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4;
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}));
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
Number<KPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
......@@ -110,13 +109,8 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
// HACK: fix this @Jing Zhang
constexpr auto HoPerThreadSubC = HPerThread;
constexpr auto WoPerThreadSubC = WPerThread;
static_assert(KPerThread % KPerThreadSubC == 0, "");
static_assert(HPerThread % HoPerThreadSubC == 0, "");
static_assert(WPerThread % WoPerThreadSubC == 0, "");
static_assert(EPerBlock % EPerThreadLoop == 0, "");
static_assert(KPerThread % KPerThreadLoop == 0, "");
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatAB, a_thread_mtx_.GetElementSpaceSize(), true>
......@@ -127,12 +121,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
FloatC,
decltype(a_thread_mtx_),
decltype(b_thread_mtx_),
decltype(c_thread_mtx_),
HoPerThreadSubC,
WoPerThreadSubC>{};
decltype(c_thread_mtx_)>{};
static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) {
static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) {
static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) {
a_thread_copy_.Run(a_block_mtx,
make_tuple(e_begin, k_begin),
a_block_buf,
......@@ -140,16 +132,12 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
make_tuple(I0, I0),
a_thread_buf);
static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) {
static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) {
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0),
b_thread_buf,
make_tuple(e_begin, I0, h_begin, w_begin),
c_thread_buf,
make_tuple(k_begin, I0, h_begin, w_begin));
});
});
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0),
b_thread_buf,
make_tuple(e_begin, I0, I0, I0),
c_thread_buf,
make_tuple(k_begin, I0, I0, I0));
});
});
}
......@@ -167,7 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
FloatAB,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<EPerThreadLoop, KPerThreadLoop>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
......
......@@ -17,8 +17,6 @@ template <typename FloatA,
typename AThreadDesc_E_K,
typename BThreadDesc_E_N_Ho_Wo,
typename CThreadDesc_K_N_Ho_Wo,
index_t H,
index_t W,
typename enable_if<AThreadDesc_E_K::IsKnownAtCompileTime() &&
BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
......@@ -56,98 +54,52 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr index_t Vec = 2;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto E = AThreadDesc_E_K{}.GetLength(I0);
constexpr auto K = AThreadDesc_E_K{}.GetLength(I1);
constexpr auto H = BThreadDesc_E_N_Ho_Wo{}.GetLength(I2);
constexpr auto W = BThreadDesc_E_N_Ho_Wo{}.GetLength(I3);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, E, 1>{}([&](auto e) {
static_for<0, K, 1>{}([&](auto k) {
constexpr index_t a_offset =
AThreadDesc_E_K{}.CalculateOffset(a_origin_idx + make_tuple(e, k));
#if 0
if constexpr(H == 2 && W == 2)
{
constexpr index_t b_offset_0 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
constexpr index_t b_offset_1 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1));
constexpr index_t b_offset_2 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
constexpr index_t b_offset_3 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1));
constexpr index_t c_offset_0 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
constexpr index_t c_offset_1 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1));
constexpr index_t c_offset_2 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
constexpr index_t c_offset_3 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
}
else if constexpr(H == 4 && W == 1)
{
constexpr index_t b_offset_0 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
constexpr index_t b_offset_1 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
constexpr index_t b_offset_2 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0));
constexpr index_t b_offset_3 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0));
constexpr index_t c_offset_0 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
constexpr index_t c_offset_1 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
constexpr index_t c_offset_2 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0));
constexpr index_t c_offset_3 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
}
else
#endif
{
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
constexpr index_t b_offset = BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(
b_origin_idx + make_tuple(e, 0, h, w));
static_for<0, K, 1>{}([&](auto k) {
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
static_for<0, E, Vec>{}([&](auto e) {
vector_type<FloatA, Vec> a_vec;
vector_type<FloatB, Vec> b_vec;
constexpr index_t c_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h, w));
static_for<0, Vec, 1>{}([&](auto v) {
constexpr index_t a_offset = AThreadDesc_E_K{}.CalculateOffset(
a_origin_idx + make_tuple(e + v, k));
constexpr index_t b_offset = BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(
b_origin_idx + make_tuple(e + v, 0, h, w));
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
a_vec.template AsType<FloatA>()(v) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(v) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, Vec>::type;
using b_vector_t = typename vector_type<FloatB, Vec>::type;
constexpr index_t c_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h, w));
inner_product<a_vector_t, b_vector_t, FloatC>(
a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
});
}
});
});
});
}
......
......@@ -106,7 +106,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t WoPerBlock = 8;
constexpr index_t E1 = 16;
constexpr index_t EPerBlock = 8;
constexpr index_t EPerBlock = 16;
constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 2;
......
......@@ -126,7 +126,7 @@ int main(int argc, char* argv[])
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
#endif
#if 1
#if 0
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
......
......@@ -11,7 +11,7 @@ cmake
-D HALF_INCLUDE_DIR="/root/workspace/external/half/include" \
-D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX906 -O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX1030 -O3 --amdgpu-target=gfx1030 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment