"examples/git@developer.sourcefind.cn:guobj/qwen_lmdeploy.git" did not exist on "6df4a6ac36d1fe239c4355d5459fc7832c0109d2"
Commit a8169558 authored by Jing Zhang's avatar Jing Zhang
Browse files

imprve threadwise gemm with dot2

parent 8eaa6d5d
...@@ -29,21 +29,20 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -29,21 +29,20 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
index_t w; index_t w;
}; };
static constexpr index_t KPerThreadLoop = 4;
static constexpr auto KPerThread = ThreadMatrixC{}.GetLength(I0); static constexpr auto KPerThread = ThreadMatrixC{}.GetLength(I0);
static constexpr auto HPerThread = ThreadMatrixC{}.GetLength(I2); static constexpr auto HPerThread = ThreadMatrixC{}.GetLength(I2);
static constexpr auto WPerThread = ThreadMatrixC{}.GetLength(I3); static constexpr auto WPerThread = ThreadMatrixC{}.GetLength(I3);
// HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4;
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{})); make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}));
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
...@@ -110,13 +109,8 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -110,13 +109,8 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
constexpr auto EPerBlock = a_block_mtx.GetLength(I0); constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
// HACK: fix this @Jing Zhang static_assert(EPerBlock % EPerThreadLoop == 0, "");
constexpr auto HoPerThreadSubC = HPerThread; static_assert(KPerThread % KPerThreadLoop == 0, "");
constexpr auto WoPerThreadSubC = WPerThread;
static_assert(KPerThread % KPerThreadSubC == 0, "");
static_assert(HPerThread % HoPerThreadSubC == 0, "");
static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatAB, a_thread_mtx_.GetElementSpaceSize(), true> StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatAB, a_thread_mtx_.GetElementSpaceSize(), true>
...@@ -127,12 +121,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -127,12 +121,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
FloatC, FloatC,
decltype(a_thread_mtx_), decltype(a_thread_mtx_),
decltype(b_thread_mtx_), decltype(b_thread_mtx_),
decltype(c_thread_mtx_), decltype(c_thread_mtx_)>{};
HoPerThreadSubC,
WoPerThreadSubC>{};
static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) { static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) {
static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) { static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) {
a_thread_copy_.Run(a_block_mtx, a_thread_copy_.Run(a_block_mtx,
make_tuple(e_begin, k_begin), make_tuple(e_begin, k_begin),
a_block_buf, a_block_buf,
...@@ -140,16 +132,12 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -140,16 +132,12 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
make_tuple(I0, I0), make_tuple(I0, I0),
a_thread_buf); a_thread_buf);
static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) { threadwise_gemm.Run(a_thread_buf,
static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) { make_tuple(I0, I0),
threadwise_gemm.Run(a_thread_buf, b_thread_buf,
make_tuple(I0, I0), make_tuple(e_begin, I0, I0, I0),
b_thread_buf, c_thread_buf,
make_tuple(e_begin, I0, h_begin, w_begin), make_tuple(k_begin, I0, I0, I0));
c_thread_buf,
make_tuple(k_begin, I0, h_begin, w_begin));
});
});
}); });
}); });
} }
...@@ -167,7 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -167,7 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
FloatAB, FloatAB,
BlockMatrixA, BlockMatrixA,
decltype(a_thread_mtx_), decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>, Sequence<EPerThreadLoop, KPerThreadLoop>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
ThreadGemmADataPerRead_K, ThreadGemmADataPerRead_K,
......
...@@ -17,8 +17,6 @@ template <typename FloatA, ...@@ -17,8 +17,6 @@ template <typename FloatA,
typename AThreadDesc_E_K, typename AThreadDesc_E_K,
typename BThreadDesc_E_N_Ho_Wo, typename BThreadDesc_E_N_Ho_Wo,
typename CThreadDesc_K_N_Ho_Wo, typename CThreadDesc_K_N_Ho_Wo,
index_t H,
index_t W,
typename enable_if<AThreadDesc_E_K::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_E_K::IsKnownAtCompileTime() &&
BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() && BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
...@@ -56,98 +54,52 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 ...@@ -56,98 +54,52 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value && is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
constexpr index_t Vec = 2;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto E = AThreadDesc_E_K{}.GetLength(I0); constexpr auto E = AThreadDesc_E_K{}.GetLength(I0);
constexpr auto K = AThreadDesc_E_K{}.GetLength(I1); constexpr auto K = AThreadDesc_E_K{}.GetLength(I1);
constexpr auto H = BThreadDesc_E_N_Ho_Wo{}.GetLength(I2);
constexpr auto W = BThreadDesc_E_N_Ho_Wo{}.GetLength(I3);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, E, 1>{}([&](auto e) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, K, 1>{}([&](auto k) { static_for<0, H, 1>{}([&](auto h) {
constexpr index_t a_offset = static_for<0, W, 1>{}([&](auto w) {
AThreadDesc_E_K{}.CalculateOffset(a_origin_idx + make_tuple(e, k)); static_for<0, E, Vec>{}([&](auto e) {
vector_type<FloatA, Vec> a_vec;
#if 0 vector_type<FloatB, Vec> b_vec;
if constexpr(H == 2 && W == 2)
{
constexpr index_t b_offset_0 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
constexpr index_t b_offset_1 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1));
constexpr index_t b_offset_2 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
constexpr index_t b_offset_3 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1));
constexpr index_t c_offset_0 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
constexpr index_t c_offset_1 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1));
constexpr index_t c_offset_2 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
constexpr index_t c_offset_3 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
}
else if constexpr(H == 4 && W == 1)
{
constexpr index_t b_offset_0 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
constexpr index_t b_offset_1 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
constexpr index_t b_offset_2 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0));
constexpr index_t b_offset_3 =
BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0));
constexpr index_t c_offset_0 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
constexpr index_t c_offset_1 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
constexpr index_t c_offset_2 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0));
constexpr index_t c_offset_3 =
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
}
else
#endif
{
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
constexpr index_t b_offset = BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(
b_origin_idx + make_tuple(e, 0, h, w));
constexpr index_t c_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( static_for<0, Vec, 1>{}([&](auto v) {
c_origin_idx + make_tuple(k, 0, h, w)); constexpr index_t a_offset = AThreadDesc_E_K{}.CalculateOffset(
a_origin_idx + make_tuple(e + v, k));
constexpr index_t b_offset = BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(
b_origin_idx + make_tuple(e + v, 0, h, w));
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}( a_vec.template AsType<FloatA>()(v) = a_buf[Number<a_offset>{}];
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]); b_vec.template AsType<FloatB>()(v) = b_buf[Number<b_offset>{}];
}); });
using a_vector_t = typename vector_type<FloatA, Vec>::type;
using b_vector_t = typename vector_type<FloatB, Vec>::type;
constexpr index_t c_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h, w));
inner_product<a_vector_t, b_vector_t, FloatC>(
a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
}); });
} });
}); });
}); });
} }
......
...@@ -106,7 +106,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -106,7 +106,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t WoPerBlock = 8; constexpr index_t WoPerBlock = 8;
constexpr index_t E1 = 16; constexpr index_t E1 = 16;
constexpr index_t EPerBlock = 8; constexpr index_t EPerBlock = 16;
constexpr index_t KPerThread = KPerBlock; constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
......
...@@ -126,7 +126,7 @@ int main(int argc, char* argv[]) ...@@ -126,7 +126,7 @@ int main(int argc, char* argv[])
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
#endif #endif
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
......
...@@ -11,7 +11,7 @@ cmake ...@@ -11,7 +11,7 @@ cmake
-D HALF_INCLUDE_DIR="/root/workspace/external/half/include" \ -D HALF_INCLUDE_DIR="/root/workspace/external/half/include" \
-D BUILD_DEV=OFF \ -D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX906 -O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX1030 -O3 --amdgpu-target=gfx1030 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment