Commit 43adf1fa authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

clang format

parent ab3d3b4a
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#ifndef KERNARG_PRELOAD #ifndef KERNARG_PRELOAD
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig &stream_config, float launch_and_time_kernel(const StreamConfig& stream_config,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
...@@ -19,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -19,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if (stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
...@@ -49,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -49,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
hip_check_error(hipDeviceSynchronize()); hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for (int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
...@@ -81,7 +81,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -81,7 +81,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
#else #else
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig &stream_config, float launch_and_time_kernel(const StreamConfig& stream_config,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
...@@ -92,7 +92,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -92,7 +92,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
// hipGetErrorString(hipMalloc(&args1, sizeof(Args))); // hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
// hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice)); // hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if (stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
...@@ -109,9 +109,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -109,9 +109,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
// //
// warm up // warm up
const int nrepeat = 1000; const int nrepeat = 1000;
for (auto i = 0; i < nrepeat; i++) for(auto i = 0; i < nrepeat; i++)
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, hipLaunchKernelGGL(
args...); kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
#if DEBUG_LOG #if DEBUG_LOG
...@@ -127,9 +127,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -127,9 +127,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for (int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, hipLaunchKernelGGL(
args...); kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
// hip_check_error(hipGetLastError()); // hip_check_error(hipGetLastError());
hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
...@@ -140,8 +140,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -140,8 +140,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
} }
else else
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>( kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
args...);
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
return 0; return 0;
...@@ -155,7 +154,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config, ...@@ -155,7 +154,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
} }
#endif #endif
template <typename... Args, typename F, typename PreProcessFunc> template <typename... Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
PreProcessFunc preprocess, PreProcessFunc preprocess,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
...@@ -164,7 +163,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, ...@@ -164,7 +163,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if (stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
...@@ -195,7 +194,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, ...@@ -195,7 +194,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config,
hip_check_error(hipDeviceSynchronize()); hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for (int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
{ {
preprocess(); preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
......
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck namespace ck {
{
template <typename GridwiseTsmm, template <typename GridwiseTsmm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename BLayout, typename BLayout,
...@@ -27,35 +26,68 @@ namespace ck ...@@ -27,35 +26,68 @@ namespace ck
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop, bool HasDoubleTailKBlockLoop,
typename Block2CTileMap> typename Block2CTileMap>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_tsmm_dl_v1r3( kernel_tsmm_dl_v1r3(
const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, index_t M, index_t N, index_t K, const FloatAB* p_a_grid,
index_t K0, index_t k_batch, index_t MPadded, index_t NPadded, const Block2CTileMap block_2_ctile_map) //: in __global__ functions, struct is const FloatAB* p_b_grid,
FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t K0,
index_t k_batch,
index_t MPadded,
index_t NPadded,
const Block2CTileMap block_2_ctile_map) //: in __global__ functions, struct is
// better for reduced load overhead // better for reduced load overhead
{ {
// strides depend on B's layout // strides depend on B's layout
if constexpr (is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
GridwiseTsmm::template Run<HasMainKBlockLoop, GridwiseTsmm::template Run<HasMainKBlockLoop,
HasDoubleTailKBlockLoop, HasDoubleTailKBlockLoop,
GridwiseTsmm, GridwiseTsmm,
CGlobalMemoryDataOperation>(p_a_grid, p_b_grid, p_c_grid, M, N, K, CGlobalMemoryDataOperation>(p_a_grid,
K0, k_batch, K, N, N, MPadded, NPadded, block_2_ctile_map); p_b_grid,
p_c_grid,
M,
N,
K,
K0,
k_batch,
K,
N,
N,
MPadded,
NPadded,
block_2_ctile_map);
} }
else else
{ {
GridwiseTsmm::template Run<HasMainKBlockLoop, GridwiseTsmm::template Run<HasMainKBlockLoop,
HasDoubleTailKBlockLoop, HasDoubleTailKBlockLoop,
GridwiseTsmm, GridwiseTsmm,
CGlobalMemoryDataOperation>(p_a_grid, p_b_grid, p_c_grid, M, N, K, CGlobalMemoryDataOperation>(p_a_grid,
K0, k_batch, K, K, N, MPadded, NPadded, block_2_ctile_map); p_b_grid,
} p_c_grid,
} M,
N,
template <index_t BlockSize, K,
K0,
k_batch,
K,
K,
N,
MPadded,
NPadded,
block_2_ctile_map);
}
}
template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
...@@ -83,8 +115,8 @@ namespace ck ...@@ -83,8 +115,8 @@ namespace ck
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct GridwiseTsmmDl_km_kn_mn struct GridwiseTsmmDl_km_kn_mn
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -96,9 +128,9 @@ namespace ck ...@@ -96,9 +128,9 @@ namespace ck
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument // struct Argument : public tensor_operation::device::BaseArgument //
{ {
Argument(const FloatAB *p_a_grid_, Argument(const FloatAB* p_a_grid_,
const FloatAB *p_b_grid_, const FloatAB* p_b_grid_,
FloatC *p_c_grid_, FloatC* p_c_grid_,
index_t M_, index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
...@@ -128,9 +160,9 @@ namespace ck ...@@ -128,9 +160,9 @@ namespace ck
} }
// private: // private:
const FloatAB *p_a_grid; const FloatAB* p_a_grid;
const FloatAB *p_b_grid; const FloatAB* p_b_grid;
FloatC *p_c_grid; FloatC* p_c_grid;
index_t M, N, K; index_t M, N, K;
index_t StrideA, StrideB, StrideC; index_t StrideA, StrideB, StrideC;
...@@ -214,19 +246,18 @@ namespace ck ...@@ -214,19 +246,18 @@ namespace ck
index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0) index_t M, index_t MPad, index_t K, index_t StrideA, index_t KBatch, index_t K0)
{ {
const auto a_grid_desc_m_k = [&]() const auto a_grid_desc_m_k = [&]() {
{ if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
if constexpr (is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
} }
else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
} }
}(); }();
if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -255,19 +286,18 @@ namespace ck ...@@ -255,19 +286,18 @@ namespace ck
index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0) index_t K, index_t NPad, index_t N, index_t StrideB, index_t KBatch, index_t K0)
{ {
const auto b_grid_desc_k_n = [&]() const auto b_grid_desc_k_n = [&]() {
{ if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
if constexpr (is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
} }
else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
} }
}(); }();
if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -290,19 +320,18 @@ namespace ck ...@@ -290,19 +320,18 @@ namespace ck
__host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() const auto c_grid_desc_m_n = [&]() {
{ if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
if constexpr (is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
} }
else if constexpr (is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
} }
}(); }();
if constexpr (GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -335,7 +364,7 @@ namespace ck ...@@ -335,7 +364,7 @@ namespace ck
using BGridDesc_Kbatch_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1, 1)); using BGridDesc_Kbatch_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
__host__ __device__ static constexpr bool CheckValidity(const Argument &karg) __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{ {
// const auto MPadded = CalculateMPadded(karg.M); // const auto MPadded = CalculateMPadded(karg.M);
...@@ -361,7 +390,7 @@ namespace ck ...@@ -361,7 +390,7 @@ namespace ck
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1 // KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
__host__ __device__ static constexpr auto MakeAGridDescriptor_Kbatch_K0_M0_M1_K1( __host__ __device__ static constexpr auto MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(
const AGridDesc_Kbatch_K0_M_K1 &a_grid_desc_kbatch_k0_m_k1) const AGridDesc_Kbatch_K0_M_K1& a_grid_desc_kbatch_k0_m_k1)
{ {
const auto KBatch = a_grid_desc_kbatch_k0_m_k1.GetLength(I0); const auto KBatch = a_grid_desc_kbatch_k0_m_k1.GetLength(I0);
const auto K0 = a_grid_desc_kbatch_k0_m_k1.GetLength(I1); const auto K0 = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
...@@ -383,7 +412,7 @@ namespace ck ...@@ -383,7 +412,7 @@ namespace ck
} }
__host__ __device__ static constexpr auto MakeBGridDescriptor_Kbatch_K0_N0_N1_K1( __host__ __device__ static constexpr auto MakeBGridDescriptor_Kbatch_K0_N0_N1_K1(
const BGridDesc_Kbatch_K0_N_K1 &b_grid_desc_kbatch_k0_n_k1) const BGridDesc_Kbatch_K0_N_K1& b_grid_desc_kbatch_k0_n_k1)
{ {
const auto KBatch = b_grid_desc_kbatch_k0_n_k1.GetLength(I0); const auto KBatch = b_grid_desc_kbatch_k0_n_k1.GetLength(I0);
const auto K0 = b_grid_desc_kbatch_k0_n_k1.GetLength(I1); const auto K0 = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
...@@ -405,7 +434,7 @@ namespace ck ...@@ -405,7 +434,7 @@ namespace ck
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n) MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
...@@ -451,8 +480,20 @@ namespace ck ...@@ -451,8 +480,20 @@ namespace ck
bool HasDoubleTailKBlockLoop, bool HasDoubleTailKBlockLoop,
typename GridwiseTsmm, typename GridwiseTsmm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation> InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, index_t M, index_t N, index_t K, __device__ static void Run(const FloatAB* p_a_grid,
index_t K0, index_t k_batch, index_t StrideA, index_t StrideB, index_t StrideC, index_t MPadded, index_t NPadded, const Block2CTileMap &block_2_ctile_map) const FloatAB* p_b_grid,
FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t K0,
index_t k_batch,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t MPadded,
index_t NPadded,
const Block2CTileMap& block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -464,8 +505,7 @@ namespace ck ...@@ -464,8 +505,7 @@ namespace ck
M, MPadded, K, StrideA, k_batch, K0); // M, MPadded, K, StrideA, k_batch, K0); //
const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1( const auto b_grid_desc_kbatch_k0_n_k1 = GridwiseTsmm::MakeBGridDescriptor_KBatch_K0_N_K1(
K, NPadded, N, StrideB, k_batch, K0); // K, NPadded, N, StrideB, k_batch, K0); //
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n = GridwiseTsmm::MakeCGridDescriptor_M_N(M, N, StrideC);
GridwiseTsmm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto a_grid_desc_kbatch_k0_m0_m1_k1 = const auto a_grid_desc_kbatch_k0_m0_m1_k1 =
GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); // GridwiseTsmm::MakeAGridDescriptor_Kbatch_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1); //
...@@ -482,15 +522,15 @@ namespace ck ...@@ -482,15 +522,15 @@ namespace ck
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.convert_1D_block_idx_to_3D_tuple( const auto c_m0_n0_block_cluster_idx =
get_block_1d_id(), N, k_batch); block_2_ctile_map.convert_1D_block_idx_to_3D_tuple(get_block_1d_id(), N, k_batch);
// HACK: this force index data into SGPR // HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I2]); const index_t kbatch_id = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I2]);
if (!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0), make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
...@@ -593,7 +633,7 @@ namespace ck ...@@ -593,7 +633,7 @@ namespace ck
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB *p_a_block_double = p_shared_block; FloatAB* p_a_block_double = p_shared_block;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize()); b_k0_n_k1_thread_desc.GetElementSpaceSize());
...@@ -632,7 +672,7 @@ namespace ck ...@@ -632,7 +672,7 @@ namespace ck
b_thread_even_buf); b_thread_even_buf);
} }
if constexpr (HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
// const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1); // const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
...@@ -691,11 +731,11 @@ namespace ck ...@@ -691,11 +731,11 @@ namespace ck
a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_copy_kbatch_k0_m0_m1_k1, a_block_even_buf);
k_block_data_begin += 2 * K0PerBlock; k_block_data_begin += 2 * K0PerBlock;
} while (k_block_data_begin < K0 - 2 * K0PerBlock); } while(k_block_data_begin < K0 - 2 * K0PerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
if constexpr (HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_kbatch_k0_m0_m1_k1,
a_block_slice_copy_step); a_block_slice_copy_step);
...@@ -780,5 +820,5 @@ namespace ck ...@@ -780,5 +820,5 @@ namespace ck
c_grid_buf); c_grid_buf);
} }
} }
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment