Unverified Commit 81c942cd authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Deprecate static kernel (#42)

* deprecate static kernels
parent b8b2d0a6
#ifndef CK_THREADWISE_DIRECT_CONVOLUTION_HPP
#define CK_THREADWISE_DIRECT_CONVOLUTION_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace ck {
// optimized for scenario if p_in, p_wei, p_out are in register
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution_1(InDesc,
TInWei* const __restrict__ p_in,
WeiDesc,
TInWei* const __restrict__ p_wei,
OutDesc,
TOut* __restrict__ p_out)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
#if 0
if(blockIdx.x == 0 && get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_desc, "threadwise_direct_convolution: in_desc: ");
print_ConstantTensorDescriptor(wei_desc, "threadwise_direct_convolution: wei_desc: ");
print_ConstantTensorDescriptor(out_desc, "threadwise_direct_convolution: out_desc: ");
}
#endif
for(index_t n = 0; n < out_desc.GetLength(I0); ++n)
{
for(index_t k = 0; k < out_desc.GetLength(I1); ++k)
{
for(index_t ho = 0; ho < out_desc.GetLength(I2); ++ho)
{
for(index_t wo = 0; wo < out_desc.GetLength(I3); ++wo)
{
for(index_t c = 0; c < wei_desc.GetLength(I1); ++c)
{
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
{
const index_t hi = ho + y;
const index_t wi = wo + x;
const index_t in_index =
in_desc.GetOffsetFromMultiIndex(n, c, hi, wi);
const index_t wei_index =
wei_desc.GetOffsetFromMultiIndex(k, c, y, x);
const index_t out_index =
out_desc.GetOffsetFromMultiIndex(n, k, ho, wo);
fused_multiply_accumulate(
p_out[out_index], p_wei[wei_index], p_in[in_index]);
}
}
}
}
}
}
}
}
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
// Copy in and wei into register before doing convolution
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution_2(InDesc,
TInWei* const __restrict__ p_in,
WeiDesc,
TInWei* const __restrict__ p_wei,
OutDesc,
TOut* __restrict__ p_out)
{
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr auto in_reg_desc = make_ConstantTensorDescriptor_packed(in_desc.GetLengths());
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor_packed(wei_desc.GetLengths());
// register
TInWei p_in_reg[in_reg_desc.GetElementSpace()];
TInWei p_wei_reg[wei_reg_desc.GetElementSpace()];
// copy input tensor into register
threadwise_tensor_slice_copy(
in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths(), Number<1>{});
// copy input tensor into register
threadwise_tensor_slice_copy(
wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths(), Number<1>{});
// do convolution
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
}
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
// load 1x1 weight into register, and do 1x1 convolution in register.
template <class Data, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution_3(InDesc,
Data* const __restrict__ p_in,
WeiDesc,
Data* const __restrict__ p_wei,
OutDesc,
Data* __restrict__ p_out)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr auto in_reg_desc = make_ConstantTensorDescriptor(Sequence<in_desc.GetLength(I0),
in_desc.GetLength(I1),
out_desc.GetLength(I2),
out_desc.GetLength(I3)>{});
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(
Sequence<wei_desc.GetLength(I0), wei_desc.GetLength(I1), 1, 1>{});
Data p_in_reg[in_reg_desc.GetElementSpace()];
Data p_wei_reg[wei_reg_desc.GetElementSpace()];
constexpr index_t in_w_new_read = 1;
constexpr auto in_desc_reg_new_read =
make_ConstantTensorDescriptor(Sequence<in_reg_desc.GetLength(I0),
in_reg_desc.GetLength(I1),
in_reg_desc.GetLength(I2),
in_w_new_read>{});
#if 0
// this verison reused old input data in register, and read new data from LDS
// loop over vertical direction
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
// read first input
threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
in_reg_desc,
p_in_reg,
in_reg_desc.GetLengths());
// read first 1x1 weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
wei_reg_desc,
p_wei_reg,
wei_reg_desc.GetLengths());
// do first 1x1 conv
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
// loop over horizontal direction
for(index_t x = 1; x < wei_desc.GetLength(I3); ++x)
{
// read new weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
wei_reg_desc,
p_wei_reg,
wei_reg_desc.GetLengths());
// shift old input to the left
threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number<in_w_new_read>{});
// read new input
threadwise_4d_tensor_copy(
in_desc,
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1),
in_reg_desc,
p_in_reg +
in_reg_desc.GetOffsetFromMultiIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
in_desc_reg_new_read.GetLengths());
// do 1x1 conv
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
}
}
#elif 1
// this version read all input from LDS when filter moves
// loop over vertical direction
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
// loop over horizontal direction
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
{
// read new weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
wei_reg_desc,
p_wei_reg,
wei_reg_desc.GetLengths());
// read new input
threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x),
in_reg_desc,
p_in_reg,
in_reg_desc.GetLengths());
// do 1x1 conv
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
}
}
#endif
}
} // namespace ck
#endif
#ifndef CK_THREADWISE_GEMM_HPP
#define CK_THREADWISE_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
namespace ck {
template <typename Float, class Matrix>
__device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
{
for(index_t i = 0; i < Matrix::NRow(); ++i)
{
for(index_t j = 0; j < Matrix::NCol(); ++j)
{
const index_t id = Matrix::CalculateOffset(i, j);
p_thread[id] = Float(0);
}
}
}
template <typename SrcMatrix,
typename DstMatrix,
index_t NSliceRow,
index_t NSliceCol,
index_t DataPerAccess>
struct ThreadwiseMatrixSliceCopy
{
__device__ constexpr ThreadwiseMatrixSliceCopy()
{
static_assert(SrcMatrix::RowStride() % DataPerAccess == 0 &&
DstMatrix::RowStride() % DataPerAccess == 0,
"wrong! wrong alignment");
static_assert(NSliceCol % DataPerAccess == 0,
"wrong! should be NSliceCol % DataPerAccess == 0");
}
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
using vector_t = typename vector_type<Data, DataPerAccess>::type;
for(index_t i = 0; i < NSliceRow; ++i)
{
for(index_t j = 0; j < NSliceCol; j += DataPerAccess)
{
const index_t src_index = SrcMatrix::CalculateOffset(i, j);
const index_t dst_index = DstMatrix::CalculateOffset(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
}
}
}
};
// C += transpose(A) * B
// Element of matrix can be vectorized data
template <typename MatrixA, typename MatrixB, typename MatrixC>
struct ThreadwiseGemmTransANormalBNormalC
{
__device__ constexpr ThreadwiseGemmTransANormalBNormalC()
{
static_assert(MatrixA::NRow() == MatrixB::NRow() && MatrixA::NCol() == MatrixC::NRow() &&
MatrixB::NCol() == MatrixC::NCol(),
"wrong!");
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
constexpr index_t M = MatrixC::NRow();
constexpr index_t N = MatrixC::NCol();
constexpr index_t K = MatrixA::NRow(); // A is transposed
for(index_t k = 0; k < K; ++k)
{
for(index_t m = 0; m < M; ++m)
{
for(index_t n = 0; n < N; ++n)
{
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed
const index_t bindex = MatrixB::CalculateOffset(k, n);
const index_t cindex = MatrixC::CalculateOffset(m, n);
p_c[cindex] +=
inner_product_with_conversion<FloatC>{}(p_a[aindex], p_b[bindex]);
}
}
}
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
constexpr index_t M = MatrixC::NRow();
constexpr index_t N = MatrixC::NCol();
constexpr index_t K = MatrixA::NRow(); // A is transposed
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
for(index_t k = 0; k < K; ++k)
{
for(index_t m = 0; m < M; ++m)
{
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed
static_if<N == 2>{}([&](auto) {
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0);
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1);
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0);
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1);
amd_assembly_outer_product_1x2(
p_a[aindex], p_b[bindex_0], p_b[bindex_1], p_c[cindex_0], p_c[cindex_1]);
});
static_if<N == 4>{}([&](auto) {
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0);
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1);
const index_t bindex_2 = MatrixB::CalculateOffset(k, 2);
const index_t bindex_3 = MatrixB::CalculateOffset(k, 3);
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0);
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1);
const index_t cindex_2 = MatrixC::CalculateOffset(m, 2);
const index_t cindex_3 = MatrixC::CalculateOffset(m, 3);
amd_assembly_outer_product_1x4(p_a[aindex],
p_b[bindex_0],
p_b[bindex_1],
p_b[bindex_2],
p_b[bindex_3],
p_c[cindex_0],
p_c[cindex_1],
p_c[cindex_2],
p_c[cindex_3]);
});
}
}
}
#endif
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr bool has_amd_asm = is_same<FloatC, float>{} &&
((is_same<FloatA, float>{} && is_same<FloatB, float>{}) ||
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
static_if<has_amd_asm>{}([&](auto fwd) { Run_amd_asm(p_a, p_b, fwd(p_c)); })
.Else([&](auto) { Run_source(p_a, p_b, p_c); });
#else
Run_source(p_a, p_b, p_c);
#endif
}
};
} // namespace ck
#endif
#ifndef CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
namespace ck {
template <class Float, class TDesc>
__device__ void threadwise_generic_tensor_set_zero(TDesc, Float* __restrict__ p)
{
static_ford<decltype(TDesc::GetLengths())>{}([&](auto multi_id) {
constexpr index_t offset = TDesc::GetOffsetFromMultiIndex(multi_id);
p[offset] = static_cast<Float>(0);
});
}
} // namespace ck
#endif
#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_coordinate.hpp"
namespace ck {
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimensions of vector access should be the same on src and dst.
// The dimension access order should be the same on src and dst.
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template <typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename SrcDstDimAccessOrder,
index_t SrcDstVectorReadWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::Generic,
AddressSpace DstAddressSpace = AddressSpace::Generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set,
index_t SrcDataStride = 1,
index_t DstDataStride = 1>
struct ThreadwiseGenericTensorSliceCopy_v4r2
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = typename TensorCoordinate<SrcDesc>::type;
using DstCoord = typename TensorCoordinate<DstDesc>::type;
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2(const Index& src_slice_origin,
const Index& dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
nDim == SrcDstDimAccessOrder::Size(),
"wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<SrcDstDimAccessOrder>{}, "wrong! map is not valid");
static_assert(SliceLengths{}[SrcDstVectorReadWriteDim] %
math::lcm(SrcDataPerRead, DstDataPerWrite) ==
0,
"wrong! cannot evenly divide");
// TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2()
: ThreadwiseGenericTensorSliceCopy_v4r2(make_zero_multi_index<nDim>(),
make_zero_multi_index<nDim>())
{
}
__device__ void SetSrcSliceOrigin(SrcCoord src_slice_origin)
{
mSrcSliceOrigin = src_slice_origin;
}
__device__ void SetDstSliceOrigin(DstCoord dst_slice_origin)
{
mDstSliceOrigin = dst_slice_origin;
}
template <typename SrcData, typename DstData>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
{
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}(
[&](auto long_vector_access_id) {
// data id w.r.t slicing-window
auto long_vector_data_begin_id = long_vector_access_id;
long_vector_data_begin_id(vector_access_dim) =
long_vector_size * long_vector_access_id[vector_access_dim];
// buffer to hold a src long-vector
SrcData p_src_long_vector[long_vector_size];
// load data from src to the long-vector buffer
static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) {
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access;
const index_t buffer_offset = i * src_data_per_access;
const auto src_coord =
mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
SrcDataStride,
1>(p_src,
src_coord.GetOffset(),
src_coord.IsOffsetValidAssumingUpperIndexIsValid(),
SrcDesc::GetElementSpace(),
p_src_long_vector,
buffer_offset,
true,
long_vector_size);
});
// SrcData to DstData conversion
DstData p_dst_long_vector[long_vector_size];
static_for<0, long_vector_size, 1>{}([&](auto i) {
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
});
// store data from the long-vector buffer to dst
static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) {
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access;
const index_t buffer_offset = i * dst_data_per_access;
const auto dst_coord =
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check dst data's valid mapping situation, only check the first data in this
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride>(p_dst_long_vector,
buffer_offset,
true,
long_vector_size,
p_dst,
dst_coord.GetOffset(),
dst_coord.IsOffsetValidAssumingUpperIndexIsValid(),
DstDesc::GetElementSpace());
});
});
}
template <typename T, bool PositiveDirection>
__device__ void MoveSrcSliceWindow(const T& step_sizes_,
integral_constant<bool, PositiveDirection>)
{
const auto step_sizes = to_multi_index(step_sizes_);
static_if<PositiveDirection>{}([&](auto) { mSrcSliceOrigin += to_multi_index(step_sizes); })
.Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
}
template <typename T, bool PositiveDirection>
__device__ void MoveDstSliceWindow(const T& step_sizes_,
integral_constant<bool, PositiveDirection>)
{
const auto step_sizes = to_multi_index(step_sizes_);
static_if<PositiveDirection>{}([&](auto) { mDstSliceOrigin += step_sizes; })
.Else([&](auto) { mDstSliceOrigin -= step_sizes; });
}
private:
SrcCoord mSrcSliceOrigin;
DstCoord mDstSliceOrigin;
};
} // namespace ck
#endif
......@@ -2,7 +2,6 @@
#define CK_XDLOPS_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
#include "amd_xdlops.hpp"
......
#ifndef CK_AMD_BUFFER_ADDRESSING_HPP
#define CK_AMD_BUFFER_ADDRESSING_HPP
#include "float_type.hpp"
#include "amd_buffer_addressing_v2.hpp"
namespace ck {
template <typename T>
union BufferResource
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data;
T* address[2];
int32_t range[4];
int32_t config[4];
};
__device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.f32");
__device__ float2_t
__llvm_amdgcn_buffer_load_f32x2(int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v2f32");
__device__ float4_t
__llvm_amdgcn_buffer_load_f32x4(int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v4f32");
__device__ half_t
__llvm_amdgcn_raw_buffer_load_f16(int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
__device__ ushort
__llvm_amdgcn_raw_buffer_load_bf16(int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.bf16");
__device__ void __llvm_amdgcn_buffer_store_f32(float vdata,
int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.f32");
__device__ void __llvm_amdgcn_buffer_store_f32x2(float2_t vdata,
int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v2f32");
__device__ void __llvm_amdgcn_buffer_store_f32x4(float4_t vdata,
int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
__device__ void
__llvm_amdgcn_raw_buffer_store_f16(half_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
__device__ void
__llvm_amdgcn_raw_buffer_store_bf16(ushort vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.bf16");
#if CK_USE_AMD_BUFFER_ATOMIC_FADD
#if CK_HIP_VERSION_FLAT >= 3010020405
// starting ROCm-3.10, the return type becomes float
__device__ float
#else
__device__ void
#endif
__llvm_amdgcn_buffer_atomic_add_f32(float vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32");
#endif
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ typename vector_type<T, VectorSize>::type amd_buffer_load(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_elemenst_space);
// buffer_store requires:
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
// 2) p_dst_thread to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ void amd_buffer_store(const T* p_src_thread,
T* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range);
// buffer_atomic requires:
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
// 2) p_dst_thread to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ void amd_buffer_atomic_add(const T* p_src_thread,
T* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range);
template <>
__device__ float amd_buffer_load<float, 1>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<float> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#else
float tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? tmp : float(0);
#endif
}
template <>
__device__ float2_t amd_buffer_load<float, 2>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<float> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#else
float2_t tmp = __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? tmp : float2_t(0);
#endif
}
template <>
__device__ float4_t amd_buffer_load<float, 4>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<float> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#else
float4_t tmp = __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? tmp : float4_t(0);
#endif
}
template <>
__device__ half_t amd_buffer_load<half_t, 1>(const half_t* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return __llvm_amdgcn_raw_buffer_load_f16(
src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0);
#else
half_t zero(0);
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return src_thread_data_valid ? __llvm_amdgcn_raw_buffer_load_f16(
src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0)
: zero;
#endif
}
template <>
__device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
return *reinterpret_cast<half2_t*>(&dst_out_tmp);
#else
half2_t zeros(0);
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<half2_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
__device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
return *reinterpret_cast<half4_t*>(&dst_out_tmp);
#else
half4_t zeros(0);
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<half4_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
__device__ half8_t amd_buffer_load<half_t, 8>(const half_t* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
return *reinterpret_cast<half8_t*>(&dst_out_tmp);
#else
half8_t zeros(0);
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<half8_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
__device__ ushort amd_buffer_load<ushort, 1>(const ushort* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return __llvm_amdgcn_raw_buffer_load_bf16(
src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0);
#else
ushort zero(0);
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return src_thread_data_valid ? __llvm_amdgcn_raw_buffer_load_bf16(
src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0)
: zero;
#endif
}
template <>
__device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
return *reinterpret_cast<ushort2_t*>(&dst_out_tmp);
#else
ushort2_t zeros(0);
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<ushort2_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
__device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
return *reinterpret_cast<ushort4_t*>(&dst_out_tmp);
#else
ushort4_t zeros(0);
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<ushort4_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
__device__ ushort8_t amd_buffer_load<ushort, 8>(const ushort* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
return *reinterpret_cast<ushort8_t*>(&dst_out_tmp);
#else
ushort8_t zeros(0);
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<ushort8_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
__device__ void amd_buffer_store<float, 1>(const float* p_src_thread,
float* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32(*p_src_thread,
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32(
*p_src_thread, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
template <>
__device__ void amd_buffer_store<float, 2>(const float* p_src_thread,
float* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast<const float2_t*>(p_src_thread),
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast<const float2_t*>(p_src_thread),
dst_wave_buffer_resource.data,
0,
dst_thread_addr_offset,
false,
false);
}
#endif
}
template <>
__device__ void amd_buffer_store<float, 4>(const float* p_src_thread,
float* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast<const float4_t*>(p_src_thread),
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast<const float4_t*>(p_src_thread),
dst_wave_buffer_resource.data,
0,
dst_thread_addr_offset,
false,
false);
}
#endif
}
template <>
__device__ void amd_buffer_store<half_t, 1>(const half_t* p_src_thread,
half_t* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
__llvm_amdgcn_raw_buffer_store_f16(*p_src_thread,
dst_wave_buffer_resource.data,
dst_addr_shift + dst_thread_addr_offset,
0,
0);
#else
if(dst_thread_data_valid)
{
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
__llvm_amdgcn_raw_buffer_store_f16(
*p_src_thread, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0);
}
#endif
}
template <>
__device__ void amd_buffer_store<half_t, 2>(const half_t* p_src_thread,
half_t* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
const float* p_src_tmp = reinterpret_cast<const float*>(p_src_thread);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
template <>
__device__ void amd_buffer_store<half_t, 4>(const half_t* p_src_thread,
half_t* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
const float2_t* p_src_tmp = reinterpret_cast<const float2_t*>(p_src_thread);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x2(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
template <>
__device__ void amd_buffer_store<half_t, 8>(const half_t* p_src_thread,
half_t* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
const float4_t* p_src_tmp = reinterpret_cast<const float4_t*>(p_src_thread);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x4(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x4(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
template <>
__device__ void amd_buffer_store<ushort, 1>(const ushort* p_src_thread,
ushort* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_bf16(*p_src_thread,
dst_wave_buffer_resource.data,
dst_addr_shift + dst_thread_addr_offset,
0,
0);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_raw_buffer_store_bf16(
*p_src_thread, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0);
}
#endif
}
template <>
__device__ void amd_buffer_store<ushort, 2>(const ushort* p_src_thread,
ushort* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
const float* p_src_tmp = reinterpret_cast<const float*>(p_src_thread);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
template <>
__device__ void amd_buffer_store<ushort, 4>(const ushort* p_src_thread,
ushort* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
const float2_t* p_src_tmp = reinterpret_cast<const float2_t*>(p_src_thread);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x2(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
template <>
__device__ void amd_buffer_store<ushort, 8>(const ushort* p_src_thread,
ushort* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
const float4_t* p_src_tmp = reinterpret_cast<const float4_t*>(p_src_thread);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x4(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x4(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
#if CK_USE_AMD_BUFFER_ATOMIC_FADD
template <>
__device__ void amd_buffer_atomic_add<float, 1>(const float* p_src_thread,
float* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_atomic_add_f32(*p_src_thread,
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_atomic_add_f32(
*p_src_thread, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false);
}
#endif
}
template <>
__device__ void amd_buffer_atomic_add<float, 2>(const float* p_src_thread,
float* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range;
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
for(index_t i = 0; i < 2; ++i)
{
__llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i],
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset +
i * sizeof(float),
false);
}
#else
if(dst_thread_data_valid)
{
for(index_t i = 0; i < 2; ++i)
{
__llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i],
dst_wave_buffer_resource.data,
0,
dst_thread_addr_offset + i * sizeof(float),
false);
}
}
#endif
}
template <>
__device__ void amd_buffer_atomic_add<float, 4>(const float* p_src_thread,
float* p_dst_wave,
index_t dst_thread_data_offset,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
for(index_t i = 0; i < 4; ++i)
{
__llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i],
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset +
i * sizeof(float),
false);
}
#else
if(dst_thread_data_valid)
{
for(index_t i = 0; i < 4; ++i)
{
__llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i],
dst_wave_buffer_resource.data,
0,
dst_thread_addr_offset + i * sizeof(float),
false);
}
}
#endif
}
#endif // CK_USE_AMD_BUFFER_ATOMIC_FADD
} // namespace ck
#endif
......@@ -23,6 +23,48 @@ amd_inner_product_dlop<float, float, float>(const float& a, const float& b, floa
#endif
}
template <>
__device__ void
amd_inner_product_dlop<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_inner_product_dlop(vector_type<float, 2>{a}.AsType<float>()[I0],
vector_type<float, 2>{b}.AsType<float>()[I0],
c);
amd_inner_product_dlop(vector_type<float, 2>{a}.AsType<float>()[I1],
vector_type<float, 2>{b}.AsType<float>()[I1],
c);
}
template <>
__device__ void
amd_inner_product_dlop<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I0],
vector_type<float, 4>{b}.AsType<float>()[I0],
c);
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I1],
vector_type<float, 4>{b}.AsType<float>()[I1],
c);
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I2],
vector_type<float, 4>{b}.AsType<float>()[I2],
c);
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I3],
vector_type<float, 4>{b}.AsType<float>()[I3],
c);
}
#if CK_USE_AMD_DLOP
template <>
__device__ void
......
......@@ -13,7 +13,6 @@
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional4.hpp"
#include "in_memory_operation.hpp"
#include "integral_constant.hpp"
#include "math.hpp"
#include "number.hpp"
......@@ -25,6 +24,7 @@
#include "type.hpp"
#include "utility.hpp"
#include "magic_division.hpp"
#include "amd_buffer_addressing_v2.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
......
#ifndef CK_CONFIG_NVIDIA_HPP
#define CK_CONFIG_NVIDIA_HPP
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <nvToolsExt.h>
// index type: unsigned or signed
#define CK_UNSIGNED_INDEX_TYPE 0
// device backend
#define CK_DEVICE_BACKEND_NVIDIA 1
// disable AMD inline asm and intrinsic
#define CK_USE_AMD_INLINE_ASM 0
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0
#define CK_USE_AMD_BUFFER_ADDRESSING 0
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 0
#define CK_USE_AMD_XDLOPS 0
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
#define CK_USE_AMD_XDLOPS_EMULATE 0
// experimental implementation
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 0
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
#define CK_EXPERIMENTAL_THREADWISE_COPY_V4R2_USE_OPTIMIZED_ADDRESS_CACLULATION 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck {
enum AddressSpace
{
Generic,
Global,
Lds,
Vgpr
};
enum InMemoryDataOperation
{
Set,
AtomicAdd
};
#if CK_UNSIGNED_INDEX_TYPE
using index_t = uint32_t;
#else
using index_t = int32_t;
#endif
} // namespace ck
#endif
#ifndef CK_FLOAT_TYPE_NVIDIA_HPP
#define CK_FLOAT_TYPE_NVIDIA_HPP
#include "number.hpp"
namespace ck {
// For some reason, CUDA need this definition, otherwise
// compiler won't generate optimal load and store instruction, and
// kernel would produce wrong result, indicating the compiler fail to generate correct
// instruction,
// float
using float2_t = float2;
using float4_t = float4;
// float
typedef float float32_t __attribute__((ext_vector_type(32)));
// bfloat16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
// fp16
using half_t = half;
using half2_t = half2;
using half4_t = float2;
template <class T, index_t N>
struct vector_type
{
typedef struct
{
T scalar[N];
} type;
};
template <>
struct vector_type<float, 1>
{
using type = float;
template <index_t I>
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
{
static_assert(I < 1, "wrong");
*(reinterpret_cast<float*>(&v) + I) = s;
}
};
template <>
struct vector_type<float, 2>
{
using type = float2_t;
union DataType
{
type vector;
float scalar[2];
};
template <index_t I>
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
{
static_assert(I < 2, "wrong");
*(reinterpret_cast<float*>(&v) + I) = s;
}
__host__ __device__ static type Pack(float s0, float s1)
{
DataType data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
struct vector_type<float, 4>
{
using type = float4_t;
__host__ __device__ static constexpr index_t GetSize() { return 4; }
template <index_t I>
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
{
static_assert(I < 4, "wrong");
*(reinterpret_cast<float*>(&v) + I) = s;
}
};
template <>
struct vector_type<half_t, 1>
{
using type = half_t;
template <index_t I>
__host__ __device__ static void SetScalar(type& v, half_t s, Number<I>)
{
static_assert(I < 1, "wrong");
*(reinterpret_cast<half_t*>(&v) + I) = s;
}
};
template <>
struct vector_type<half_t, 2>
{
using type = half2_t;
union DataType
{
type vector;
half_t scalar[2];
};
template <index_t I>
__host__ __device__ static void SetScalar(type& v, half_t s, Number<I>)
{
static_assert(I < 2, "wrong");
*(reinterpret_cast<half_t*>(&v) + I) = s;
}
__host__ __device__ static type Pack(half_t s0, half_t s1)
{
DataType data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
// data type conversion
template <typename T>
struct type_convert
{
template <typename X>
__device__ T operator()(const X& x) const
{
return static_cast<T>(x);
}
};
template <typename T>
struct inner_product_with_conversion
{
static constexpr auto convert = type_convert<T>();
__device__ T operator()(float a, float b) const { return convert(a) * convert(b); }
__device__ T operator()(half2_t a, half2_t b) const
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
__device__ T operator()(half4_t a, half4_t b) const
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
};
} // namespace ck
#endif
#ifndef CK_IN_MEMORY_OPERATION_AMD_HPP
#define CK_IN_MEMORY_OPERATION_AMD_HPP
#include "float_type.hpp"
#if CK_USE_AMD_BUFFER_ADDRESSING
#include "amd_buffer_addressing.hpp"
#include "amd_buffer_addressing_v2.hpp"
#endif
namespace ck {
template <typename T>
__device__ void atomic_add_impl(T* p_dst, T src)
{
atomicAdd(p_dst, src);
}
// atomicAdd for float does not support vector type
template <>
__device__ void atomic_add_impl<float2_t>(float2_t* p_dst, float2_t src)
{
float* p_dst_float = reinterpret_cast<float*>(p_dst);
const float* p_src_float = reinterpret_cast<const float*>(&src);
for(index_t i = 0; i < 2; ++i)
{
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
}
}
template <>
__device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
{
float* p_dst_float = reinterpret_cast<float*>(p_dst);
const float* p_src_float = reinterpret_cast<const float*>(&src);
for(index_t i = 0; i < 4; ++i)
{
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
}
}
template <typename T, index_t DataPerAccess>
struct SetData
{
using vector_t = typename vector_type<T, DataPerAccess>::type;
// This version is only for compatibility, don't use this version if possible
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src,
index_t src_offset,
bool src_valid,
index_t /* src_range */,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t /* dst_range */) const
{
if(dst_valid)
{
if(src_valid)
{
#if 0
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
#else
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[0x3fffffff & src_offset]);
#endif
}
else
{
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) = 0;
}
}
}
#if CK_USE_AMD_BUFFER_ADDRESSING
// buffer_load requires:
// 1) p_src_thread must be in global memory space, p_dst_thread must be vgpr
// 2) p_src_thread to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <>
__device__ void Run<AddressSpace::Global, AddressSpace::Vgpr>(const T* p_src,
index_t src_offset,
bool src_valid,
index_t src_range,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t /* dst_range */) const
{
if(dst_valid)
{
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
amd_buffer_load_v2<T, DataPerAccess>(p_src, src_offset, src_valid, src_range);
}
}
// buffer_store requires:
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
// 2) p_dst_thread to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <>
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
index_t src_offset,
bool src_valid,
index_t /* src_range */,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t dst_range) const
{
const auto zeros = vector_t(0);
amd_buffer_store_v2<T, DataPerAccess>(
src_valid ? *reinterpret_cast<const vector_t*>(&(p_src[src_offset])) : zeros,
p_dst,
dst_offset,
dst_valid,
dst_range);
}
#endif
};
template <typename T, index_t DataPerAccess>
struct AtomicAddData
{
using vector_t = typename vector_type<T, DataPerAccess>::type;
// This version is only for compatibility, don't use this version if possible
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src,
index_t src_offset,
bool src_valid,
index_t /* src_range */,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t /* dst_range */) const
{
if(src_valid && dst_valid)
{
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
}
}
#if CK_USE_AMD_BUFFER_ADDRESSING && CK_USE_AMD_BUFFER_ATOMIC_FADD
// buffer_atomic requires:
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
// 2) p_dst_thread to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <>
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
index_t src_offset,
bool src_valid,
index_t /* src_range */,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t dst_range) const
{
const auto zeros = vector_t(0);
amd_buffer_atomic_add<T, DataPerAccess>(
src_valid ? &(p_src[src_offset]) : &zeros, p_dst, dst_offset, dst_valid, dst_range);
}
#endif
};
template <typename T,
index_t DataPerAccess,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride = 1,
index_t DstDataStride = 1>
__device__ void transfer_data(const T* p_src,
index_t src_offset,
bool src_valid,
index_t src_range,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t dst_range)
{
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
"wrong! InMemoryDataOperation not supported!");
// keep it simple, don't use static_if here, otherwise compiler will do weird things
if constexpr(SrcDataStride == 1 && DstDataStride == 1)
{
if constexpr(DstInMemOp == InMemoryDataOperation::Set)
{
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
}
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
{
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
}
}
else
{
#pragma unroll
for(index_t i = 0; i < DataPerAccess; ++i)
{
if constexpr(DstInMemOp == InMemoryDataOperation::Set)
{
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src,
src_offset + i * SrcDataStride,
src_valid,
src_range,
p_dst,
dst_offset + i * DstDataStride,
dst_valid,
dst_range);
}
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
{
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src,
src_offset + i * SrcDataStride,
src_valid,
src_range,
p_dst,
dst_offset + i * DstDataStride,
dst_valid,
dst_range);
}
}
}
}
} // namespace ck
#endif
#ifndef CK_IN_MEMORY_OPERATION_NVIDIA_HPP
#define CK_IN_MEMORY_OPERATION_NVIDIA_HPP
namespace ck {
template <typename T>
__device__ void atomic_add_impl(T* p_dst, T src)
{
atomicAdd(p_dst, src);
}
// atomicAdd for float does not support vector type
template <>
__device__ void atomic_add_impl<float2_t>(float2_t* p_dst, float2_t src)
{
float* p_dst_float = reinterpret_cast<float*>(p_dst);
const float* p_src_float = reinterpret_cast<const float*>(&src);
for(index_t i = 0; i < 2; ++i)
{
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
}
}
template <>
__device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
{
float* p_dst_float = reinterpret_cast<float*>(p_dst);
const float* p_src_float = reinterpret_cast<const float*>(&src);
for(index_t i = 0; i < 4; ++i)
{
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
}
}
template <typename T, index_t DataPerAccess>
struct SetData
{
using vector_t = typename vector_type<T, DataPerAccess>::type;
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
{
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
}
};
template <typename T, index_t DataPerAccess>
struct AtomicAddData
{
using vector_t = typename vector_type<T, DataPerAccess>::type;
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
{
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
}
};
template <typename T,
index_t DataPerAccess,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride = 1,
index_t DstDataStride = 1>
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
"wrong! InMemoryDataOperation not supported!");
// keep it simple, don't use static_if here, otherwise compiler will do weird things
if(SrcDataStride == 1 && DstDataStride == 1)
{
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset);
});
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset);
});
}
else
{
for(index_t i = 0; i < DataPerAccess; i++)
{
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
});
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
});
}
}
}
} // namespace ck
#endif
#ifndef CK_SYNCHRONIZATION_NVIDIA_HPP
#define CK_SYNCHRONIZATION_NVIDIA_HPP
#include "config.hpp"
namespace ck {
__device__ void block_sync_lds() { __syncthreads(); }
__device__ void block_sync_lds_vmem() { __syncthreads(); }
} // namespace ck
#endif
extern "C" __global__ void
gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer(
const void* const __restrict__ p_in_global,
const void* const __restrict__ p_wei_global,
void* const __restrict__ p_out_global){
};
extern "C" __global__ void gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
const void* const __restrict__ p_in_global,
const void* const __restrict__ p_wei_global,
void* const __restrict__ p_out_global){
};
extern "C" __global__ void gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
const void* const __restrict__ p_in_global,
const void* const __restrict__ p_wei_global,
void* const __restrict__ p_out_global){
};
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp"
int main(int argc, char* argv[])
{
using namespace launcher;
#if 1
// 1x1 filter, 14x14 image
constexpr index_t N = 1;
constexpr index_t C = 256;
constexpr index_t HI = 1;
constexpr index_t WI = 128;
constexpr index_t K = 16;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 28x28
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 1x1 filter, 8x8 image
constexpr index_t N = 256;
constexpr index_t C = 1024;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 1x1 filter, 14x14 image
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 28x28 image
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 5x5 filter, 2x2 pad, 7x7 input
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 1024;
constexpr index_t Y = 5;
constexpr index_t X = 5;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>;
#elif 0
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 0
// 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 1024;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 1280;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<2, 2>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#endif
constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
constexpr auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
print_array("LeftPads", LeftPads{});
print_array("LeftPads", LeftPads{});
print_array("RightPads", RightPads{});
print_array("ConvStrides", ConvStrides{});
print_array("ConvDilations", ConvDilations{});
Tensor<float> in_nchw_device(make_HostTensorDescriptor(in_nchw_desc));
Tensor<float> in_nchw_host(make_HostTensorDescriptor(in_nchw_desc));
Tensor<float> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
Tensor<float> out_nkhw(make_HostTensorDescriptor(out_nkhw_desc));
std::size_t num_thread = std::thread::hardware_concurrency();
if(argc != 3)
{
printf("arg1: do_verification, arg2: nrepeat\n");
exit(1);
}
bool do_verification = atoi(argv[1]);
std::size_t nrepeat = atoi(argv[2]);
if(do_verification)
{
#if 0
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#else
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#endif
}
#if 0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
#endif
(in_nchw_desc,
in_nchw_device,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
if(do_verification)
{
host_direct_convolution_backward_data(in_nchw_host,
wei_kcyx,
out_nkhw,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{});
check_error(in_nchw_host, in_nchw_device);
#if 0
LogRange(std::cout << "out_nkhw : ", out_nkhw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx : ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_host : ", in_nchw_host.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_device : ", in_nchw_device.mData, ",") << std::endl;
#endif
}
}
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
int main(int argc, char* argv[])
{
using namespace ck;
if(argc != 5)
{
printf("arg1: do_verification, arg2: do_log, arg3: init_method, arg4: nrepeat\n");
exit(1);
}
const bool do_verification = atoi(argv[1]);
const bool do_log = atoi(argv[2]);
const int init_method = atoi(argv[3]);
const int nrepeat = atoi(argv[4]);
#if 0
constexpr index_t N = 256;
constexpr index_t C = 256;
constexpr index_t HI = 16;
constexpr index_t WI = 16;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 1080;
constexpr index_t WI = 1920;
constexpr index_t K = 16;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t Hi = 540;
constexpr index_t Wi = 960;
constexpr index_t K = 16;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t Hi = 270;
constexpr index_t Wi = 480;
constexpr index_t K = 16;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t Hi = 1080;
constexpr index_t Wi = 1920;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 1;
constexpr index_t Hi = 1024;
constexpr index_t Wi = 2048;
constexpr index_t K = 4;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t Hi = 540;
constexpr index_t Wi = 960;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t Hi = 270;
constexpr index_t Wi = 480;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
// 3x3, 36x36, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t Hi = 37;
constexpr index_t Wi = 37;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t Hi = 35;
constexpr index_t Wi = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 71;
constexpr index_t WI = 71;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
// 1x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 1536;
constexpr index_t Hi = 8;
constexpr index_t Wi = 8;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 1x1, 73x73
constexpr index_t N = 128;
constexpr index_t C = 160;
constexpr index_t Hi = 73;
constexpr index_t Wi = 73;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35
constexpr index_t N = 128;
constexpr index_t C = 96;
constexpr index_t Hi = 35;
constexpr index_t Wi = 35;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t Hi = 71;
constexpr index_t Wi = 71;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
// 7x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t Hi = 17;
constexpr index_t Wi = 17;
constexpr index_t K = 128;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<3, 0>;
using InRightPads = Sequence<3, 0>;
#elif 1
// 1x7, 17x17
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t Hi = 17;
constexpr index_t Wi = 17;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 3>;
using InRightPads = Sequence<0, 3>;
#elif 0
// 3x3, 299x299 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t Hi = 299;
constexpr index_t Wi = 299;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 3x3, 147x147
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t Hi = 147;
constexpr index_t Wi = 147;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
// 3x3, 149x149
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t Hi = 149;
constexpr index_t Wi = 149;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 3x3, 17x17, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t Hi = 17;
constexpr index_t Wi = 17;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 1x1, 35x35
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t Hi = 35;
constexpr index_t Wi = 35;
constexpr index_t K = 96;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35, stride 2
constexpr index_t N = 128;
constexpr index_t C = 288;
constexpr index_t Hi = 35;
constexpr index_t Wi = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 1x3, 8x8
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t Hi = 8;
constexpr index_t Wi = 8;
constexpr index_t K = 448;
constexpr index_t Y = 1;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 1>;
using InRightPads = Sequence<0, 1>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t Hi = 8;
constexpr index_t Wi = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 0>;
using InRightPads = Sequence<1, 0>;
#elif 0
// 3x3, 147x147
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t Hi = 147;
constexpr index_t Wi = 147;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 7x1, 73x73
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t Hi = 73;
constexpr index_t Wi = 73;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<3, 0>;
using InRightPads = Sequence<3, 0>;
#elif 0
// 3x3, 73x73
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t Hi = 73;
constexpr index_t Wi = 73;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 256;
constexpr index_t C = 1024;
constexpr index_t Hi = 14;
constexpr index_t Wi = 14;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14
constexpr index_t N = 256;
constexpr index_t C = 1024;
constexpr index_t Hi = 14;
constexpr index_t Wi = 14;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t Hi = 14;
constexpr index_t Wi = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 1
// 3x3, 28x28
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t Hi = 28;
constexpr index_t Wi = 28;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 1
// 3x3, 14x14
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t Hi = 14;
constexpr index_t Wi = 14;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
// 1x1, 56x56, stride 2
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t Hi = 56;
constexpr index_t Wi = 56;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 7x7, 230x230 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t Hi = 230;
constexpr index_t Wi = 230;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride = 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t Hi = 28;
constexpr index_t Wi = 28;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t Hi = 28;
constexpr index_t Wi = 28;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 1
// 1x1, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t Hi = 7;
constexpr index_t Wi = 7;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 3x3, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t Hi = 7;
constexpr index_t Wi = 7;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#elif 0
// 1x1, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t Hi = 56;
constexpr index_t Wi = 56;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
// 3x3, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t Hi = 56;
constexpr index_t Wi = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<1, 1>;
using InRightPads = Sequence<1, 1>;
#endif
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
constexpr index_t Ho = (Hi + InLeftPads{}[0] + InRightPads{}[0] - YEff) / ConvStrides{}[0] + 1;
constexpr index_t Wo = (Wi + InLeftPads{}[1] + InRightPads{}[1] - XEff) / ConvStrides{}[1] + 1;
#if 1
constexpr index_t in_vector_size = 1;
using in_data_t = typename vector_type<float, in_vector_size>::type;
using acc_data_t = float;
using out_data_t = float;
#elif 1
using in_data_t = half_t;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using out_data_t = half_t;
#elif 0
constexpr index_t in_vector_size = 1;
using in_data_t = typename vector_type<float, in_vector_size>::type;
using acc_data_t = float;
using out_data_t = int8_t;
#elif 1
constexpr index_t in_vector_size = 16;
using in_data_t = typename vector_type<int8_t, in_vector_size>::type;
using acc_data_t = int32_t;
using out_data_t = int8_t;
#endif
Tensor<in_data_t> in_nchw(HostTensorDescriptor(std::initializer_list<index_t>{N, C, Hi, Wi}));
Tensor<in_data_t> wei_kcyx(HostTensorDescriptor(std::initializer_list<index_t>{K, C, Y, X}));
Tensor<out_data_t> out_nkhw_host(
HostTensorDescriptor(std::initializer_list<index_t>{N, K, Ho, Wo}));
Tensor<out_data_t> out_nkhw_device(
HostTensorDescriptor(std::initializer_list<index_t>{N, K, Ho, Wo}));
ostream_HostTensorDescriptor(in_nchw.mDesc, std::cout << "in_nchw_desc: ");
ostream_HostTensorDescriptor(wei_kcyx.mDesc, std::cout << "wei_kcyx_desc: ");
ostream_HostTensorDescriptor(out_nkhw_host.mDesc, std::cout << "out_nkhw_desc: ");
print_array("InLeftPads", InLeftPads{});
print_array("InRightPads", InRightPads{});
print_array("ConvStrides", ConvStrides{});
print_array("ConvDilations", ConvDilations{});
std::size_t num_thread = std::thread::hardware_concurrency();
if(do_verification)
{
switch(init_method)
{
case 0:
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break;
case 1:
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break;
case 2:
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break;
case 3:
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break;
default:
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
auto gen_wei = [](auto... is) {
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
};
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
}
}
constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, Hi, Wi>{});
constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
constexpr auto out_nkhw_desc = make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
#if 1
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
InLeftPads{},
InRightPads{},
nrepeat);
#elif 0
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
InLeftPads{},
InRightPads{},
nrepeat);
#elif 0
device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
InLeftPads{},
InRightPads{},
nrepeat);
#endif
if(do_verification)
{
host_direct_convolution(in_nchw,
wei_kcyx,
out_nkhw_host,
ConvStrides{},
ConvDilations{},
InLeftPads{},
InRightPads{});
check_error(out_nkhw_host, out_nkhw_device);
if(do_log)
{
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
}
}
}
......@@ -24,16 +24,16 @@
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4_NHWC 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V4R5_NCHW 0
#define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R4_NHWC 1
#define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V4R5_NCHW 1
#define USE_CONV_FWD_V4R5R2_NCHW 1
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
#define USE_CONV_FWD_V4R4_XDL_NCHW 1
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 1
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
enum ConvForwardAlgo
{
......
#ifndef CONV_COMMON_HPP
#define CONV_COMMON_HPP
#include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor.hpp"
enum ConvTensorLayout
......@@ -13,53 +12,6 @@ enum ConvTensorLayout
NHWCc
};
template <class InDesc,
class WeiDesc,
class ConvStrides,
class ConvDilations,
class LeftPads,
class RightPads>
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
InDesc, WeiDesc, ConvStrides, ConvDilations, LeftPads, RightPads)
{
using namespace ck;
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4");
static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4");
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
"input & weight dimension not consistent");
constexpr index_t N = in_desc.GetLength(I0);
constexpr index_t Hi = in_desc.GetLength(I2);
constexpr index_t Wi = in_desc.GetLength(I3);
constexpr index_t K = wei_desc.GetLength(I0);
constexpr index_t Y = wei_desc.GetLength(I2);
constexpr index_t X = wei_desc.GetLength(I3);
constexpr index_t LeftPadH = LeftPads{}.Get(I0);
constexpr index_t LeftPadW = LeftPads{}.Get(I1);
constexpr index_t RightPadH = RightPads{}.Get(I0);
constexpr index_t RightPadW = RightPads{}.Get(I1);
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
constexpr index_t Ho = (Hi + LeftPadH + RightPadH - YEff) / ConvStrides{}[0] + 1;
constexpr index_t Wo = (Wi + LeftPadW + RightPadW - XEff) / ConvStrides{}[1] + 1;
return make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
}
template <typename... InDesc,
typename... WeiDesc,
typename ConvStrides,
......@@ -131,30 +83,4 @@ calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, cons
return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
}
template <class Float, class InDesc, class WeiDesc, class OutDesc>
constexpr std::size_t calculate_convolution_memory_size(Float, InDesc, WeiDesc, OutDesc)
{
using namespace ck;
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t N = out_desc.GetLength(I0);
constexpr index_t K = out_desc.GetLength(I1);
constexpr index_t Ho = out_desc.GetLength(I2);
constexpr index_t Wo = out_desc.GetLength(I3);
constexpr index_t C = wei_desc.GetLength(I1);
constexpr index_t Y = wei_desc.GetLength(I2);
constexpr index_t X = wei_desc.GetLength(I3);
return sizeof(Float) *
(InDesc::GetElementSpace() + WeiDesc::GetElementSpace() + OutDesc::GetElementSpace());
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment