Commit b8ba0239 authored by carlushuang's avatar carlushuang
Browse files

support multi-thread

parent e06b9871
......@@ -46,7 +46,7 @@ struct BlockwiseGemmAvx2_MxN
using CCoord = decltype(make_tensor_coordinate(CDesc{}, IndexC{}));
template <typename TensorDesc>
constexpr auto GetLeadingElement(const TensorDesc& desc)
static constexpr auto GetLeadingElement(const TensorDesc& desc)
{
// if use this function, make sure desc are known at compile time.
// otherwise, it is not efficient to calculate leading dim here
......@@ -63,12 +63,12 @@ struct BlockwiseGemmAvx2_MxN
}
}
ck::index_t GetALeadingElement(const ABlockDesc& a_block_desc) const
static ck::index_t GetALeadingElement(const ABlockDesc& a_block_desc)
{
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetBLeadingElement(const BBlockDesc& b_block_desc) const
static ck::index_t GetBLeadingElement(const BBlockDesc& b_block_desc)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
......@@ -84,12 +84,12 @@ struct BlockwiseGemmAvx2_MxN
}
}
ck::index_t GetCLeadingElement(const CDesc& c_desc) const
static ck::index_t GetCLeadingElement(const CDesc& c_desc)
{
return c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc) const
static ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
......@@ -104,7 +104,7 @@ struct BlockwiseGemmAvx2_MxN
}
}
ck::index_t GetKPerBlock(const ABlockDesc& a_block_desc) const
static ck::index_t GetKPerBlock(const ABlockDesc& a_block_desc)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
......@@ -119,7 +119,7 @@ struct BlockwiseGemmAvx2_MxN
}
}
ck::index_t GetNPerBlock(const BBlockDesc& b_block_desc) const
static ck::index_t GetNPerBlock(const BBlockDesc& b_block_desc)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
......@@ -135,8 +135,8 @@ struct BlockwiseGemmAvx2_MxN
}
}
ck::index_t
GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t) const
static ck::index_t
GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
......@@ -149,8 +149,8 @@ struct BlockwiseGemmAvx2_MxN
}
}
ck::index_t
GetBBlockStartOffset(const BBlockDesc& b_block_desc, const index_t, const index_t i_n) const
static ck::index_t
GetBBlockStartOffset(const BBlockDesc& b_block_desc, const index_t, const index_t i_n)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
......@@ -165,14 +165,14 @@ struct BlockwiseGemmAvx2_MxN
}
}
ck::index_t
GetCBlockStartOffset(const CDesc& c_desc, const index_t i_m, const index_t i_n) const
static ck::index_t
GetCBlockStartOffset(const CDesc& c_desc, const index_t i_m, const index_t i_n)
{
return i_m * c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n;
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CBuffer>
void Run(const ABlockDesc& a_block_desc,
static void Run(const ABlockDesc& a_block_desc,
const ABlockBuffer& a_block_buf,
const IndexA& /* a_origin */,
......@@ -184,7 +184,7 @@ struct BlockwiseGemmAvx2_MxN
CBuffer& c_buf,
const IndexC& /* c_origin */,
bool is_accumulate_c = true) const
bool is_accumulate_c = true)
{
auto lda = GetALeadingElement(a_block_desc) * sizeof(FloatA);
auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB);
......
......@@ -9,7 +9,9 @@
#include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <utility>
#include <unistd.h>
#include <omp.h>
namespace ck {
namespace cpu {
......@@ -168,19 +170,61 @@ struct GridwiseGemmAvx2_MxN
constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension();
auto a_threadwise_copy = AThreadwiseCopy(a_grid_desc,
auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize());
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize());
auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
auto blockwise_gemm = BlockwiseGemmAvx2_MxN<
FloatA, // FloatA,
FloatB, // FloatB,
FloatC, // FloatC,
decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block)), // CBlockDesc,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
int total_threads = omp_get_max_threads();
// TODO: openmp aware ordering
//
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
{
auto a_move_k_step = ck::make_multi_index(0, k_per_block);
auto b_move_k_step = ck::make_multi_index(0, k_per_block, 0);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block);
const ck::index_t grid_size = grid_m * grid_n;
const ck::index_t grids_per_thread =
math::integer_divide_ceil(grid_size, total_threads);
// This version does not consider K panel re-usage. simple for openmp
#pragma omp parallel
{
auto a_threadwise_copy =
AThreadwiseCopy(a_grid_desc,
ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block),
ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{});
auto b_threadwise_copy = BThreadwiseCopy(b_grid_desc,
auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(),
GetBBlockDescriptor(k_per_block, n_per_block),
ck::make_zero_multi_index<b_block_copy_dim>(),
BElementwiseOperation{});
auto c_threadwise_copy = CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block),
auto c_threadwise_copy =
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block),
ck::make_zero_multi_index<2>(),
c_grid_desc,
ck::make_zero_multi_index<2>(),
......@@ -193,15 +237,6 @@ struct GridwiseGemmAvx2_MxN
DeviceAlignedMemCPU c_block_mem(m_per_block * n_per_block * sizeof(FloatC),
MemAlignmentByte);
auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize());
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize());
auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf),
a_block_mem.mMemSize / sizeof(FloatA));
......@@ -216,32 +251,14 @@ struct GridwiseGemmAvx2_MxN
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
: c_grid_desc.GetElementSpaceSize());
auto blockwise_gemm = BlockwiseGemmAvx2_MxN<
FloatA, // FloatA,
FloatB, // FloatB,
FloatC, // FloatC,
decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block)), // CBlockDesc,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
const ck::index_t tid = omp_get_thread_num();
// TODO: openmp aware ordering
//
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
for(ck::index_t i_gpt = 0; i_gpt < grids_per_thread; i_gpt++)
{
auto a_move_k_step = ck::make_multi_index(0, k_per_block);
auto b_move_k_step = ck::make_multi_index(0, k_per_block, 0);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block);
ck::index_t gid = i_gpt * total_threads + tid;
if(gid >= grid_size)
break;
const ck::index_t grid_size = grid_m * grid_n;
// This version does not consider K panel re-usage. simple for openmp
#pragma omp parallel for
for(ck::index_t gid = 0; gid < grid_size; gid++)
{
ck::index_t i_mc = (gid / grid_n) * m_per_block;
ck::index_t i_nc = (gid % grid_n) * n_per_block;
......@@ -254,7 +271,8 @@ struct GridwiseGemmAvx2_MxN
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0));
b_threadwise_copy.SetSrcSliceOrigin(
b_grid_desc,
ck::make_multi_index(math::integer_divide_ceil(
ck::make_multi_index(
math::integer_divide_ceil(
i_nc, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
0,
0));
......@@ -280,7 +298,8 @@ struct GridwiseGemmAvx2_MxN
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
// printf("==> i_m:%d, i_n:%d, i_k:%d, mc:%d, nc:%d, kc:%d(%d, %d)\n", i_mc,
// printf("[tid:%d]==> i_m:%d, i_n:%d, i_k:%d, mc:%d, nc:%d, kc:%d(%d,
// %d)\n", tid, i_mc,
// i_nc, i_kc, mc_size, nc_size, kc_size, KPerBlock, GemmK); fflush(stdout);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
......@@ -336,6 +355,7 @@ struct GridwiseGemmAvx2_MxN
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
}
}
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value)
{
auto a_move_k_step = ck::make_multi_index(0, k_per_block);
......@@ -345,10 +365,61 @@ struct GridwiseGemmAvx2_MxN
0,
0);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_m_per_thread = math::integer_divide_ceil(grid_m, total_threads);
// only parallel in gemm m dim
#pragma omp parallel for
for(ck::index_t i_mc = 0; i_mc < GemmM; i_mc += m_per_block)
#pragma omp parallel
{
auto a_threadwise_copy =
AThreadwiseCopy(a_grid_desc,
ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block),
ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{});
auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(),
GetBBlockDescriptor(k_per_block, n_per_block),
ck::make_zero_multi_index<b_block_copy_dim>(),
BElementwiseOperation{});
auto c_threadwise_copy =
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block),
ck::make_zero_multi_index<2>(),
c_grid_desc,
ck::make_zero_multi_index<2>(),
CElementwiseOperation{});
DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA),
MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB),
MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem(m_per_block * n_per_block * sizeof(FloatC),
MemAlignmentByte);
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf),
a_block_mem.mMemSize / sizeof(FloatA));
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf),
b_block_mem.mMemSize / sizeof(FloatB));
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
: reinterpret_cast<FloatC*>(p_c_grid),
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
: c_grid_desc.GetElementSpaceSize());
const ck::index_t tid = omp_get_thread_num();
for(ck::index_t i_gmpt = 0; i_gmpt < grid_m_per_thread; i_gmpt++)
{
ck::index_t i_mc = (i_gmpt * total_threads + tid) * m_per_block;
if(i_mc >= GemmM)
break;
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0));
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
......@@ -368,18 +439,19 @@ struct GridwiseGemmAvx2_MxN
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
nc_size = math::integer_least_multiple(
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
b_threadwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
auto c_block_desc =
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
auto c_block_desc = UseCLocalBuffer
? GetCBlockDescriptor(mc_size, nc_size)
: c_grid_desc;
if constexpr(!UseCLocalBuffer)
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.SetSrcSliceOrigin(
c_block_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
......@@ -400,8 +472,8 @@ struct GridwiseGemmAvx2_MxN
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
......@@ -413,6 +485,7 @@ struct GridwiseGemmAvx2_MxN
}
}
}
}
};
} // namespace cpu
......
......@@ -5,6 +5,8 @@ set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC)
set_target_properties(device_conv2d_fwd_cpu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_link_libraries(device_conv2d_fwd_cpu_instance PRIVATE /opt/rocm/llvm/lib/libomp.so)
target_compile_options(device_conv2d_fwd_cpu_instance PRIVATE -fopenmp=libomp -Wno-unused-command-line-argument)
install(TARGETS device_conv2d_fwd_cpu_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_conv2d_fwd_cpu_instance)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment