Commit b8ba0239 authored by carlushuang's avatar carlushuang
Browse files

support multi-thread

parent e06b9871
#ifndef CK_BLOCKWISE_GEMM_AVX2_HPP #ifndef CK_BLOCKWISE_GEMM_AVX2_HPP
#define CK_BLOCKWISE_GEMM_AVX2_HPP #define CK_BLOCKWISE_GEMM_AVX2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "threadwise_gemm_avx2.hpp" #include "threadwise_gemm_avx2.hpp"
namespace ck { namespace ck {
namespace cpu { namespace cpu {
template <typename FloatA, template <typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename ABlockDesc, typename ABlockDesc,
typename BBlockDesc, typename BBlockDesc,
typename CDesc, typename CDesc,
ck::index_t KPerBlock, ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch, typename ThreadwiseGemm_Dispatch,
typename ThreadMNAccessOrder // how we acces gemm MN to utilize micro kernel typename ThreadMNAccessOrder // how we acces gemm MN to utilize micro kernel
> >
struct BlockwiseGemmAvx2_MxN struct BlockwiseGemmAvx2_MxN
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr index_t nDimA = ABlockDesc::GetNumOfDimension(); static constexpr index_t nDimA = ABlockDesc::GetNumOfDimension();
static constexpr index_t nDimB = BBlockDesc::GetNumOfDimension(); static constexpr index_t nDimB = BBlockDesc::GetNumOfDimension();
static constexpr index_t nDimC = CDesc::GetNumOfDimension(); static constexpr index_t nDimC = CDesc::GetNumOfDimension();
using IndexA = MultiIndex<nDimA>; using IndexA = MultiIndex<nDimA>;
using IndexB = MultiIndex<nDimB>; using IndexB = MultiIndex<nDimB>;
using IndexC = MultiIndex<nDimC>; using IndexC = MultiIndex<nDimC>;
using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{})); using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{}));
using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{})); using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{}));
using CCoord = decltype(make_tensor_coordinate(CDesc{}, IndexC{})); using CCoord = decltype(make_tensor_coordinate(CDesc{}, IndexC{}));
template <typename TensorDesc> template <typename TensorDesc>
constexpr auto GetLeadingElement(const TensorDesc& desc) static constexpr auto GetLeadingElement(const TensorDesc& desc)
{ {
// if use this function, make sure desc are known at compile time. // if use this function, make sure desc are known at compile time.
// otherwise, it is not efficient to calculate leading dim here // otherwise, it is not efficient to calculate leading dim here
if constexpr(TensorDesc::GetNumOfDimension() == 1) if constexpr(TensorDesc::GetNumOfDimension() == 1)
{ {
return 1; return 1;
} }
else else
{ {
constexpr auto last_dims = constexpr auto last_dims =
typename uniform_sequence_gen<TensorDesc::GetNumOfDimension() - 1, 0>::type{}; typename uniform_sequence_gen<TensorDesc::GetNumOfDimension() - 1, 0>::type{};
constexpr auto lead_dims = decltype(last_dims)::PushFront(Number<1>{}); constexpr auto lead_dims = decltype(last_dims)::PushFront(Number<1>{});
return desc.CalculateOffset(lead_dims); return desc.CalculateOffset(lead_dims);
} }
} }
ck::index_t GetALeadingElement(const ABlockDesc& a_block_desc) const static ck::index_t GetALeadingElement(const ABlockDesc& a_block_desc)
{ {
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
} }
ck::index_t GetBLeadingElement(const BBlockDesc& b_block_desc) const static ck::index_t GetBLeadingElement(const BBlockDesc& b_block_desc)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
{ {
// K * N // K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
} }
else else
{ {
// N/8 * K * 8 // N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] * return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}]; b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
} }
} }
ck::index_t GetCLeadingElement(const CDesc& c_desc) const static ck::index_t GetCLeadingElement(const CDesc& c_desc)
{ {
return c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; return c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
} }
ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc) const static ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
{ {
// M * K // M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
} }
else else
{ {
// K * M // K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
} }
} }
ck::index_t GetKPerBlock(const ABlockDesc& a_block_desc) const static ck::index_t GetKPerBlock(const ABlockDesc& a_block_desc)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
{ {
// M * K // M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
} }
else else
{ {
// K * M // K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
} }
} }
ck::index_t GetNPerBlock(const BBlockDesc& b_block_desc) const static ck::index_t GetNPerBlock(const BBlockDesc& b_block_desc)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
{ {
// K * N // K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
} }
else else
{ {
// N/8 * K * 8 // N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] * return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}]; b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
} }
} }
ck::index_t static ck::index_t
GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t) const GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
{ {
return i_m * a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; return i_m * a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
} }
else else
{ {
return i_m; return i_m;
} }
} }
ck::index_t static ck::index_t
GetBBlockStartOffset(const BBlockDesc& b_block_desc, const index_t, const index_t i_n) const GetBBlockStartOffset(const BBlockDesc& b_block_desc, const index_t, const index_t i_n)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
{ {
// K * N // K * N
return i_n; return i_n;
} }
else else
{ {
// N/8 * K * 8 // N/8 * K * 8
return i_n * b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; return i_n * b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
} }
} }
ck::index_t static ck::index_t
GetCBlockStartOffset(const CDesc& c_desc, const index_t i_m, const index_t i_n) const GetCBlockStartOffset(const CDesc& c_desc, const index_t i_m, const index_t i_n)
{ {
return i_m * c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n; return i_m * c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n;
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CBuffer>
void Run(const ABlockDesc& a_block_desc, static void Run(const ABlockDesc& a_block_desc,
const ABlockBuffer& a_block_buf, const ABlockBuffer& a_block_buf,
const IndexA& /* a_origin */, const IndexA& /* a_origin */,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
const IndexB& /* b_origin */, const IndexB& /* b_origin */,
const CDesc& c_desc, const CDesc& c_desc,
CBuffer& c_buf, CBuffer& c_buf,
const IndexC& /* c_origin */, const IndexC& /* c_origin */,
bool is_accumulate_c = true) const bool is_accumulate_c = true)
{ {
auto lda = GetALeadingElement(a_block_desc) * sizeof(FloatA); auto lda = GetALeadingElement(a_block_desc) * sizeof(FloatA);
auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB); auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB);
auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC); auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC);
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc); // printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
const auto k_per_block = GetKPerBlock(a_block_desc); const auto k_per_block = GetKPerBlock(a_block_desc);
const auto m_per_block = GetMPerBlock(a_block_desc); const auto m_per_block = GetMPerBlock(a_block_desc);
const auto n_per_block = GetNPerBlock(b_block_desc); const auto n_per_block = GetNPerBlock(b_block_desc);
const auto m_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxMr; const auto m_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxMr;
const auto n_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxNr; const auto n_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxNr;
ck::cpu::ThreadwiseGemmParam param; ck::cpu::ThreadwiseGemmParam param;
param.Kr = k_per_block; param.Kr = k_per_block;
param.lda = lda; param.lda = lda;
param.ldb = ldb; param.ldb = ldb;
param.ldc = ldc; param.ldc = ldc;
param.alpha = 1.0f; // TODO param.alpha = 1.0f; // TODO
param.accmulate_c = is_accumulate_c ? 1 : 0; param.accmulate_c = is_accumulate_c ? 1 : 0;
if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value) if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value)
{ {
for(ck::index_t i_m = 0; i_m < m_per_block; i_m += m_per_thread) for(ck::index_t i_m = 0; i_m < m_per_block; i_m += m_per_thread)
{ {
auto current_mr = ck::math::min(m_per_block - i_m, m_per_thread); auto current_mr = ck::math::min(m_per_block - i_m, m_per_thread);
param.p_a = &a_block_buf.p_data_[GetABlockStartOffset(a_block_desc, i_m, 0)]; param.p_a = &a_block_buf.p_data_[GetABlockStartOffset(a_block_desc, i_m, 0)];
// printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr, // printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
// GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout); // GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
for(ck::index_t i_n = 0; i_n < n_per_block; i_n += n_per_thread) for(ck::index_t i_n = 0; i_n < n_per_block; i_n += n_per_thread)
{ {
auto current_nr = ck::math::min(n_per_block - i_n, n_per_thread); auto current_nr = ck::math::min(n_per_block - i_n, n_per_thread);
param.p_b = &b_block_buf.p_data_[GetBBlockStartOffset(b_block_desc, 0, i_n)]; param.p_b = &b_block_buf.p_data_[GetBBlockStartOffset(b_block_desc, 0, i_n)];
param.p_c = &c_buf.p_data_[GetCBlockStartOffset(c_desc, i_m, i_n)]; param.p_c = &c_buf.p_data_[GetCBlockStartOffset(c_desc, i_m, i_n)];
// printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n, // printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
// current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b, // current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
// GetCBlockStartOffset(c_desc, i_m, i_n), // GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout); // param.p_c);fflush(stdout);
ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr); ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr);
} }
} }
} }
} }
}; };
} // namespace cpu } // namespace cpu
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_GRIDWISE_GEMM_AVX2_HPP #ifndef CK_GRIDWISE_GEMM_AVX2_HPP
#define CK_GRIDWISE_GEMM_AVX2_HPP #define CK_GRIDWISE_GEMM_AVX2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_avx2.hpp" #include "blockwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp" #include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp" #include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "dynamic_buffer_cpu.hpp" #include "dynamic_buffer_cpu.hpp"
#include <unistd.h> #include <utility>
#include <unistd.h>
namespace ck { #include <omp.h>
namespace cpu {
namespace ck {
template <typename GridwiseGemm, namespace cpu {
typename FloatA,
typename FloatB, template <typename GridwiseGemm,
typename FloatC, typename FloatA,
typename AGridDesc, typename FloatB,
typename BGridDesc, typename FloatC,
typename CGridDesc, typename AGridDesc,
typename AElementwiseOperation, typename BGridDesc,
typename BElementwiseOperation, typename CGridDesc,
typename CElementwiseOperation> typename AElementwiseOperation,
void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid, typename BElementwiseOperation,
const FloatB* __restrict__ p_b_grid, typename CElementwiseOperation>
FloatC* __restrict__ p_c_grid, void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
const AGridDesc& a_grid_desc, const FloatB* __restrict__ p_b_grid,
const BGridDesc& b_grid_desc, FloatC* __restrict__ p_c_grid,
const CGridDesc& c_grid_desc, const AGridDesc& a_grid_desc,
const AElementwiseOperation& a_element_op, const BGridDesc& b_grid_desc,
const BElementwiseOperation& b_element_op, const CGridDesc& c_grid_desc,
const CElementwiseOperation& c_element_op) const AElementwiseOperation& a_element_op,
{ const BElementwiseOperation& b_element_op,
GridwiseGemm::Run(p_a_grid, const CElementwiseOperation& c_element_op)
p_b_grid, {
p_c_grid, GridwiseGemm::Run(p_a_grid,
a_grid_desc, p_b_grid,
b_grid_desc, p_c_grid,
c_grid_desc, a_grid_desc,
a_element_op, b_grid_desc,
b_element_op, c_grid_desc,
c_element_op); a_element_op,
} b_element_op,
c_element_op);
template <typename FloatA, }
typename FloatB,
typename FloatC, template <typename FloatA,
typename AGridDesc, typename FloatB,
typename BGridDesc, typename FloatC,
typename CGridDesc, typename AGridDesc,
typename AElementwiseOperation, typename BGridDesc,
typename BElementwiseOperation, typename CGridDesc,
typename CElementwiseOperation, typename AElementwiseOperation,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3) typename BElementwiseOperation,
ck::index_t NPerBlock, typename CElementwiseOperation,
ck::index_t KPerBlock, ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
typename ThreadwiseGemm_Dispatch, ck::index_t NPerBlock,
typename AThreadwiseCopy, ck::index_t KPerBlock,
typename BThreadwiseCopy, typename ThreadwiseGemm_Dispatch,
typename CThreadwiseCopy, typename AThreadwiseCopy,
typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache typename BThreadwiseCopy,
typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel typename CThreadwiseCopy,
bool UseALocalBuffer, typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache
bool UseBLocalBuffer, typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel
bool UseCLocalBuffer // if true, will allocate a buffer and write to it in kernel, then bool UseALocalBuffer,
// copy back to block buffer (need CThreadwiseCopy). bool UseBLocalBuffer,
// if false, will write to C directly (no need CThreadwiseCopy) bool UseCLocalBuffer // if true, will allocate a buffer and write to it in kernel, then
> // copy back to block buffer (need CThreadwiseCopy).
struct GridwiseGemmAvx2_MxN // if false, will write to C directly (no need CThreadwiseCopy)
{ >
static constexpr auto I0 = Number<0>{}; struct GridwiseGemmAvx2_MxN
static constexpr auto I1 = Number<1>{}; {
static constexpr auto I0 = Number<0>{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats static constexpr auto I1 = Number<1>{};
static constexpr index_t MemAlignmentByte = 32; // 256bit
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static auto GetABlockDescriptor(const ck::index_t m_per_blk, const ck::index_t k_per_blk) static constexpr index_t MemAlignmentByte = 32; // 256bit
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, static auto GetABlockDescriptor(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
ck::tensor_layout::gemm::RowMajor>::value) {
{ if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
// A : M, K ck::tensor_layout::gemm::RowMajor>::value)
auto a_block_desc_m_k = {
make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, k_per_blk)); // A : M, K
return a_block_desc_m_k; auto a_block_desc_m_k =
} make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, k_per_blk));
else return a_block_desc_m_k;
{ }
// A : K, M else
auto a_block_desc_k_m = make_naive_tensor_descriptor_packed( {
make_tuple(k_per_blk, // A : K, M
math::integer_least_multiple( auto a_block_desc_k_m = make_naive_tensor_descriptor_packed(
m_per_blk, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize))); make_tuple(k_per_blk,
return a_block_desc_k_m; math::integer_least_multiple(
} m_per_blk, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize)));
} return a_block_desc_k_m;
}
static auto GetBBlockDescriptor(const ck::index_t k_per_blk, const ck::index_t n_per_blk) }
{
// n_per_blk should be 8x static auto GetBBlockDescriptor(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, {
ck::tensor_layout::gemm::RowMajor>::value) // n_per_blk should be 8x
{ if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
// B : K, N ck::tensor_layout::gemm::RowMajor>::value)
auto b_block_desc_k_n = {
make_naive_tensor_descriptor_packed(make_tuple(k_per_blk, n_per_blk)); // B : K, N
return b_block_desc_k_n; auto b_block_desc_k_n =
} make_naive_tensor_descriptor_packed(make_tuple(k_per_blk, n_per_blk));
else return b_block_desc_k_n;
{ }
// B : N/8, K, N8 else
auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(make_tuple( {
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), // B : N/8, K, N8
k_per_blk, auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(make_tuple(
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)); math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
return b_block_desc_n0_k_n1; k_per_blk,
} ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
} return b_block_desc_n0_k_n1;
}
static auto GetCBlockDescriptor(const ck::index_t m_per_blk, const ck::index_t n_per_blk) }
{
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk)); static auto GetCBlockDescriptor(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
} {
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, }
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc) static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
{ const BGridDesc& b_grid_desc,
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) const CGridDesc& c_grid_desc)
bool is_valid = true; {
const auto GemmN = c_grid_desc.GetLength(I1); // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
if constexpr(UseCLocalBuffer) bool is_valid = true;
{ const auto GemmN = c_grid_desc.GetLength(I1);
if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN) if constexpr(UseCLocalBuffer)
is_valid &= false; {
} if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN)
else is_valid &= false;
{ }
// TODO: need check c grid is simple transform? else
if(GemmN % 8 != 0) {
is_valid &= false; // TODO: need check c grid is simple transform?
} if(GemmN % 8 != 0)
return is_valid; is_valid &= false;
} }
return is_valid;
static void Run(const FloatA* __restrict__ p_a_grid, }
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, static void Run(const FloatA* __restrict__ p_a_grid,
const AGridDesc& a_grid_desc, const FloatB* __restrict__ p_b_grid,
const BGridDesc& b_grid_desc, FloatC* __restrict__ p_c_grid,
const CGridDesc& c_grid_desc, const AGridDesc& a_grid_desc,
const AElementwiseOperation& a_element_op, const BGridDesc& b_grid_desc,
const BElementwiseOperation& b_element_op, const CGridDesc& c_grid_desc,
const CElementwiseOperation& c_element_op) const AElementwiseOperation& a_element_op,
{ const BElementwiseOperation& b_element_op,
ck::index_t m_per_block = MPerBlock; const CElementwiseOperation& c_element_op)
ck::index_t n_per_block = NPerBlock; {
ck::index_t k_per_block = KPerBlock; ck::index_t m_per_block = MPerBlock;
ck::index_t n_per_block = NPerBlock;
const auto GemmM = c_grid_desc.GetLength(I0); ck::index_t k_per_block = KPerBlock;
const auto GemmN = c_grid_desc.GetLength(I1);
const auto GemmK = a_grid_desc.GetLength(I1); const auto GemmM = c_grid_desc.GetLength(I0);
const auto GemmN = c_grid_desc.GetLength(I1);
constexpr auto a_block_copy_dim = AGridDesc::GetNumOfDimension(); const auto GemmK = a_grid_desc.GetLength(I1);
constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension(); constexpr auto a_block_copy_dim = AGridDesc::GetNumOfDimension();
auto a_threadwise_copy = AThreadwiseCopy(a_grid_desc, constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension();
ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block), auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
ck::make_zero_multi_index<a_block_copy_dim>(), reinterpret_cast<const FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize());
AElementwiseOperation{});
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
auto b_threadwise_copy = BThreadwiseCopy(b_grid_desc, reinterpret_cast<const FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize());
ck::make_zero_multi_index<b_block_copy_dim>(),
GetBBlockDescriptor(k_per_block, n_per_block), auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
ck::make_zero_multi_index<b_block_copy_dim>(), reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
BElementwiseOperation{});
auto blockwise_gemm = BlockwiseGemmAvx2_MxN<
auto c_threadwise_copy = CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block), FloatA, // FloatA,
ck::make_zero_multi_index<2>(), FloatB, // FloatB,
c_grid_desc, FloatC, // FloatC,
ck::make_zero_multi_index<2>(), decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc,
CElementwiseOperation{}); decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block)), // CBlockDesc,
DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA), KPerBlock, // KPerBlock,
MemAlignmentByte); ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB), ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
MemAlignmentByte); // gemm MN to utilize micro kernel>{};
DeviceAlignedMemCPU c_block_mem(m_per_block * n_per_block * sizeof(FloatC),
MemAlignmentByte); int total_threads = omp_get_max_threads();
auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( // TODO: openmp aware ordering
reinterpret_cast<const FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize()); //
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( {
reinterpret_cast<const FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize()); auto a_move_k_step = ck::make_multi_index(0, k_per_block);
auto b_move_k_step = ck::make_multi_index(0, k_per_block, 0);
auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize()); const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block);
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( const ck::index_t grid_size = grid_m * grid_n;
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf), const ck::index_t grids_per_thread =
a_block_mem.mMemSize / sizeof(FloatA)); math::integer_divide_ceil(grid_size, total_threads);
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( // This version does not consider K panel re-usage. simple for openmp
reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf), #pragma omp parallel
b_block_mem.mMemSize / sizeof(FloatB)); {
auto a_threadwise_copy =
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( AThreadwiseCopy(a_grid_desc,
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf) ck::make_zero_multi_index<a_block_copy_dim>(),
: reinterpret_cast<FloatC*>(p_c_grid), GetABlockDescriptor(m_per_block, k_per_block),
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC) ck::make_zero_multi_index<a_block_copy_dim>(),
: c_grid_desc.GetElementSpaceSize()); AElementwiseOperation{});
auto blockwise_gemm = BlockwiseGemmAvx2_MxN< auto b_threadwise_copy =
FloatA, // FloatA, BThreadwiseCopy(b_grid_desc,
FloatB, // FloatB, ck::make_zero_multi_index<b_block_copy_dim>(),
FloatC, // FloatC, GetBBlockDescriptor(k_per_block, n_per_block),
decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc, ck::make_zero_multi_index<b_block_copy_dim>(),
decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc, BElementwiseOperation{});
decltype(GetCBlockDescriptor(m_per_block, n_per_block)), // CBlockDesc,
KPerBlock, // KPerBlock, auto c_threadwise_copy =
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block),
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces ck::make_zero_multi_index<2>(),
// gemm MN to utilize micro kernel>{}; c_grid_desc,
ck::make_zero_multi_index<2>(),
// TODO: openmp aware ordering CElementwiseOperation{});
//
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value) DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA),
{ MemAlignmentByte);
auto a_move_k_step = ck::make_multi_index(0, k_per_block); DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB),
auto b_move_k_step = ck::make_multi_index(0, k_per_block, 0); MemAlignmentByte);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block); DeviceAlignedMemCPU c_block_mem(m_per_block * n_per_block * sizeof(FloatC),
const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block); MemAlignmentByte);
const ck::index_t grid_size = grid_m * grid_n; auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
// This version does not consider K panel re-usage. simple for openmp reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf),
#pragma omp parallel for a_block_mem.mMemSize / sizeof(FloatA));
for(ck::index_t gid = 0; gid < grid_size; gid++)
{ auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
ck::index_t i_mc = (gid / grid_n) * m_per_block; reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf),
ck::index_t i_nc = (gid % grid_n) * n_per_block; b_block_mem.mMemSize / sizeof(FloatB));
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block); auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
ck::index_t nc_size = UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x : reinterpret_cast<FloatC*>(p_c_grid),
nc_size = math::integer_least_multiple( UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); : c_grid_desc.GetElementSpaceSize());
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0)); const ck::index_t tid = omp_get_thread_num();
b_threadwise_copy.SetSrcSliceOrigin(
b_grid_desc, for(ck::index_t i_gpt = 0; i_gpt < grids_per_thread; i_gpt++)
ck::make_multi_index(math::integer_divide_ceil( {
i_nc, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), ck::index_t gid = i_gpt * total_threads + tid;
0, if(gid >= grid_size)
0)); break;
auto c_block_desc = ck::index_t i_mc = (gid / grid_n) * m_per_block;
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc; ck::index_t i_nc = (gid % grid_n) * n_per_block;
if constexpr(UseCLocalBuffer)
{ ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, ck::index_t nc_size =
ck::make_multi_index(i_mc, i_nc)); ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
} nc_size = math::integer_least_multiple(
else nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0));
ck::make_multi_index(i_mc, i_nc)); b_threadwise_copy.SetSrcSliceOrigin(
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); b_grid_desc,
} ck::make_multi_index(
math::integer_divide_ceil(
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block) i_nc, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
{ 0,
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block); 0));
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); auto c_block_desc =
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(UseCLocalBuffer)
// printf("==> i_m:%d, i_n:%d, i_k:%d, mc:%d, nc:%d, kc:%d(%d, %d)\n", i_mc, {
// i_nc, i_kc, mc_size, nc_size, kc_size, KPerBlock, GemmK); fflush(stdout); c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); }
else
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); {
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
// for(auto i_elem = 0; i_elem < (mc_size * kc_size) ; i_elem++){ ck::make_multi_index(i_mc, i_nc));
// printf("A ==> %3d : %f(0x%08x)\n", i_elem, c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
// (reinterpret_cast<float*>(a_block_buf.p_data_))[i_elem], }
// (reinterpret_cast<uint32_t*>(a_block_buf.p_data_))[i_elem]);
//} for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
{
// for(auto i_elem = 0; i_elem < (kc_size * nc_size) ; i_elem++){ ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
// printf("B ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(b_block_buf.p_data_))[i_elem], auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
// (reinterpret_cast<uint32_t*>(b_block_buf.p_data_))[i_elem]); auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
// }
// printf("[%d] 2222 \n",__LINE__); // printf("[tid:%d]==> i_m:%d, i_n:%d, i_k:%d, mc:%d, nc:%d, kc:%d(%d,
blockwise_gemm.Run(a_block_desc, // %d)\n", tid, i_mc,
a_block_buf, // i_nc, i_kc, mc_size, nc_size, kc_size, KPerBlock, GemmK); fflush(stdout);
make_zero_multi_index<a_block_copy_dim>(),
b_block_desc, a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_block_buf,
make_zero_multi_index<b_block_copy_dim>(), b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
c_block_desc,
c_block_buf, // for(auto i_elem = 0; i_elem < (mc_size * kc_size) ; i_elem++){
make_zero_multi_index<2>(), // printf("A ==> %3d : %f(0x%08x)\n", i_elem,
i_kc != 0); // (reinterpret_cast<float*>(a_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(a_block_buf.p_data_))[i_elem]);
// printf("[%d] 2222 \n",__LINE__); //}
if((i_kc + k_per_block) < GemmK)
{ // for(auto i_elem = 0; i_elem < (kc_size * nc_size) ; i_elem++){
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step); // printf("B ==> %3d : %f(0x%08x)\n", i_elem,
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step); // (reinterpret_cast<float*>(b_block_buf.p_data_))[i_elem],
} // (reinterpret_cast<uint32_t*>(b_block_buf.p_data_))[i_elem]);
// }
// printf("[%d] 2222 \n",__LINE__); // printf("[%d] 2222 \n",__LINE__);
blockwise_gemm.Run(a_block_desc,
// for(auto i_elem = 0; i_elem < (10) ; i_elem++){ a_block_buf,
// printf("C ==> %3d : %f(0x%08x)\n", i_elem, make_zero_multi_index<a_block_copy_dim>(),
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem], b_block_desc,
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]); b_block_buf,
// } make_zero_multi_index<b_block_copy_dim>(),
} c_block_desc,
c_block_buf,
// for(auto i_elem = 0; i_elem < (c_block_mem.mMemSize / sizeof(FloatC)) ; make_zero_multi_index<2>(),
// i_elem++){ i_kc != 0);
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem], // printf("[%d] 2222 \n",__LINE__);
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]); if((i_kc + k_per_block) < GemmK)
// } {
if constexpr(UseCLocalBuffer) a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
} }
}
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value) // printf("[%d] 2222 \n",__LINE__);
{
auto a_move_k_step = ck::make_multi_index(0, k_per_block); // for(auto i_elem = 0; i_elem < (10) ; i_elem++){
auto b_move_k_step = ck::make_multi_index( // printf("C ==> %3d : %f(0x%08x)\n", i_elem,
math::integer_divide_ceil(n_per_block, // (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), // (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
0, // }
0); }
// only parallel in gemm m dim // for(auto i_elem = 0; i_elem < (c_block_mem.mMemSize / sizeof(FloatC)) ;
#pragma omp parallel for // i_elem++){
for(ck::index_t i_mc = 0; i_mc < GemmM; i_mc += m_per_block) // printf("C ==> %3d : %f(0x%08x)\n", i_elem,
{ // (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block); // (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0)); // }
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block) if constexpr(UseCLocalBuffer)
{ c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block); }
}
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); }
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value)
{
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, auto a_move_k_step = ck::make_multi_index(0, k_per_block);
ck::make_multi_index(0, i_kc, 0)); auto b_move_k_step = ck::make_multi_index(
math::integer_divide_ceil(n_per_block,
// TODO: if use local C buffer, then this nc loop need to loop only once ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
for(ck::index_t i_nc = 0; i_nc < GemmN; i_nc += n_per_block) 0,
{ 0);
ck::index_t nc_size =
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
nc_size = math::integer_least_multiple( const ck::index_t grid_m_per_thread = math::integer_divide_ceil(grid_m, total_threads);
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
// only parallel in gemm m dim
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); #pragma omp parallel
{
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); auto a_threadwise_copy =
AThreadwiseCopy(a_grid_desc,
auto c_block_desc = ck::make_zero_multi_index<a_block_copy_dim>(),
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc; GetABlockDescriptor(m_per_block, k_per_block),
ck::make_zero_multi_index<a_block_copy_dim>(),
if constexpr(!UseCLocalBuffer) AElementwiseOperation{});
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, auto b_threadwise_copy =
ck::make_multi_index(i_mc, i_nc)); BThreadwiseCopy(b_grid_desc,
c_threadwise_copy.Run( ck::make_zero_multi_index<b_block_copy_dim>(),
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); GetBBlockDescriptor(k_per_block, n_per_block),
} ck::make_zero_multi_index<b_block_copy_dim>(),
BElementwiseOperation{});
blockwise_gemm.Run(a_block_desc,
a_block_buf, auto c_threadwise_copy =
make_zero_multi_index<a_block_copy_dim>(), CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block),
b_block_desc, ck::make_zero_multi_index<2>(),
b_block_buf, c_grid_desc,
make_zero_multi_index<b_block_copy_dim>(), ck::make_zero_multi_index<2>(),
c_block_desc, CElementwiseOperation{});
c_block_buf,
make_zero_multi_index<2>(), DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA),
i_kc != 0); MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB),
if((i_nc + n_per_block) < GemmN) MemAlignmentByte);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step); DeviceAlignedMemCPU c_block_mem(m_per_block * n_per_block * sizeof(FloatC),
MemAlignmentByte);
if constexpr(UseCLocalBuffer)
{ auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf),
ck::make_multi_index(i_mc, i_nc)); a_block_mem.mMemSize / sizeof(FloatA));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
} reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf),
} b_block_mem.mMemSize / sizeof(FloatB));
if((i_kc + k_per_block) < GemmK) auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step); UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
} : reinterpret_cast<FloatC*>(p_c_grid),
} UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
} : c_grid_desc.GetElementSpaceSize());
}
}; const ck::index_t tid = omp_get_thread_num();
} // namespace cpu for(ck::index_t i_gmpt = 0; i_gmpt < grid_m_per_thread; i_gmpt++)
} // namespace ck {
ck::index_t i_mc = (i_gmpt * total_threads + tid) * m_per_block;
#endif if(i_mc >= GemmM)
break;
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0));
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
{
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc,
ck::make_multi_index(0, i_kc, 0));
// TODO: if use local C buffer, then this nc loop need to loop only once
for(ck::index_t i_nc = 0; i_nc < GemmN; i_nc += n_per_block)
{
ck::index_t nc_size =
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
nc_size = math::integer_least_multiple(
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
b_threadwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
auto c_block_desc = UseCLocalBuffer
? GetCBlockDescriptor(mc_size, nc_size)
: c_grid_desc;
if constexpr(!UseCLocalBuffer)
{
c_threadwise_copy.SetSrcSliceOrigin(
c_block_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
blockwise_gemm.Run(a_block_desc,
a_block_buf,
make_zero_multi_index<a_block_copy_dim>(),
b_block_desc,
b_block_buf,
make_zero_multi_index<b_block_copy_dim>(),
c_block_desc,
c_block_buf,
make_zero_multi_index<2>(),
i_kc != 0);
if((i_nc + n_per_block) < GemmN)
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
}
if((i_kc + k_per_block) < GemmK)
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
}
}
}
}
}
};
} // namespace cpu
} // namespace ck
#endif
# device_conv2d_fwd_cpu_instance # device_conv2d_fwd_cpu_instance
set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
) )
add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC) target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC)
set_target_properties(device_conv2d_fwd_cpu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_cpu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_conv2d_fwd_cpu_instance LIBRARY DESTINATION lib) target_link_libraries(device_conv2d_fwd_cpu_instance PRIVATE /opt/rocm/llvm/lib/libomp.so)
target_compile_options(device_conv2d_fwd_cpu_instance PRIVATE -fopenmp=libomp -Wno-unused-command-line-argument)
clang_tidy_check(device_conv2d_fwd_cpu_instance) install(TARGETS device_conv2d_fwd_cpu_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_conv2d_fwd_cpu_instance)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment