Commit 9a7fa123 authored by carlushuang's avatar carlushuang
Browse files

support gcc with cpu only compile

parent ad09ebdb
...@@ -38,6 +38,7 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -38,6 +38,7 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
} }
}; };
#ifndef CK_NOGPU
// static buffer for vector // static buffer for vector
template <AddressSpaceEnum AddressSpace, template <AddressSpaceEnum AddressSpace,
typename S, typename S,
...@@ -151,6 +152,7 @@ struct StaticBufferTupleOfVector ...@@ -151,6 +152,7 @@ struct StaticBufferTupleOfVector
static_for<0, Number<numScalars>{}, 1>{}([&](auto i) { SetAsType(i, S{0}); }); static_for<0, Number<numScalars>{}, 1>{}([&](auto i) { SetAsType(i, S{0}); });
} }
}; };
#endif
template <AddressSpaceEnum AddressSpace, typename T, index_t N> template <AddressSpaceEnum AddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>) __host__ __device__ constexpr auto make_static_buffer(Number<N>)
......
#ifndef CK_SYNCHRONIZATION_AMD_HPP #ifndef CK_SYNCHRONIZATION_AMD_HPP
#define CK_SYNCHRONIZATION_AMD_HPP #define CK_SYNCHRONIZATION_AMD_HPP
#ifndef CK_NOGPU
#include "config.hpp" #include "config.hpp"
...@@ -19,3 +20,4 @@ __device__ void block_sync_lds() ...@@ -19,3 +20,4 @@ __device__ void block_sync_lds()
} // namespace ck } // namespace ck
#endif #endif
#endif
#pragma once #pragma once
#ifndef CK_NOGPU
#include "get_id.hpp" #include "get_id.hpp"
namespace ck { namespace ck {
...@@ -16,3 +17,4 @@ struct ThisThreadBlock ...@@ -16,3 +17,4 @@ struct ThisThreadBlock
}; };
} // namespace ck } // namespace ck
#endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#ifndef CK_NOGPU
namespace ck { namespace ck {
template <typename S, template <typename S,
...@@ -166,3 +167,4 @@ struct transpose_vectors<int8_t, NX, NY> ...@@ -166,3 +167,4 @@ struct transpose_vectors<int8_t, NX, NY>
} // namespace ck } // namespace ck
#endif #endif
#endif
\ No newline at end of file
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
#include <functional> #include <functional>
#include <thread> #include <thread>
#include <chrono> #include <chrono>
#include "ck/options.hpp"
#ifndef CK_NOGPU
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#endif
#include "stream_config.hpp" #include "stream_config.hpp"
#include "ck/options.hpp"
#ifndef CK_NOGPU
inline void hip_check_error(hipError_t x) inline void hip_check_error(hipError_t x)
{ {
if(x != hipSuccess) if(x != hipSuccess)
...@@ -36,22 +38,6 @@ struct DeviceMem ...@@ -36,22 +38,6 @@ struct DeviceMem
std::size_t mMemSize; std::size_t mMemSize;
}; };
struct DeviceAlignedMemCPU
{
DeviceAlignedMemCPU() = delete;
DeviceAlignedMemCPU(std::size_t mem_size, std::size_t alignment);
void* GetDeviceBuffer();
std::size_t GetBufferSize();
void ToDevice(const void* p);
void FromDevice(void* p);
void SetZero();
~DeviceAlignedMemCPU();
void* mpDeviceBuf;
std::size_t mMemSize;
std::size_t mAlignment;
};
struct KernelTimerImpl; struct KernelTimerImpl;
struct KernelTimer struct KernelTimer
...@@ -65,19 +51,6 @@ struct KernelTimer ...@@ -65,19 +51,6 @@ struct KernelTimer
std::unique_ptr<KernelTimerImpl> impl; std::unique_ptr<KernelTimerImpl> impl;
}; };
struct WallTimerImpl;
struct WallTimer
{
WallTimer();
~WallTimer();
void Start();
void End();
float GetElapsedTime() const;
std::unique_ptr<WallTimerImpl> impl;
};
using device_stream_t = hipStream_t; using device_stream_t = hipStream_t;
template <typename... Args, typename F> template <typename... Args, typename F>
...@@ -136,6 +109,36 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -136,6 +109,36 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
return 0; return 0;
#endif #endif
} }
#endif
struct DeviceAlignedMemCPU
{
DeviceAlignedMemCPU() = delete;
DeviceAlignedMemCPU(std::size_t mem_size, std::size_t alignment);
void* GetDeviceBuffer();
std::size_t GetBufferSize();
void ToDevice(const void* p);
void FromDevice(void* p);
void SetZero();
~DeviceAlignedMemCPU();
void* mpDeviceBuf;
std::size_t mMemSize;
std::size_t mAlignment;
};
struct WallTimerImpl;
struct WallTimer
{
WallTimer();
~WallTimer();
void Start();
void End();
float GetElapsedTime() const;
std::unique_ptr<WallTimerImpl> impl;
};
template <typename... Args, typename F> template <typename... Args, typename F>
void launch_cpu_kernel(F kernel, Args... args) void launch_cpu_kernel(F kernel, Args... args)
...@@ -162,4 +165,3 @@ float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args) ...@@ -162,4 +165,3 @@ float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args)
return timer.GetElapsedTime() / nrepeat; return timer.GetElapsedTime() / nrepeat;
} }
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <utility> #include <utility>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include "data_type.hpp" #include "common_header.hpp"
template <typename Range> template <typename Range>
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
......
#include <chrono> #include <chrono>
#include <assert.h>
#include <string.h>
#include "device.hpp" #include "device.hpp"
#ifndef CK_NOGPU
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{ {
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize)); hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
...@@ -24,45 +27,6 @@ void DeviceMem::SetZero() { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)) ...@@ -24,45 +27,6 @@ void DeviceMem::SetZero() { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize))
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); } DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }
DeviceAlignedMemCPU::DeviceAlignedMemCPU(std::size_t mem_size, std::size_t alignment)
: mMemSize(mem_size), mAlignment(alignment)
{
if(mem_size == 0)
{
mpDeviceBuf = nullptr;
}
else
{
assert(!(alignment == 0 || (alignment & (alignment - 1)))); // check pow of 2
void* p1;
void** p2;
int offset = alignment - 1 + sizeof(void*);
p1 = malloc(mem_size + offset);
assert(p1 != nullptr);
p2 = reinterpret_cast<void**>((reinterpret_cast<size_t>(p1) + offset) & ~(alignment - 1));
p2[-1] = p1;
mpDeviceBuf = reinterpret_cast<void*>(p2);
}
}
void* DeviceAlignedMemCPU::GetDeviceBuffer() { return mpDeviceBuf; }
std::size_t DeviceAlignedMemCPU::GetBufferSize() { return mMemSize; }
void DeviceAlignedMemCPU::ToDevice(const void* p) { memcpy(mpDeviceBuf, p, mMemSize); }
void DeviceAlignedMemCPU::FromDevice(void* p) { memcpy(p, mpDeviceBuf, mMemSize); }
void DeviceAlignedMemCPU::SetZero() { memset(mpDeviceBuf, 0, mMemSize); }
DeviceAlignedMemCPU::~DeviceAlignedMemCPU()
{
if(mpDeviceBuf != nullptr)
free((reinterpret_cast<void**>(mpDeviceBuf))[-1]);
}
struct KernelTimerImpl struct KernelTimerImpl
{ {
KernelTimerImpl() KernelTimerImpl()
...@@ -108,6 +72,46 @@ void KernelTimer::Start() { impl->Start(); } ...@@ -108,6 +72,46 @@ void KernelTimer::Start() { impl->Start(); }
void KernelTimer::End() { impl->End(); } void KernelTimer::End() { impl->End(); }
float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); } float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); }
#endif
DeviceAlignedMemCPU::DeviceAlignedMemCPU(std::size_t mem_size, std::size_t alignment)
: mMemSize(mem_size), mAlignment(alignment)
{
if(mem_size == 0)
{
mpDeviceBuf = nullptr;
}
else
{
assert(!(alignment == 0 || (alignment & (alignment - 1)))); // check pow of 2
void* p1;
void** p2;
int offset = alignment - 1 + sizeof(void*);
p1 = malloc(mem_size + offset);
assert(p1 != nullptr);
p2 = reinterpret_cast<void**>((reinterpret_cast<size_t>(p1) + offset) & ~(alignment - 1));
p2[-1] = p1;
mpDeviceBuf = reinterpret_cast<void*>(p2);
}
}
void* DeviceAlignedMemCPU::GetDeviceBuffer() { return mpDeviceBuf; }
std::size_t DeviceAlignedMemCPU::GetBufferSize() { return mMemSize; }
void DeviceAlignedMemCPU::ToDevice(const void* p) { memcpy(mpDeviceBuf, p, mMemSize); }
void DeviceAlignedMemCPU::FromDevice(void* p) { memcpy(p, mpDeviceBuf, mMemSize); }
void DeviceAlignedMemCPU::SetZero() { memset(mpDeviceBuf, 0, mMemSize); }
DeviceAlignedMemCPU::~DeviceAlignedMemCPU()
{
if(mpDeviceBuf != nullptr)
free((reinterpret_cast<void**>(mpDeviceBuf))[-1]);
}
struct WallTimerImpl struct WallTimerImpl
{ {
......
...@@ -6,8 +6,8 @@ set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE ...@@ -6,8 +6,8 @@ set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC) target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC)
set_target_properties(device_conv2d_fwd_cpu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_cpu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_link_libraries(device_conv2d_fwd_cpu_instance PRIVATE /opt/rocm/llvm/lib/libomp.so) target_link_libraries(device_conv2d_fwd_cpu_instance PRIVATE "${OMP_LIBRARY}")
target_compile_options(device_conv2d_fwd_cpu_instance PRIVATE -fopenmp=libomp -Wno-unused-command-line-argument) target_compile_options(device_conv2d_fwd_cpu_instance PRIVATE "${OMP_CXX_FLAG}")
install(TARGETS device_conv2d_fwd_cpu_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_cpu_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_conv2d_fwd_cpu_instance) clang_tidy_check(device_conv2d_fwd_cpu_instance)
#include <stdlib.h> #include <stdlib.h>
#include "convolution_forward_specialization_cpu.hpp" #include "config.hpp"
#include "config.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp" #include "device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp" #include "device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
namespace device_conv2d_fwd_avx2_instance { namespace device_conv2d_fwd_avx2_instance {
using InType = float; using InType = float;
using WeiType = float; using WeiType = float;
using OutType = float; using OutType = float;
using AccType = float; using AccType = float;
using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXCK8 using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXCK8
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu; using Relu = ck::tensor_operation::cpu::element_wise::Relu;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 = static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop = static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC = static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK; static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \ #define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf> DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
// clang-format on // clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, false)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, false)>;
// clang-format on // clang-format on
// use this in single thread, but gemm_n is not multiple of 8 // use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on // clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some // use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...) // time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 48, 24, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 48, 24, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, true, true, true),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true), // DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on // clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, false)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, false)>;
// clang-format on // clang-format on
// use this in single thread, but gemm_n is not multiple of 8 // use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on // clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some // use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...) // time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 48, 24, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 48, 24, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, true, true, true),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true), // DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on // clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances{});
} }
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances{});
} }
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances{});
} }
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances{});
} }
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances{});
} }
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances{});
} }
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#!/bin/bash
rm -f CMakeCache.txt
rm -f *.cmake
rm -rf CMakeFiles
AVX2_FLAGS='-m64 -mavx2 -mf16c -mfma -DHALF_ENABLE_F16C_INTRINSICS=0 '
rm -rf build/
mkdir build && cd build
MY_PROJECT_SOURCE=..
MY_PROJECT_INSTALL=../install.dir
rm -rf $MY_PROJECT_INSTALL
mkdir $MY_PROJECT_INSTALL
cmake \
-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \
-D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="$AVX2_FLAGS " \
-D CMAKE_CXX_COMPILER=g++ \
-D CMAKE_PREFIX_PATH=/usr/local \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D CK_NOGPU=ON \
${MY_PROJECT_SOURCE}
add_test_executable(test_conv2d_fwd_cpu conv2d_fwd_cpu.cpp) add_test_executable(test_conv2d_fwd_cpu conv2d_fwd_cpu.cpp)
target_link_libraries(test_conv2d_fwd_cpu PRIVATE host_tensor) target_link_libraries(test_conv2d_fwd_cpu PRIVATE host_tensor)
target_link_libraries(test_conv2d_fwd_cpu PRIVATE device_conv2d_fwd_cpu_instance) target_link_libraries(test_conv2d_fwd_cpu PRIVATE device_conv2d_fwd_cpu_instance)
# 3.13 introduce target_link_directories, which is better # 3.13 introduce target_link_directories, which is better
set_target_properties(test_conv2d_fwd_cpu PROPERTIES LINK_FLAGS -Wl,-rpath,/opt/rocm/llvm/lib ) set_target_properties(test_conv2d_fwd_cpu PROPERTIES LINK_FLAGS "${OMP_LINK_FLAG}")
target_link_libraries(test_conv2d_fwd_cpu PRIVATE /opt/rocm/llvm/lib/libomp.so) target_link_libraries(test_conv2d_fwd_cpu PRIVATE "${OMP_LIBRARY}")
target_compile_options(test_conv2d_fwd_cpu PRIVATE -fopenmp=libomp -Wno-unused-command-line-argument) target_compile_options(test_conv2d_fwd_cpu PRIVATE "${OMP_CXX_FLAG}")
add_test_executable(test_cpu_gemm_uk cpu_gemm_uk.cpp) add_test_executable(test_cpu_gemm_uk cpu_gemm_uk.cpp)
target_link_libraries(test_cpu_gemm_uk PRIVATE host_tensor) target_link_libraries(test_cpu_gemm_uk PRIVATE host_tensor)
# 3.13 introduce target_link_directories, which is better # 3.13 introduce target_link_directories, which is better
set_target_properties(test_cpu_gemm_uk PROPERTIES LINK_FLAGS -Wl,-rpath,/opt/rocm/llvm/lib ) set_target_properties(test_cpu_gemm_uk PROPERTIES LINK_FLAGS "${OMP_LINK_FLAG}")
target_link_libraries(test_cpu_gemm_uk PRIVATE /opt/rocm/llvm/lib/libomp.so) target_link_libraries(test_cpu_gemm_uk PRIVATE "${OMP_LIBRARY}")
target_compile_options(test_cpu_gemm_uk PRIVATE -fopenmp=libomp -Wno-unused-command-line-argument) target_compile_options(test_cpu_gemm_uk PRIVATE "${OMP_CXX_FLAG}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment