Unverified Commit f63a23ac authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

[MIOpen Downstream] Initial MIOpen integration (#52)

* update online kernel wrapper bundle all descriptors in a tuple

* change __CONSTANT__ to CONSTANT

* rename

* adding tuning

* added IsValidCompileParameter

* reorginze

* adding tunable for fp16 and int8

* fix kernel compile warning and bug fixes

* suppress warning about cast CONSTANT (address space 4) pointer

* fix building issue
parent 12649254
#ifndef CK_AMD_BUFFER_ADDRESSING_V2_HPP #ifndef CK_AMD_BUFFER_ADDRESSING_V2_HPP
#define CK_AMD_BUFFER_ADDRESSING_V2_HPP #define CK_AMD_BUFFER_ADDRESSING_V2_HPP
#include "float_type.hpp" #include "data_type.hpp"
namespace ck { namespace ck {
...@@ -33,175 +33,175 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_siz ...@@ -33,175 +33,175 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_siz
// load // load
__device__ int8_t __device__ int8_t
__llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
__device__ int8x2_t __device__ int8x2_t
__llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
__device__ int8x4_t __device__ int8x4_t
__llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
__device__ int16_t __device__ int16_t
__llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32_t __device__ int32_t
__llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32x2_t __device__ int32x2_t
__llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
__device__ int32x4_t __device__ int32x4_t
__llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// half // half
__device__ half_t __device__ half_t
__llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
__device__ half2_t __device__ half2_t
__llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
__device__ half4_t __device__ half4_t
__llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// float // float
__device__ float __device__ float
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
__device__ float2_t __device__ float2_t
__llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
__device__ float4_t __device__ float4_t
__llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// store // store
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
__device__ void
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// half
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, llvm_amdgcn_raw_buffer_store_fp16(half_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
// half
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// float
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata, llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// float
__device__ void
__llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
__device__ void
__llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type<T, N>::type __device__ typename vector_type<T, N>::type
...@@ -220,31 +220,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -220,31 +220,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return __llvm_amdgcn_raw_buffer_load_fp32( return llvm_amdgcn_raw_buffer_load_fp32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return __llvm_amdgcn_raw_buffer_load_fp32x2( return llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return __llvm_amdgcn_raw_buffer_load_fp32x4( return llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
vector_type<float, 8> tmp; vector_type<float, 8> tmp;
tmp.AsType<float4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.AsType<float4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<float4_t>()(Number<1>{}) = tmp.AsType<float4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float), src_wave_addr_offset + 4 * sizeof(float),
0); 0);
return tmp.AsType<float8_t>()(Number<0>{}); return tmp.AsType<float8_t>()(Number<0>{});
} }
...@@ -253,17 +253,17 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -253,17 +253,17 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return __llvm_amdgcn_raw_buffer_load_fp16( return llvm_amdgcn_raw_buffer_load_fp16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return __llvm_amdgcn_raw_buffer_load_fp16x2( return llvm_amdgcn_raw_buffer_load_fp16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return __llvm_amdgcn_raw_buffer_load_fp16x4( return llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
else if constexpr(N == 8) else if constexpr(N == 8)
...@@ -271,18 +271,18 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -271,18 +271,18 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
#if 0 #if 0
vector_type<half_t, 8> tmp; vector_type<half_t, 8> tmp;
tmp.AsType<half4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4( tmp.AsType<half4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<half4_t>()(Number<1>{}) = tmp.AsType<half4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(half_t), src_wave_addr_offset + 4 * sizeof(half_t),
0); 0);
return tmp.AsType<half8_t>()(Number<0>{}); return tmp.AsType<half8_t>()(Number<0>{});
#else #else
float4_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x4( float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<half8_t>(tmp); return as_type<half8_t>(tmp);
...@@ -293,31 +293,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -293,31 +293,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return __llvm_amdgcn_raw_buffer_load_i32( return llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return __llvm_amdgcn_raw_buffer_load_i32x2( return llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return __llvm_amdgcn_raw_buffer_load_i32x4( return llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
vector_type<int32_t, 8> tmp; vector_type<int32_t, 8> tmp;
tmp.AsType<int32x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( tmp.AsType<int32x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp.AsType<int32x4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t), src_wave_addr_offset + 4 * sizeof(int32_t),
0); 0);
return tmp.AsType<int32x8_t>()(Number<0>{}); return tmp.AsType<int32x8_t>()(Number<0>{});
} }
} }
...@@ -325,16 +325,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -325,16 +325,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return __llvm_amdgcn_raw_buffer_load_i8( return llvm_amdgcn_raw_buffer_load_i8(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return __llvm_amdgcn_raw_buffer_load_i8x2( return llvm_amdgcn_raw_buffer_load_i8x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else #else
int16_t tmp = __llvm_amdgcn_raw_buffer_load_i16( int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x2_t>(tmp); return as_type<int8x2_t>(tmp);
...@@ -343,10 +343,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -343,10 +343,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return __llvm_amdgcn_raw_buffer_load_i8x4( return llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else #else
int32_t tmp = __llvm_amdgcn_raw_buffer_load_i32( int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x4_t>(tmp); return as_type<int8x4_t>(tmp);
...@@ -357,18 +357,18 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -357,18 +357,18 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 8> tmp; vector_type<int8_t, 8> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<int8x4_t>()(Number<1>{}) = tmp.AsType<int8x4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t), src_wave_addr_offset + 4 * sizeof(int8_t),
0); 0);
return tmp.AsType<int8x8_t>()(Number<0>{}); return tmp.AsType<int8x8_t>()(Number<0>{});
#else #else
int32x2_t tmp = __llvm_amdgcn_raw_buffer_load_i32x2( int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x8_t>(tmp); return as_type<int8x8_t>(tmp);
...@@ -379,30 +379,30 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -379,30 +379,30 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 16> tmp; vector_type<int8_t, 16> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<int8x4_t>()(Number<1>{}) = tmp.AsType<int8x4_t>()(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t), src_wave_addr_offset + 4 * sizeof(int8_t),
0); 0);
tmp.AsType<int8x4_t>()(Number<2>{}) = tmp.AsType<int8x4_t>()(Number<2>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int8_t), src_wave_addr_offset + 8 * sizeof(int8_t),
0); 0);
tmp.AsType<int8x4_t>()(Number<3>{}) = tmp.AsType<int8x4_t>()(Number<3>{}) =
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int8_t), src_wave_addr_offset + 12 * sizeof(int8_t),
0); 0);
return tmp.AsType<int8x16_t>()(Number<0>{}); return tmp.AsType<int8x16_t>()(Number<0>{});
#else #else
int32x4_t tmp = __llvm_amdgcn_raw_buffer_load_i32x4( int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x16_t>(tmp); return as_type<int8x16_t>(tmp);
...@@ -428,156 +428,156 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -428,156 +428,156 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
__llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
__llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data, llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
} }
else if constexpr(is_same<T, int8_t>::value) else if constexpr(is_same<T, int8_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
__llvm_amdgcn_raw_buffer_store_i8(src_thread_data, llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
#else #else
__llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
#endif #endif
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
#else #else
__llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
#endif #endif
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
__llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 16) else if constexpr(N == 16)
{ {
__llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
} }
else if constexpr(is_same<T, half_t>::value) else if constexpr(is_same<T, half_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
__llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
__llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
vector_type<half_t, 8> tmp{src_thread_data}; vector_type<half_t, 8> tmp{src_thread_data};
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}], llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}], llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t), dst_wave_addr_offset + 4 * sizeof(half_t),
0); 0);
} }
} }
} }
......
#ifndef CK_AMD_DLOP_HPP #ifndef CK_AMD_DLOP_HPP
#define CK_AMD_DLOP_HPP #define CK_AMD_DLOP_HPP
#include "float_type.hpp" #include "data_type.hpp"
namespace ck { namespace ck {
......
#ifndef CK_AMD_INLINE_ASM_HPP #ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP
#include "float_type.hpp" #include "data_type.hpp"
namespace ck { namespace ck {
......
#ifndef CK_AMD_LLVM_INTRINSIC_HPP #ifndef CK_AMD_LLVM_INTRINSIC_HPP
#define CK_AMD_LLVM_INTRINSIC_HPP #define CK_AMD_LLVM_INTRINSIC_HPP
#include "float_type.hpp" #include "data_type.hpp"
namespace ck { namespace ck {
__device__ int32_t __llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane"); __device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane");
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_AMD_XDLOPS_HPP #ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP #define CK_AMD_XDLOPS_HPP
#include "float_type.hpp" #include "data_type.hpp"
namespace ck { namespace ck {
......
...@@ -7,8 +7,9 @@ ...@@ -7,8 +7,9 @@
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
#include "container_element_picker.hpp" #include "container_element_picker.hpp"
#include "multi_index.hpp" #include "multi_index.hpp"
#include "data_type_enum.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "float_type.hpp" #include "data_type_helper.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
......
...@@ -8,18 +8,13 @@ ...@@ -8,18 +8,13 @@
#include "bfloat16_dev.hpp" #include "bfloat16_dev.hpp"
// address space for kernel parameter // address space for kernel parameter
#define __CONSTANT__ __attribute__((address_space(4))) #define CONSTANT __attribute__((address_space(4)))
// device backend // GPU target
#define CK_DEVICE_BACKEND_AMD 1 // should enable one and only one GPU target
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
// GPU ID defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
#if 0 #error Need to define a single GPU target
#define CK_AMD_GPU_GFX906 1
#elif 1
#define CK_AMD_GPU_GFX908 1
#elif 0
#define CK_AMD_GPU_GFX1030 1
#endif #endif
// HIP version // HIP version
...@@ -36,7 +31,8 @@ ...@@ -36,7 +31,8 @@
#endif #endif
// buffer resourse // buffer resourse
#if defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) #if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(CK_AMD_GPU_GFX1030) #elif defined(CK_AMD_GPU_GFX1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
...@@ -50,10 +46,6 @@ ...@@ -50,10 +46,6 @@
#define CK_USE_AMD_INLINE_ASM 1 #define CK_USE_AMD_INLINE_ASM 1
#endif #endif
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1
#endif
// AMD DLOPS // AMD DLOPS
#ifndef CK_USE_AMD_DLOP #ifndef CK_USE_AMD_DLOP
#define CK_USE_AMD_DLOP 1 #define CK_USE_AMD_DLOP 1
...@@ -78,14 +70,6 @@ ...@@ -78,14 +70,6 @@
#define CK_USE_AMD_XDLOPS 0 #define CK_USE_AMD_XDLOPS 0
#endif #endif
#ifndef CK_USE_AMD_XDLOPS_INLINE_ASM
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
#endif
#ifndef CK_USE_AMD_XDLOPS_EMULATE
#define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes
#endif
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...@@ -104,18 +88,6 @@ ...@@ -104,18 +88,6 @@
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1
#endif #endif
#ifndef CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1
#endif
#ifndef CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK 0
#endif
#ifndef CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0
#endif
// pass tensor descriptor by value or void* // pass tensor descriptor by value or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
...@@ -131,17 +103,6 @@ ...@@ -131,17 +103,6 @@
#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 #define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
#endif #endif
// workaround: put all workaround here
// workaround for unnecessary VGPR <--> AGPR data movement when using mfma LLVM intrinsic
#ifndef CK_WORKAROUND_SWDEV_229564
#define CK_WORKAROUND_SWDEV_229564 1
#endif
// workaround for accvgpr over-allocation
#ifndef CK_WORKAROUND_SWDEV_241664
#define CK_WORKAROUND_SWDEV_241664 1
#endif
// workaround for compiler crash when compiling recursive lambda // workaround for compiler crash when compiling recursive lambda
#ifndef CK_WORKAROUND_SWDEV_275126 #ifndef CK_WORKAROUND_SWDEV_275126
#define CK_WORKAROUND_SWDEV_275126 1 #define CK_WORKAROUND_SWDEV_275126 1
...@@ -159,7 +120,7 @@ ...@@ -159,7 +120,7 @@
namespace ck { namespace ck {
enum AddressSpace enum AddressSpaceEnum_t
{ {
Generic, Generic,
Global, Global,
...@@ -168,7 +129,7 @@ enum AddressSpace ...@@ -168,7 +129,7 @@ enum AddressSpace
Vgpr Vgpr
}; };
enum InMemoryDataOperation enum InMemoryDataOperationEnum_t
{ {
Set, Set,
AtomicAdd AtomicAdd
......
#ifndef CK_DATA_TYPE_HPP #ifndef CK_FLOAT_TYPE_AMD_HPP
#define CK_DATA_TYPE_HPP #define CK_FLOAT_TYPE_AMD_HPP
#include "statically_indexed_array.hpp"
namespace ck { namespace ck {
using half_t = _Float16;
// vector_type
template <typename T, index_t N>
struct vector_type;
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// vectors"
template <typename T, index_t V, index_t N>
struct vector_type<T __attribute__((ext_vector_type(V))), N>;
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// vectors"
template <typename T, index_t V, index_t N>
struct vector_type<vector_type<T, V>, N>;
// vector_type_maker
// This is the right way to handle "vector of vectors": making a bigger vector instead
template <typename T, index_t N>
struct vector_type_maker
{
using type = vector_type<T, N>;
};
template <typename T, index_t N0, index_t N1>
struct vector_type_maker<T __attribute__((ext_vector_type(N1))), N0>
{
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N0, index_t N1>
struct vector_type_maker<vector_type<T, N1>, N0>
{
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N>
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
template <typename T, index_t N>
__host__ __device__ constexpr auto make_vector_type(Number<N>)
{
return typename vector_type_maker<T, N>::type{};
}
// scalar_type
template <typename TV>
struct scalar_type;
template <typename T, index_t N>
struct scalar_type<T __attribute__((ext_vector_type(N)))>
{
using type = T;
static constexpr index_t vector_size = N;
};
template <typename T, index_t N>
struct scalar_type<vector_type<T, N>>
{
using type = T;
static constexpr index_t vector_size = N;
};
//
template <>
struct scalar_type<float>
{
using type = float;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<half_t>
{
using type = half_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<ushort>
{
using type = ushort;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int32_t>
{
using type = int32_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int8_t>
{
using type = int8_t;
static constexpr index_t vector_size = 1;
};
//
template <typename T>
struct vector_type<T, 1>
{
using d1_t = T;
using type = d1_t;
union
{
T d1_;
StaticallyIndexedArray<T, 1> d1x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value, "wrong!");
return data_.d1x1_;
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value, "wrong!");
return data_.d1x1_;
}
};
template <typename T>
struct vector_type<T, 2>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
using type = d2_t;
union
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
}
};
template <typename T>
struct vector_type<T, 4>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
using type = d4_t;
union
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
}
};
template <typename T>
struct vector_type<T, 8>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
using type = d8_t;
union
{
d8_t d8_;
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
}
};
template <typename T>
struct vector_type<T, 16>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
using type = d16_t;
union
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
}
};
template <typename T>
struct vector_type<T, 32>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
using type = d32_t;
union
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
};
template <typename T>
struct vector_type<T, 64>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
};
template <typename T>
struct vector_type<T, 128>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
using type = d128_t;
union
{
d128_t d128_;
StaticallyIndexedArray<d1_t, 128> d1x128_;
StaticallyIndexedArray<d2_t, 64> d2x64_;
StaticallyIndexedArray<d4_t, 32> d4x32_;
StaticallyIndexedArray<d8_t, 16> d8x16_;
StaticallyIndexedArray<d16_t, 8> d16x8_;
StaticallyIndexedArray<d32_t, 4> d32x4_;
StaticallyIndexedArray<d64_t, 2> d64x2_;
StaticallyIndexedArray<d128_t, 1> d128x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
};
template <typename T>
struct vector_type<T, 256>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
typedef T d256_t __attribute__((ext_vector_type(256)));
using type = d256_t;
union
{
d256_t d256_;
StaticallyIndexedArray<d1_t, 256> d1x256_;
StaticallyIndexedArray<d2_t, 128> d2x128_;
StaticallyIndexedArray<d4_t, 64> d4x64_;
StaticallyIndexedArray<d8_t, 32> d8x32_;
StaticallyIndexedArray<d16_t, 16> d16x16_;
StaticallyIndexedArray<d32_t, 8> d32x8_;
StaticallyIndexedArray<d64_t, 4> d64x4_;
StaticallyIndexedArray<d128_t, 2> d128x2_;
StaticallyIndexedArray<d256_t, 1> d256x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// data type conversion
template <typename T>
struct type_convert
{
template <typename X>
__device__ T operator()(X x) const
{
return static_cast<T>(x);
}
};
template <>
template <>
__device__ float type_convert<float>::operator()<ushort>(ushort x) const
{
return bfloat16_to_float(x);
}
template <>
template <>
__device__ ushort type_convert<ushort>::operator()<float>(float x) const
{
return float_to_bfloat16(x);
}
// TODO: deprecate this
template <typename T>
struct inner_product_with_conversion
{
static constexpr auto convert = type_convert<T>();
template <typename X, index_t N>
__device__ T operator()(typename vector_type<X, N>::type a,
typename vector_type<X, N>::type b) const
{
const vector_type<X, N> a_vector{a};
const vector_type<X, N> b_vector{b};
T acc = 0;
static_for<0, N, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
return acc;
}
__device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); }
__device__ T operator()(int8x4_t a, int8x4_t b) const
{
const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b};
T acc = 0;
static_for<0, 4, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
});
return acc;
}
__device__ T operator()(int8x8_t a, int8x8_t b) const
{
const vector_type<int8_t, 8> a_vector{a};
const vector_type<int8_t, 8> b_vector{b};
T acc = 0;
static_for<0, 8, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
});
return acc;
}
__device__ T operator()(int8x16_t a, int8x16_t b) const
{
const vector_type<int8_t, 16> a_vector{a};
const vector_type<int8_t, 16> b_vector{b};
T acc = 0;
static_for<0, 16, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
});
return acc;
}
};
template <typename T> template <typename T>
struct NumericLimits; struct NumericLimits;
......
#ifndef CK_DATA_TYPE_ENUM_HPP
#define CK_DATA_TYPE_ENUM_HPP
namespace ck {
// this enumerate should be synchronized with include/miopen.h
typedef enum {
Half = 0,
Float = 1,
Int32 = 2,
Int8 = 3,
Int8x4 = 4,
BFloat16 = 5,
Double = 6,
Unknown = 100,
} DataTypeEnum_t;
} // namespace ck
#endif
#ifndef CK_DATA_TYPE_HELPER_HPP
#define CK_DATA_TYPE_HELPER_HPP
#include "data_type.hpp"
#include "data_type_enum.hpp"
namespace ck {
template <DataTypeEnum_t DataTypeEnum>
struct get_datatype_from_enum;
template <>
struct get_datatype_from_enum<DataTypeEnum_t::Int8>
{
using type = int8_t;
};
template <>
struct get_datatype_from_enum<DataTypeEnum_t::Int32>
{
using type = int32_t;
};
template <>
struct get_datatype_from_enum<DataTypeEnum_t::Half>
{
using type = half_t;
};
template <>
struct get_datatype_from_enum<DataTypeEnum_t::Float>
{
using type = float;
};
template <>
struct get_datatype_from_enum<DataTypeEnum_t::Double>
{
using type = double;
};
template <typename T>
struct get_datatype_enum_from_type;
template <>
struct get_datatype_enum_from_type<int8_t>
{
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int8;
};
template <>
struct get_datatype_enum_from_type<int32_t>
{
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int32;
};
template <>
struct get_datatype_enum_from_type<half_t>
{
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Half;
};
template <>
struct get_datatype_enum_from_type<float>
{
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Float;
};
template <>
struct get_datatype_enum_from_type<double>
{
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Double;
};
} // namespace ck
#endif
...@@ -5,7 +5,7 @@ namespace ck { ...@@ -5,7 +5,7 @@ namespace ck {
#include "amd_buffer_addressing_v2.hpp" #include "amd_buffer_addressing_v2.hpp"
template <AddressSpace BufferAddressSpace, typename T, typename ElementSpaceSize> template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
struct DynamicBuffer struct DynamicBuffer
{ {
using type = T; using type = T;
...@@ -18,7 +18,7 @@ struct DynamicBuffer ...@@ -18,7 +18,7 @@ struct DynamicBuffer
{ {
} }
__host__ __device__ static constexpr AddressSpace GetAddressSpace() __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{ {
return BufferAddressSpace; return BufferAddressSpace;
} }
...@@ -32,7 +32,7 @@ struct DynamicBuffer ...@@ -32,7 +32,7 @@ struct DynamicBuffer
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr const auto Get(index_t i, bool is_valid_offset) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector =
...@@ -46,7 +46,7 @@ struct DynamicBuffer ...@@ -46,7 +46,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(GetAddressSpace() == AddressSpace::Global) if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>( return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
...@@ -80,7 +80,7 @@ struct DynamicBuffer ...@@ -80,7 +80,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(GetAddressSpace() == AddressSpace::Global) if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>( amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
...@@ -92,14 +92,15 @@ struct DynamicBuffer ...@@ -92,14 +92,15 @@ struct DynamicBuffer
} }
#endif #endif
} }
else if constexpr(GetAddressSpace() == AddressSpace::Lds) else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
{ {
if(is_valid_offset) if(is_valid_offset)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*reinterpret_cast<X*>(&p_data_[i]) = x; *reinterpret_cast<X*>(&p_data_[i]) = x;
#else #else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
// inefficient
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128 // ds_write_b128
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
...@@ -119,7 +120,8 @@ struct DynamicBuffer ...@@ -119,7 +120,8 @@ struct DynamicBuffer
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) || is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value && (is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value), is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
"wrong! not implemented for this combination, please add implementation"); "wrong! not implemented for this combination, please add "
"implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
...@@ -194,7 +196,7 @@ struct DynamicBuffer ...@@ -194,7 +196,7 @@ struct DynamicBuffer
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
}; };
template <AddressSpace BufferAddressSpace = AddressSpace::Generic, template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
typename T, typename T,
typename ElementSpaceSize> typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
......
#ifndef CK_FLOAT_TYPE_AMD_HPP
#define CK_FLOAT_TYPE_AMD_HPP
#include "statically_indexed_array.hpp"
namespace ck {
using half_t = _Float16;
// vector_type
template <typename T, index_t N>
struct vector_type;
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// vectors"
template <typename T, index_t V, index_t N>
struct vector_type<T __attribute__((ext_vector_type(V))), N>;
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// vectors"
template <typename T, index_t V, index_t N>
struct vector_type<vector_type<T, V>, N>;
// vector_type_maker
// This is the right way to handle "vector of vectors": making a bigger vector instead
template <typename T, index_t N>
struct vector_type_maker
{
using type = vector_type<T, N>;
};
template <typename T, index_t N0, index_t N1>
struct vector_type_maker<T __attribute__((ext_vector_type(N1))), N0>
{
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N0, index_t N1>
struct vector_type_maker<vector_type<T, N1>, N0>
{
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N>
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
template <typename T, index_t N>
__host__ __device__ constexpr auto make_vector_type(Number<N>)
{
return typename vector_type_maker<T, N>::type{};
}
// scalar_type
template <typename TV>
struct scalar_type;
template <typename T, index_t N>
struct scalar_type<T __attribute__((ext_vector_type(N)))>
{
using type = T;
static constexpr index_t vector_size = N;
};
template <typename T, index_t N>
struct scalar_type<vector_type<T, N>>
{
using type = T;
static constexpr index_t vector_size = N;
};
//
template <>
struct scalar_type<float>
{
using type = float;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<half_t>
{
using type = half_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<ushort>
{
using type = ushort;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int32_t>
{
using type = int32_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int8_t>
{
using type = int8_t;
static constexpr index_t vector_size = 1;
};
//
template <typename T>
struct vector_type<T, 1>
{
using d1_t = T;
using type = d1_t;
union
{
T d1_;
StaticallyIndexedArray<T, 1> d1x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value, "wrong!");
return data_.d1x1_;
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value, "wrong!");
return data_.d1x1_;
}
};
template <typename T>
struct vector_type<T, 2>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
using type = d2_t;
union
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
}
};
template <typename T>
struct vector_type<T, 4>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
using type = d4_t;
union
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
}
};
template <typename T>
struct vector_type<T, 8>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
using type = d8_t;
union
{
d8_t d8_;
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
}
};
template <typename T>
struct vector_type<T, 16>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
using type = d16_t;
union
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
}
};
template <typename T>
struct vector_type<T, 32>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
using type = d32_t;
union
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
};
template <typename T>
struct vector_type<T, 64>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
};
template <typename T>
struct vector_type<T, 128>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
using type = d128_t;
union
{
d128_t d128_;
StaticallyIndexedArray<d1_t, 128> d1x128_;
StaticallyIndexedArray<d2_t, 64> d2x64_;
StaticallyIndexedArray<d4_t, 32> d4x32_;
StaticallyIndexedArray<d8_t, 16> d8x16_;
StaticallyIndexedArray<d16_t, 8> d16x8_;
StaticallyIndexedArray<d32_t, 4> d32x4_;
StaticallyIndexedArray<d64_t, 2> d64x2_;
StaticallyIndexedArray<d128_t, 1> d128x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
};
template <typename T>
struct vector_type<T, 256>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
typedef T d256_t __attribute__((ext_vector_type(256)));
using type = d256_t;
union
{
d256_t d256_;
StaticallyIndexedArray<d1_t, 256> d1x256_;
StaticallyIndexedArray<d2_t, 128> d2x128_;
StaticallyIndexedArray<d4_t, 64> d4x64_;
StaticallyIndexedArray<d8_t, 32> d8x32_;
StaticallyIndexedArray<d16_t, 16> d16x16_;
StaticallyIndexedArray<d32_t, 8> d32x8_;
StaticallyIndexedArray<d64_t, 4> d64x4_;
StaticallyIndexedArray<d128_t, 2> d128x2_;
StaticallyIndexedArray<d256_t, 1> d256x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// data type conversion
template <typename T>
struct type_convert
{
template <typename X>
__device__ T operator()(X x) const
{
return static_cast<T>(x);
}
};
template <>
template <>
__device__ float type_convert<float>::operator()<ushort>(ushort x) const
{
return bfloat16_to_float(x);
}
template <>
template <>
__device__ ushort type_convert<ushort>::operator()<float>(float x) const
{
return float_to_bfloat16(x);
}
template <typename T>
struct inner_product_with_conversion
{
static constexpr auto convert = type_convert<T>();
template <typename X, index_t N>
__device__ T operator()(typename vector_type<X, N>::type a,
typename vector_type<X, N>::type b) const
{
const vector_type<X, N> a_vector{a};
const vector_type<X, N> b_vector{b};
T acc = 0;
static_for<0, N, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
return acc;
}
__device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); }
__device__ T operator()(int8x4_t a, int8x4_t b) const
{
const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b};
T acc = 0;
static_for<0, 4, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
});
return acc;
}
__device__ T operator()(int8x8_t a, int8x8_t b) const
{
const vector_type<int8_t, 8> a_vector{a};
const vector_type<int8_t, 8> b_vector{b};
T acc = 0;
static_for<0, 8, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
});
return acc;
}
__device__ T operator()(int8x16_t a, int8x16_t b) const
{
const vector_type<int8_t, 16> a_vector{a};
const vector_type<int8_t, 16> b_vector{b};
T acc = 0;
static_for<0, 16, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
});
return acc;
}
};
} // namespace ck
#endif
...@@ -127,7 +127,8 @@ struct MagicDivision ...@@ -127,7 +127,8 @@ struct MagicDivision
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{ {
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32); uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
uint32_t tmp = ((uint64_t)dividend_u32 * (uint64_t)multiplier) >> 32; uint32_t tmp =
(static_cast<uint64_t>(dividend_u32) * static_cast<uint64_t>(multiplier)) >> 32;
return (tmp + dividend_u32) >> shift; return (tmp + dividend_u32) >> shift;
} }
#else #else
......
...@@ -150,7 +150,15 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -150,7 +150,15 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y) __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{ {
if(x == y || x == 0) if(x < 0)
{
return gcd(-x, y);
}
else if(y < 0)
{
return gcd(x, -y);
}
else if(x == y || x == 0)
{ {
return y; return y;
} }
...@@ -160,11 +168,11 @@ __host__ __device__ constexpr index_t gcd(index_t x, index_t y) ...@@ -160,11 +168,11 @@ __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
} }
else if(x > y) else if(x > y)
{ {
return gcd(x - y, y); return gcd(x % y, y);
} }
else else
{ {
return gcd(x, y - x); return gcd(x, y % x);
} }
} }
...@@ -181,7 +189,7 @@ template <typename X, ...@@ -181,7 +189,7 @@ template <typename X,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false> typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys) __host__ __device__ constexpr auto gcd(X x, Ys... ys)
{ {
return gcd(x, ys...); return gcd(x, gcd(ys...));
} }
// least common multiple // least common multiple
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace ck { namespace ck {
template <AddressSpace BufferAddressSpace, typename T, index_t N> template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
struct StaticBuffer : public StaticallyIndexedArray<T, N> struct StaticBuffer : public StaticallyIndexedArray<T, N>
{ {
using type = T; using type = T;
...@@ -13,7 +13,7 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -13,7 +13,7 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr AddressSpace GetAddressSpace() __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{ {
return BufferAddressSpace; return BufferAddressSpace;
} }
...@@ -23,7 +23,9 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -23,7 +23,9 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
}; };
template <AddressSpace BufferAddressSpace = AddressSpace::Generic, typename T, index_t N> template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
typename T,
index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>) __host__ __device__ constexpr auto make_static_buffer(Number<N>)
{ {
return StaticBuffer<BufferAddressSpace, T, N>{}; return StaticBuffer<BufferAddressSpace, T, N>{};
......
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
namespace ck { namespace ck {
__device__ void __llvm_amdgcn_s_barrier() __asm("llvm.amdgcn.s.barrier");
__device__ void block_sync_lds() __device__ void block_sync_lds()
{ {
#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
...@@ -15,11 +13,9 @@ __device__ void block_sync_lds() ...@@ -15,11 +13,9 @@ __device__ void block_sync_lds()
s_barrier \ s_barrier \
" ::); " ::);
#else #else
__llvm_amdgcn_s_barrier(); __syncthreads();
#endif #endif
} }
__device__ void block_sync_lds_vmem() { __llvm_amdgcn_s_barrier(); }
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_TYPE_HELPER_HPP
#define CK_TYPE_HELPER_HPP
#include "float_type.hpp"
namespace ck {
template <char tid>
struct get_type_from_type_id
{
using type = float;
};
template <>
struct get_type_from_type_id<'H'>
{
using type = half_t;
};
template <>
struct get_type_from_type_id<'F'>
{
using type = float;
};
template <>
struct get_type_from_type_id<'D'>
{
using type = double;
};
} // namespace ck
#endif
#include "common_header.hpp" #include "common_header.hpp"
#include "type_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v1r2.hpp" #include "gridwise_dynamic_gemm_dlops_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
using namespace ck; using namespace ck;
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type; constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type; constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type; constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize; constexpr index_t BlockSize = CK_PARAM_BlockSize;
...@@ -61,7 +64,8 @@ constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDs ...@@ -61,7 +64,8 @@ constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDs
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw_prepare( extern "C" __global__ void
dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
int n, int n,
int c, int c,
int hi, int hi,
...@@ -147,48 +151,48 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v4r4_nchw_k ...@@ -147,48 +151,48 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v4r4_nchw_k
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm = using GridwiseGemm =
GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize, GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, /* ToDo tunable */ InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
AKMGridDesc, AKMGridDesc,
BKNGridDesc, BKNGridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM10, M1N1ThreadClusterM10,
M1N1ThreadClusterN10, M1N1ThreadClusterN10,
M1N1ThreadClusterM11, M1N1ThreadClusterM11,
M1N1ThreadClusterN11, M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1, ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1, BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowIteratorHacks>;
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
...@@ -212,14 +216,14 @@ extern "C" __global__ void ...@@ -212,14 +216,14 @@ extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k_m0_m1_grid_desc, const void CONSTANT* p_a_k_m0_m1_grid_desc,
const void __CONSTANT__* p_b_k_n0_n1_grid_desc, const void CONSTANT* p_b_k_n0_n1_grid_desc,
const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc, const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -283,48 +287,48 @@ extern "C" __global__ void ...@@ -283,48 +287,48 @@ extern "C" __global__ void
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm = using GridwiseGemm =
GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize, GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, /* ToDo tunable */ InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
AKMGridDesc, AKMGridDesc,
BKNGridDesc, BKNGridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM10, M1N1ThreadClusterM10,
M1N1ThreadClusterN10, M1N1ThreadClusterN10,
M1N1ThreadClusterM11, M1N1ThreadClusterM11,
M1N1ThreadClusterN11, M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1, ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1, BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowIteratorHacks>;
constexpr auto a_k_m0_m1_grid_desc_tmp = constexpr auto a_k_m0_m1_grid_desc_tmp =
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
......
#include "common_header.hpp" #include "common_header.hpp"
#include "type_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" #include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
...@@ -7,9 +6,13 @@ ...@@ -7,9 +6,13 @@
using namespace ck; using namespace ck;
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type; constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type; constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type; constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize; constexpr index_t BlockSize = CK_PARAM_BlockSize;
...@@ -149,7 +152,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare( ...@@ -149,7 +152,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc, AK0MK1GridDesc,
BK0NK1GridDesc, BK0NK1GridDesc,
CMNGridDesc, CMNGridDesc,
...@@ -213,10 +216,10 @@ extern "C" __global__ void ...@@ -213,10 +216,10 @@ extern "C" __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void __CONSTANT__* p_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc, const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -286,7 +289,7 @@ extern "C" __global__ void ...@@ -286,7 +289,7 @@ extern "C" __global__ void
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc, AK0MK1GridDesc,
BK0NK1GridDesc, BK0NK1GridDesc,
CMNGridDesc, CMNGridDesc,
......
#include "common_header.hpp" #include "common_header.hpp"
#include "type_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" #include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
...@@ -7,9 +6,13 @@ ...@@ -7,9 +6,13 @@
using namespace ck; using namespace ck;
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type; constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type; constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type; constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize; constexpr index_t BlockSize = CK_PARAM_BlockSize;
...@@ -149,7 +152,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare( ...@@ -149,7 +152,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc, AK0MK1GridDesc,
BK0NK1GridDesc, BK0NK1GridDesc,
CMNGridDesc, CMNGridDesc,
...@@ -213,10 +216,10 @@ extern "C" __global__ void ...@@ -213,10 +216,10 @@ extern "C" __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void __CONSTANT__* p_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc, const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -287,7 +290,7 @@ extern "C" __global__ void ...@@ -287,7 +290,7 @@ extern "C" __global__ void
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc, AK0MK1GridDesc,
BK0NK1GridDesc, BK0NK1GridDesc,
CMNGridDesc, CMNGridDesc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment