"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "c2714fcbfd600c2a13efbc42bab95b49b0b4fa33"
Commit a956d60e authored by wangshaojie6's avatar wangshaojie6
Browse files

use tuple for skip both lds

parent 244b9ffb
......@@ -43,13 +43,14 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipAllLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BThreadTransfer| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per|MultiK0| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BThreadTransfer| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 7, 1>;
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 7, 1>;
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 8, 7, 1>;
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 7, 1>;
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 7, 1>;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
......
......@@ -15,7 +15,7 @@
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 2
#define CK_MIN_BLOCK_PER_CU 1
#endif
// check GPU target
......@@ -98,7 +98,7 @@
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
// experimental feature: buffer load/store/atomic-add/ OOB trick
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
......
......@@ -32,6 +32,7 @@ template <typename ADataType,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t MultiK0,
ck::index_t K1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
......@@ -189,6 +190,7 @@ struct DeviceGemmXdlSkipAllLds
MPerBlock,
NPerBlock,
K0PerBlock,
MultiK0,
MPerXDL,
NPerXDL,
K1,
......
......@@ -3,6 +3,7 @@
namespace ck {
// N-stage prefetch
template <index_t NumPrefetch>
struct GridwiseGemmPipeline_v2;
......
......@@ -46,18 +46,19 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
//__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
// p_shared,
a_grid_desc_k0_m_k1,
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid,
p_b_grid,
p_c_grid,
// p_shared,
a_grid_desc_k0_m_k1,
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -86,6 +87,7 @@ template <index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MultiK0,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
......@@ -115,7 +117,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto MultiK0 = 4 * 1;
//static constexpr auto MultiK0 = 16 * 1;
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
......@@ -227,11 +229,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / (MultiK0 * K0PerBlock)) > 1;
const bool has_main_k0_block_loop = K0 > (MultiK0 * K0PerBlock);
return has_main_k0_block_loop;
}
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr index_t CalculateResMainK0BlockLoop(index_t K0)
{
const index_t res_main_k0_block_loop = (K0 / K0PerBlock) % MultiK0;
return res_main_k0_block_loop;
}
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_K1_K2_M0_M1_M2_M3_K3(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{
......@@ -396,7 +406,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
using AGridDesc_K0_K1_K2_M0_M1_M2_M3_K3 =
decltype(MakeAGridDescriptor_K0_K1_K2_M0_M1_M2_M3_K3(AGridDesc_K0_M_K1{}));
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
template <bool HasMainK0BlockLoop,
typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......@@ -420,6 +431,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
//const auto ResMainK0BlockLoop = CalculateResMainK0BlockLoop(K0);
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -444,11 +457,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
I1, // NPerXdlops
Number<K1>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3.GetElementSpaceSize(),
true>
a_thread_buf[MultiK0]; //, a_thread_buf_1, a_thread_buf_2, a_thread_buf_3;
auto a_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3.GetElementSpaceSize(),
true>{};
},
Number<MultiK0>{});
//StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAB,
// a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3.GetElementSpaceSize(),
// true>
// a_thread_buf[MultiK0];
ignore = b_element_op;
// B matrix threadwise copy
......@@ -462,11 +484,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
I1, // NPerXdlops
Number<K1>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>
b_thread_buf[MultiK0]; //_0, b_thread_buf_1, b_thread_buf_2, b_thread_buf_3;
auto b_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{};
},
Number<MultiK0>{});
//StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAB,
// b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
// true>
// b_thread_buf[MultiK0];
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
......@@ -564,58 +596,55 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
// constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
constexpr auto a_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// preload data to regiester and LDS
if constexpr(HasMainK0BlockLoop)
{
// Read
index_t i_pre = 0;
do
{
static_for<0, MultiK0, 1>{}([&](auto i_pre) {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf[i_pre]);
b_thread_buf(Number<i_pre>{}));
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf[i_pre]);
a_thread_buf(Number<i_pre>{}));
asm volatile("s_nop 0" ::);
// Move
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
i_pre++;
} while(i_pre < MultiK0);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
});
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainK0BlockLoop)
{
index_t K0BlockMainLoop =
__builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
index_t i = 0;
index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
index_t i = 0;
do
{
index_t i_k = 0;
do
{
blockwise_gemm.Run(a_thread_buf[i_k], b_thread_buf[i_k], c_thread_buf);
static_for<0, MultiK0, 1>{}([&](auto i_k) {
blockwise_gemm.Run(a_thread_buf(Number<i_k>{}), b_thread_buf(Number<i_k>{}), c_thread_buf);
asm volatile("s_nop 0" ::);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf[i_k]);
b_thread_buf(Number<i_k>{}));
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf[i_k]);
a_thread_buf(Number<i_k>{}));
asm volatile("s_nop 0" ::);
......@@ -623,8 +652,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
i_k++;
} while(i_k < MultiK0);
});
i += MultiK0;
} while(i < (K0BlockMainLoop - MultiK0));
......@@ -632,26 +660,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
// tail
{
//index_t loop_num = ResMainK0BlockLoop == 0 ? MultiK0 : ResMainK0BlockLoop;
static_for<0, MultiK0, 1>{}([&](auto i) {
blockwise_gemm.Run(a_thread_buf[i], b_thread_buf[i], c_thread_buf);
if constexpr(i < MultiK0 - 4)
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf[i]);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf[i]);
// only move b windows
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
}
blockwise_gemm.Run(
a_thread_buf(Number<i>{}), b_thread_buf(Number<i>{}), c_thread_buf);
});
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment