Unverified Commit 31b40352 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge pull request #16 from ROCmSoftwarePlatform/develop

Merge develop into master
parents 5781adf5 b62bf8c3
...@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16> ...@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const ushort2_t*>(a); const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b); const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run( return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c); p_a, p_b, reg_c);
...@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16> ...@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const ushort2_t*>(a); const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b); const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c); return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
} }
...@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16> ...@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const ushort2_t*>(a); const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b); const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c); return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
} }
...@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16> ...@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const ushort2_t*>(a); const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b); const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c); return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
} }
...@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16> ...@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const ushort2_t*>(a); const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b); const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c); return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
} }
......
#ifndef CK_AMD_ADDRESS_SPACE_HPP
#define CK_AMD_ADDRESS_SPACE_HPP
#include "config.hpp"
#include "c_style_pointer_cast.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace ck {
enum AddressSpaceEnum_t
{
Generic,
Global,
Lds,
Sgpr,
Vgpr,
};
template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template <typename T>
__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CONSTANT*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
} // namespace ck
#endif
#ifndef CK_AMD_BUFFER_ADDRESSING_V2_HPP #ifndef CK_AMD_BUFFER_ADDRESSING_HPP
#define CK_AMD_BUFFER_ADDRESSING_V2_HPP #define CK_AMD_BUFFER_ADDRESSING_HPP
#include "data_type.hpp" #include "data_type.hpp"
namespace ck { namespace ck {
template <typename T> template <typename T>
union BufferResource_v2 union BufferResource
{ {
// 128 bit SGPRs to supply buffer resource in buffer instructions // 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data; int32x4_t content;
StaticallyIndexedArray<T*, 2> address; StaticallyIndexedArray<T*, 2> address;
StaticallyIndexedArray<int32_t, 4> range; StaticallyIndexedArray<int32_t, 4> range;
StaticallyIndexedArray<int32_t, 4> config; StaticallyIndexedArray<int32_t, 4> config;
}; };
template <typename T> template <typename T>
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size) __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
{ {
BufferResource_v2<T> wave_buffer_resource; BufferResource<T> wave_buffer_resource;
// wavewise base address (64 bit) // wavewise base address (64 bit)
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave); wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
// wavewise range (32 bit) // wavewise range (32 bit)
wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T); wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
// wavewise setting (32 bit) // wavewise setting (32 bit)
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
return wave_buffer_resource.data; return wave_buffer_resource.content;
} }
// load // load
...@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, ...@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type<T, N>::type __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset,
index_t src_thread_addr_offset, index_t src_wave_addr_offset)
index_t src_wave_addr_offset)
{ {
static_assert( static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
...@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
template <typename T, index_t N> template <typename T, index_t N>
__device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data, __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource, int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset, index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset) index_t dst_wave_addr_offset)
{ {
static_assert( static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
...@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
// buffer_load requires: // buffer_load requires:
// 1) p_src_wave must be in global memory space // 1) p_src_wave must be in global memory space
// 2) p_src_wave to be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type_maker<T, N>::type::type __device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_v2(const T* p_src_wave, amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
index_t src_thread_data_offset, index_t src_thread_element_offset,
bool src_thread_data_valid, bool src_thread_element_valid,
index_t src_element_space) index_t src_element_space_size)
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space); make_wave_buffer_resource(p_src_wave, src_element_space_size);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
return amd_buffer_load_impl_v2<scalar_t, vector_size>( return amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else #else
vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_data_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#endif #endif
} }
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
index_t src_element_space_size,
T customized_value)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size);
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(customized_value);
}
// buffer_store requires: // buffer_store requires:
// 1) p_dst_wave must be global memory // 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer. // 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
__device__ void __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data, T* p_dst_wave,
T* p_dst_wave, const index_t dst_thread_element_offset,
const index_t dst_thread_data_offset, const bool dst_thread_element_valid,
const bool dst_thread_data_valid, const index_t dst_element_space_size)
const index_t dst_element_space)
{ {
const int32x4_t dst_wave_buffer_resource = const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space); make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type; using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type; using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
amd_buffer_store_impl_v2<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else #else
if(dst_thread_data_valid) if(dst_thread_element_valid)
{ {
amd_buffer_store_impl_v2<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
} }
#endif #endif
......
#ifndef CK_AMD_DLOP_HPP
#define CK_AMD_DLOP_HPP
#include "data_type.hpp"
namespace ck {
template <typename TA, typename TB, typename TC>
__device__ void amd_inner_product_dlop(const TA& a, const TB& b, TC& c);
template <>
__device__ void
amd_inner_product_dlop<float, float, float>(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c += a * b;
#endif
}
template <>
__device__ void
amd_inner_product_dlop<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_inner_product_dlop(vector_type<float, 2>{a}.AsType<float>()[I0],
vector_type<float, 2>{b}.AsType<float>()[I0],
c);
amd_inner_product_dlop(vector_type<float, 2>{a}.AsType<float>()[I1],
vector_type<float, 2>{b}.AsType<float>()[I1],
c);
}
template <>
__device__ void
amd_inner_product_dlop<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I0],
vector_type<float, 4>{b}.AsType<float>()[I0],
c);
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I1],
vector_type<float, 4>{b}.AsType<float>()[I1],
c);
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I2],
vector_type<float, 4>{b}.AsType<float>()[I2],
c);
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I3],
vector_type<float, 4>{b}.AsType<float>()[I3],
c);
}
#if CK_USE_AMD_DLOP
template <>
__device__ void
amd_inner_product_dlop<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm volatile("\n \
v_dot2_f32_f16 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c = __builtin_amdgcn_sdot2(a, b, c, false);
#endif
}
template <>
__device__ void
amd_inner_product_dlop<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
c);
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
c);
}
template <>
__device__ void
amd_inner_product_dlop<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
c);
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
c);
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
c);
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
c);
}
template <>
__device__ void amd_inner_product_dlop<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a,
const int8x4_t& b,
int32_t& c)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif
}
template <>
__device__ void amd_inner_product_dlop<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a,
const int8x8_t& b,
int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
c);
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
c);
}
template <>
__device__ void amd_inner_product_dlop<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a,
const int8x16_t& b,
int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
c);
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
c);
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
c);
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
c);
}
#endif // CK_USE_AMD_DLOP
} // namespace ck
#endif
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#define CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP
#include "data_type.hpp" #include "data_type.hpp"
#include "c_style_pointer_cast.hpp"
// TODO: deprecate all amd_assembly_outer_product_xxx
namespace ck { namespace ck {
...@@ -53,9 +56,9 @@ __device__ void ...@@ -53,9 +56,9 @@ __device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{ {
// TODO remove pointer casting // TODO remove pointer casting
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a); const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0); const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1); const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
// do dot2 two times // do dot2 two times
asm volatile("\n \ asm volatile("\n \
...@@ -114,11 +117,11 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, ...@@ -114,11 +117,11 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
float& c3) float& c3)
{ {
// TODO remove pointer casting // TODO remove pointer casting
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a); const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0); const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1); const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
const half2_t* p_b2_half2 = reinterpret_cast<const half2_t*>(&b2); const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
const half2_t* p_b3_half2 = reinterpret_cast<const half2_t*>(&b3); const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
// do dot2 two times // do dot2 two times
asm volatile("\n \ asm volatile("\n \
...@@ -160,11 +163,11 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a, ...@@ -160,11 +163,11 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
{ {
// TODO remove pointer casting // TODO remove pointer casting
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a); const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0); const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1); const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
const half4_t* p_b2_half4 = reinterpret_cast<const half4_t*>(&b2); const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
const half4_t* p_b3_half4 = reinterpret_cast<const half4_t*>(&b3); const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
amd_assembly_outer_product_1x4( amd_assembly_outer_product_1x4(
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3); p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
...@@ -184,11 +187,11 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a, ...@@ -184,11 +187,11 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
float& c3) float& c3)
{ {
// TODO remove pointer casting // TODO remove pointer casting
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a); const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0); const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1); const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
const half8_t* p_b2_half8 = reinterpret_cast<const half8_t*>(&b2); const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
const half8_t* p_b3_half8 = reinterpret_cast<const half8_t*>(&b3); const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
amd_assembly_outer_product_1x4( amd_assembly_outer_product_1x4(
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3); p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
......
#ifndef CK_C_STYLE_POINTER_CAST_HPP
#define CK_C_STYLE_POINTER_CAST_HPP
#include "type.hpp"
#include "enable_if.hpp"
namespace ck {
template <typename PY,
typename PX,
typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
__host__ __device__ PY c_style_pointer_cast(PX p_x)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wcast-align"
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
#pragma clang diagnostic pop
}
} // namespace ck
#endif
...@@ -7,13 +7,14 @@ ...@@ -7,13 +7,14 @@
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
#include "container_element_picker.hpp" #include "container_element_picker.hpp"
#include "multi_index.hpp" #include "multi_index.hpp"
#include "data_type_enum.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "data_type_helper.hpp" #include "data_type_enum.hpp"
#include "data_type_enum_helper.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
#include "functional4.hpp" #include "functional4.hpp"
#include "enable_if.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "math.hpp" #include "math.hpp"
#include "number.hpp" #include "number.hpp"
...@@ -23,21 +24,21 @@ ...@@ -23,21 +24,21 @@
#include "tuple.hpp" #include "tuple.hpp"
#include "tuple_helper.hpp" #include "tuple_helper.hpp"
#include "type.hpp" #include "type.hpp"
#include "utility.hpp"
#include "magic_division.hpp" #include "magic_division.hpp"
#include "amd_buffer_addressing_v2.hpp" #include "utility.hpp"
#include "c_style_pointer_cast.hpp"
#include "amd_address_space.hpp"
#include "amd_buffer_addressing.hpp"
#include "static_buffer.hpp" #include "static_buffer.hpp"
#include "dynamic_buffer.hpp" #include "dynamic_buffer.hpp"
#include "inner_product.hpp"
// TODO: remove this // TODO: remove this
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
#endif #endif
#if CK_USE_AMD_DLOP
#include "amd_dlop.hpp"
#endif
#if CK_USE_AMD_XDLOPS #if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp" #include "amd_xdlops.hpp"
#endif #endif
......
...@@ -7,19 +7,14 @@ ...@@ -7,19 +7,14 @@
#endif #endif
#include "bfloat16_dev.hpp" #include "bfloat16_dev.hpp"
// address space for kernel parameter // "Constant" address space for kernel parameter
#define CONSTANT __attribute__((address_space(4))) #define CONSTANT __attribute__((address_space(4)))
// GPU target // GPU target
// should enable one and only one GPU target // should enable one and only one GPU target
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ #if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030)) defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
#error Need to define a single GPU target #error Need to define (only) one GPU target
#endif
// HIP version
#ifndef CK_HIP_VERSION_FLAT
#define CK_HIP_VERSION_FLAT 0
#endif #endif
// launch bounds // launch bounds
...@@ -38,6 +33,16 @@ ...@@ -38,6 +33,16 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#endif #endif
// FMA instruction
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900)
#define CK_USE_AMD_V_MAC_F32
#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \
defined(CK_AMD_GPU_GFX1030)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#endif
// multi index // multi index
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 #define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
...@@ -46,13 +51,9 @@ ...@@ -46,13 +51,9 @@
#define CK_USE_AMD_INLINE_ASM 1 #define CK_USE_AMD_INLINE_ASM 1
#endif #endif
// AMD DLOPS // AMD inner product (DLOP)
#ifndef CK_USE_AMD_DLOP #ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
#define CK_USE_AMD_DLOP 1 #define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
#endif
#ifndef CK_USE_AMD_DLOP_INLINE_ASM
#define CK_USE_AMD_DLOP_INLINE_ASM 1
#endif #endif
// AMD buffer addressing // AMD buffer addressing
...@@ -99,8 +100,8 @@ ...@@ -99,8 +100,8 @@
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// thread-invariant, otherwise it's a bug // thread-invariant, otherwise it's a bug
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" // TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
#ifndef CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE #ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 #define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
#endif #endif
// workaround for compiler crash when compiling recursive lambda // workaround for compiler crash when compiling recursive lambda
...@@ -120,15 +121,6 @@ ...@@ -120,15 +121,6 @@
namespace ck { namespace ck {
enum AddressSpaceEnum_t
{
Generic,
Global,
Lds,
Sgpr,
Vgpr
};
enum InMemoryDataOperationEnum_t enum InMemoryDataOperationEnum_t
{ {
Set, Set,
......
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
namespace ck { namespace ck {
// this enumerate should be synchronized with include/miopen.h enum DataTypeEnum_t
typedef enum
{ {
Half = 0, Half = 0,
Float = 1, Float = 1,
...@@ -14,7 +13,7 @@ typedef enum ...@@ -14,7 +13,7 @@ typedef enum
BFloat16 = 5, BFloat16 = 5,
Double = 6, Double = 6,
Unknown = 100, Unknown = 100,
} DataTypeEnum_t; };
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_DATA_TYPE_HELPER_HPP #ifndef CK_DATA_TYPE_ENUM_HELPER_HPP
#define CK_DATA_TYPE_HELPER_HPP #define CK_DATA_TYPE_ENUM_HELPER_HPP
#include "data_type.hpp" #include "data_type.hpp"
#include "data_type_enum.hpp" #include "data_type_enum.hpp"
......
#ifndef CK_DYNAMIC_BUFFER_HPP #ifndef CK_BUFFER_HPP
#define CK_DYNAMIC_BUFFER_HPP #define CK_BUFFER_HPP
namespace ck { #include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp"
#include "enable_if.hpp"
#include "amd_buffer_addressing_v2.hpp" namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize> template <AddressSpaceEnum_t BufferAddressSpace,
typename T,
typename ElementSpaceSize,
bool InvalidElementUseNumericalZeroValue>
struct DynamicBuffer struct DynamicBuffer
{ {
using type = T; using type = T;
T* p_data_; T* p_data_;
ElementSpaceSize element_space_size_; ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size} : p_data_{p_data}, element_space_size_{element_space_size}
{ {
} }
__host__ __device__ constexpr DynamicBuffer(T* p_data,
ElementSpaceSize element_space_size,
T invalid_element_value)
: p_data_{p_data},
element_space_size_{element_space_size},
invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{ {
return BufferAddressSpace; return BufferAddressSpace;
} }
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X, template <typename X,
typename std::enable_if< typename enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector =
...@@ -44,29 +55,50 @@ struct DynamicBuffer ...@@ -44,29 +55,50 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>( bool constexpr use_amd_buffer_addressing = true;
p_data_, i, is_valid_offset, element_space_size_);
#else #else
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0}; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(InvalidElementUseNumericalZeroValue)
{
return amd_buffer_load_invalid_element_return_return_zero<
remove_cv_t<remove_reference_t<T>>,
t_per_x>(p_data_, i, is_valid_element, element_space_size_);
}
else
{
return amd_buffer_load_invalid_element_return_customized_value<
remove_cv_t<remove_reference_t<T>>,
t_per_x>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
}
} }
else else
{ {
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0}; if constexpr(InvalidElementUseNumericalZeroValue)
{
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
}
else
{
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
: X{invalid_element_value_};
}
} }
} }
template <typename X, template <typename X,
typename std::enable_if< typename enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector =
...@@ -78,26 +110,26 @@ struct DynamicBuffer ...@@ -78,26 +110,26 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>( constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
x, p_data_, i, is_valid_offset, element_space_size_);
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
#else #else
if(is_valid_offset) if(is_valid_element)
{ {
*reinterpret_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
} }
#endif #endif
} }
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
{ {
if(is_valid_offset) if(is_valid_element)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*reinterpret_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
#else #else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
// inefficient // inefficient
...@@ -128,24 +160,24 @@ struct DynamicBuffer ...@@ -128,24 +160,24 @@ struct DynamicBuffer
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*reinterpret_cast<int8_t*>(&p_data_[i]) = *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*reinterpret_cast<const int8_t*>(&x); *c_style_pointer_cast<const int8_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*reinterpret_cast<int16_t*>(&p_data_[i]) = *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*reinterpret_cast<const int16_t*>(&x); *c_style_pointer_cast<const int16_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*reinterpret_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*reinterpret_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
int8x4_t>::value && int8x4_t>::value &&
...@@ -153,8 +185,8 @@ struct DynamicBuffer ...@@ -153,8 +185,8 @@ struct DynamicBuffer
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*reinterpret_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*reinterpret_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
int8x8_t>::value && int8x8_t>::value &&
...@@ -162,8 +194,8 @@ struct DynamicBuffer ...@@ -162,8 +194,8 @@ struct DynamicBuffer
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*reinterpret_cast<int32x2_t*>(&p_data_[i]) = *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*reinterpret_cast<const int32x2_t*>(&x); *c_style_pointer_cast<const int32x2_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
int8x16_t>::value && int8x16_t>::value &&
...@@ -171,22 +203,22 @@ struct DynamicBuffer ...@@ -171,22 +203,22 @@ struct DynamicBuffer
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*reinterpret_cast<int32x4_t*>(&p_data_[i]) = *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*reinterpret_cast<const int32x4_t*>(&x); *c_style_pointer_cast<const int32x4_t*>(&x);
} }
} }
else else
{ {
*reinterpret_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
} }
#endif #endif
} }
} }
else else
{ {
if(is_valid_offset) if(is_valid_element)
{ {
*reinterpret_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
} }
} }
} }
...@@ -196,12 +228,18 @@ struct DynamicBuffer ...@@ -196,12 +228,18 @@ struct DynamicBuffer
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
}; };
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic, template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
typename T,
typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{ {
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size}; return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
}
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
__host__ __device__ constexpr auto
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
{
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
p, element_space_size, invalid_element_value};
} }
} // namespace ck } // namespace ck
......
#ifndef CK_ENABLE_IF_HPP
#define CK_ENABLE_IF_HPP
namespace ck {
template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck
#endif
#ifndef CK_INNER_PRODUCT_HPP
#define CK_INNER_PRODUCT_HPP
#include "data_type.hpp"
namespace ck {
template <typename TA, typename TB, typename TC>
__device__ void inner_product(const TA& a, const TB& b, TC& c);
template <>
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c += a * b;
#endif
}
template <>
__device__ void
inner_product<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
inner_product(vector_type<float, 2>{a}.AsType<float>()[I0],
vector_type<float, 2>{b}.AsType<float>()[I0],
c);
inner_product(vector_type<float, 2>{a}.AsType<float>()[I1],
vector_type<float, 2>{b}.AsType<float>()[I1],
c);
}
template <>
__device__ void
inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
inner_product(vector_type<float, 4>{a}.AsType<float>()[I0],
vector_type<float, 4>{b}.AsType<float>()[I0],
c);
inner_product(vector_type<float, 4>{a}.AsType<float>()[I1],
vector_type<float, 4>{b}.AsType<float>()[I1],
c);
inner_product(vector_type<float, 4>{a}.AsType<float>()[I2],
vector_type<float, 4>{b}.AsType<float>()[I2],
c);
inner_product(vector_type<float, 4>{a}.AsType<float>()[I3],
vector_type<float, 4>{b}.AsType<float>()[I3],
c);
}
template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{
#if defined(CK_USE_AMD_V_DOT2_F32_F16)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm volatile("\n \
v_dot2_f32_f16 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c = __builtin_amdgcn_sdot2(a, b, c, false);
#endif
#else
const auto convert = type_convert<int32_t>{};
const vector_type<half_t, 2> a_vector{a};
const vector_type<half_t, 2> b_vector{b};
static_for<0, 2, 1>{}([&](auto i) {
c += convert(a_vector.AsType<half_t>()[i]) * convert(b_vector.AsType<half_t>()[i]);
});
#endif
}
template <>
__device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
inner_product(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
c);
inner_product(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
c);
}
template <>
__device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
c);
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
c);
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
c);
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
c);
}
template <>
__device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{
#if defined(CK_USE_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif
#else
const auto convert = type_convert<int32_t>{};
const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b};
static_for<0, 4, 1>{}([&](auto i) {
c += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
});
#endif
}
template <>
__device__ void
inner_product<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a, const int8x8_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
c);
inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
c);
}
template <>
__device__ void
inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
c);
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
c);
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
c);
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
c);
}
} // namespace ck
#endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp" #include "number.hpp"
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
...@@ -27,13 +28,7 @@ struct minus ...@@ -27,13 +28,7 @@ struct minus
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; } __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
}; };
template <typename T>
struct multiplies struct multiplies
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
};
struct multiplies_v2
{ {
template <typename A, typename B> template <typename A, typename B>
__host__ __device__ constexpr auto operator()(const A& a, const B& b) const __host__ __device__ constexpr auto operator()(const A& a, const B& b) const
...@@ -184,9 +179,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>) ...@@ -184,9 +179,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
return Number<r>{}; return Number<r>{};
} }
template <typename X, template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys) __host__ __device__ constexpr auto gcd(X x, Ys... ys)
{ {
return gcd(x, gcd(ys...)); return gcd(x, gcd(ys...));
...@@ -199,9 +192,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y) ...@@ -199,9 +192,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
return (x * y) / gcd(x, y); return (x * y) / gcd(x, y);
} }
template <typename X, template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto lcm(X x, Ys... ys) __host__ __device__ constexpr auto lcm(X x, Ys... ys)
{ {
return lcm(x, lcm(ys...)); return lcm(x, lcm(ys...));
......
...@@ -11,59 +11,11 @@ namespace ck { ...@@ -11,59 +11,11 @@ namespace ck {
template <typename T> template <typename T>
__host__ __device__ void print_array(const char* s, T a) __host__ __device__ void print_array(const char* s, T a)
{ {
using data_type = decltype(a.At(Number<0>{}));
constexpr index_t nsize = a.Size(); constexpr index_t nsize = a.Size();
#if 0
if constexpr(is_same<data_type, uint32_t>{})
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, int32_t>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, bool>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); });
printf("}\n");
}
#else
printf("%s size %d, {", s, nsize); printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n"); printf("}\n");
#endif
}
template <typename T>
__host__ __device__ void print_array_v2(const char* s, T a)
{
using data_type = decltype(a.At(Number<0>{}));
constexpr index_t nsize = a.Size();
#if 0
if constexpr(is_same<data_type, uint32_t>{})
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); });
printf("}\n");
}
else if constexpr(is_same<data_type, int32_t>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
printf("}\n");
}
#else
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
printf("}\n");
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -685,8 +685,6 @@ __host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>) ...@@ -685,8 +685,6 @@ __host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
template <index_t Y, index_t... Xs> template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>) __host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
{ {
constexpr auto seq_x = Sequence<Xs...>{};
return Sequence<(Y - Xs)...>{}; return Sequence<(Y - Xs)...>{};
} }
......
...@@ -5,30 +5,66 @@ ...@@ -5,30 +5,66 @@
namespace ck { namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N> template <AddressSpaceEnum_t BufferAddressSpace,
typename T,
index_t N,
bool InvalidElementUseNumericalZeroValue>
struct StaticBuffer : public StaticallyIndexedArray<T, N> struct StaticBuffer : public StaticallyIndexedArray<T, N>
{ {
using type = T; using type = T;
using base = StaticallyIndexedArray<T, N>; using base = StaticallyIndexedArray<T, N>;
T invalid_element_value_ = T{0};
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
: base{}, invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{ {
return BufferAddressSpace; return BufferAddressSpace;
} }
template <index_t I>
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return is_valid_element ? At(i) : T{0};
}
else
{
return is_valid_element ? At(i) : invalid_element_value_;
}
}
template <index_t I>
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
{
if(is_valid_element)
{
At(i) = x;
}
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
}; };
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic, template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
typename T,
index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>) __host__ __device__ constexpr auto make_static_buffer(Number<N>)
{ {
return StaticBuffer<BufferAddressSpace, T, N>{}; return StaticBuffer<BufferAddressSpace, T, N, true>{};
}
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
{
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
} }
} // namespace ck } // namespace ck
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "sequence.hpp" #include "sequence.hpp"
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
...@@ -20,10 +21,9 @@ struct TupleElement ...@@ -20,10 +21,9 @@ struct TupleElement
{ {
__host__ __device__ constexpr TupleElement() = default; __host__ __device__ constexpr TupleElement() = default;
template < template <typename T,
typename T, typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
typename std::enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
{ {
} }
...@@ -58,17 +58,16 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -58,17 +58,16 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
{ {
__host__ __device__ constexpr TupleImpl() = default; __host__ __device__ constexpr TupleImpl() = default;
template < template <typename Y,
typename Y, typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
typename std::enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))... : TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
{ {
} }
template <typename... Ys, typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{ {
...@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr Tuple() = default; __host__ __device__ constexpr Tuple() = default;
template <typename Y, template <typename Y,
typename std::enable_if< typename enable_if<sizeof...(Xs) == 1 &&
sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value, !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y)) __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
{ {
} }
template <typename... Ys, template <typename... Ys,
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
bool>::type = false> false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...) __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
{ {
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define CK_TYPE_HPP #define CK_TYPE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
...@@ -22,10 +23,7 @@ template <typename T> ...@@ -22,10 +23,7 @@ template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type; using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T> template <typename T>
constexpr std::remove_reference_t<T>&& move(T&& t) noexcept inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
{
return static_cast<typename std::remove_reference<T>::type&&>(t);
}
template <typename T> template <typename T>
struct is_known_at_compile_time; struct is_known_at_compile_time;
...@@ -42,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>> ...@@ -42,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>>
static constexpr bool value = true; static constexpr bool value = true;
}; };
template <typename Y, template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
typename X,
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y as_type(X x) __host__ __device__ constexpr Y as_type(X x)
{ {
union AsType union AsType
......
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_dlops_v1r2.hpp" #include "gridwise_gemm_dlops_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
using namespace ck; using namespace ck;
...@@ -64,8 +64,7 @@ constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDs ...@@ -64,8 +64,7 @@ constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDs
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
extern "C" __global__ void extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
int n, int n,
int c, int c,
int hi, int hi,
...@@ -93,12 +92,9 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare( ...@@ -93,12 +92,9 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi)); const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
const auto wei_k_c_y_x_desc = const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x));
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo));
const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc, wei_k_c_y_x_desc,
...@@ -117,16 +113,16 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare( ...@@ -117,16 +113,16 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
using BKNGridDesc = decltype(b_k_n_grid_desc); using BKNGridDesc = decltype(b_k_n_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0>{})));
using BGridIteratorHacks = using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
...@@ -134,65 +130,65 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare( ...@@ -134,65 +130,65 @@ dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}))); Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm = using GridwiseGemm =
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize, GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */ InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
AKMGridDesc, AKMGridDesc,
BKNGridDesc, BKNGridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM10, M1N1ThreadClusterM10,
M1N1ThreadClusterN10, M1N1ThreadClusterN10,
M1N1ThreadClusterM11, M1N1ThreadClusterM11,
M1N1ThreadClusterN11, M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1, ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1, BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridStepHacks,
BGridIteratorHacks, BGridStepHacks,
CGridIteratorHacks, CGridStepHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowStepHacks>;
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
...@@ -216,7 +212,7 @@ extern "C" __global__ void ...@@ -216,7 +212,7 @@ extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -230,11 +226,11 @@ extern "C" __global__ void ...@@ -230,11 +226,11 @@ extern "C" __global__ void
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto in_n_c_hi_wi_desc = constexpr auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto wei_k_c_y_x_desc = constexpr auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3)); make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
constexpr auto out_n_k_ho_wo_desc = constexpr auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto descs = constexpr auto descs =
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
...@@ -253,16 +249,16 @@ extern "C" __global__ void ...@@ -253,16 +249,16 @@ extern "C" __global__ void
using BKNGridDesc = decltype(b_k_n_grid_desc); using BKNGridDesc = decltype(b_k_n_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc); using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}))); Sequence<0, 0, 0, 0, 0>{})));
using BGridIteratorHacks = using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
...@@ -270,65 +266,65 @@ extern "C" __global__ void ...@@ -270,65 +266,65 @@ extern "C" __global__ void
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}))); Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm = using GridwiseGemm =
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize, GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */ InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
AKMGridDesc, AKMGridDesc,
BKNGridDesc, BKNGridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM10, M1N1ThreadClusterM10,
M1N1ThreadClusterN10, M1N1ThreadClusterN10,
M1N1ThreadClusterM11, M1N1ThreadClusterM11,
M1N1ThreadClusterN11, M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1, ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1, BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridStepHacks,
BGridIteratorHacks, BGridStepHacks,
CGridIteratorHacks, CGridStepHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowStepHacks>;
constexpr auto a_k_m0_m1_grid_desc_tmp = constexpr auto a_k_m0_m1_grid_desc_tmp =
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment