Commit 211dae82 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into miopen_downstream_all

parents 5890e300 d5297aba
...@@ -202,6 +202,22 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, ...@@ -202,6 +202,22 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// atomic add
// int
__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
// float
__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
...@@ -624,8 +640,130 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -624,8 +640,130 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
} }
} }
template <typename T, index_t N>
__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
vector_type<float, 2> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(float),
0);
}
else if constexpr(N == 4)
{
vector_type<float, 4> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(float),
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(float),
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(float),
0);
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
vector_type<int32_t, 2> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t),
0);
}
else if constexpr(N == 4)
{
vector_type<int32_t, 4> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t),
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(int32_t),
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(int32_t),
0);
}
}
}
// buffer_load requires: // buffer_load requires:
// 1) p_src_wave must be in global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
...@@ -659,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, ...@@ -659,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
} }
// buffer_load requires: // buffer_load requires:
// 1) p_src_wave must be in global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
...@@ -687,8 +825,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -687,8 +825,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
} }
// buffer_store requires: // buffer_store requires:
// 1) p_dst_wave must be global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave to be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data, __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
...@@ -720,5 +858,40 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -720,5 +858,40 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#endif #endif
} }
// buffer_atomic_add requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void
amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( ...@@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x1f32; struct intrin_mfma_f32_32x32x1f32;
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> struct intrin_mfma_f32_32x32x1f32<64, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) = reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_a, reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_b, reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
} }
}; };
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset> struct intrin_mfma_f32_32x32x1f32<32, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) = reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x2f32; struct intrin_mfma_f32_32x32x2f32;
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset> struct intrin_mfma_f32_32x32x2f32<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
llvm_intrin_amdgcn_mfma_f32_32x32x2f32( reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x4f32; struct intrin_mfma_f32_16x16x4f32;
template <index_t COffset> template <>
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset> struct intrin_mfma_f32_16x16x4f32<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
llvm_intrin_amdgcn_mfma_f32_16x16x4f32( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x1f32; struct intrin_mfma_f32_16x16x1f32;
template <index_t COffset> template <>
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset> struct intrin_mfma_f32_16x16x1f32<16, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
llvm_intrin_amdgcn_mfma_f32_16x16x1f32( reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x1f32; struct intrin_mfma_f32_4x4x1f32;
template <index_t COffset> template <>
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset> struct intrin_mfma_f32_4x4x1f32<4, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
llvm_intrin_amdgcn_mfma_f32_4x4x1f32( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
} }
}; };
template <index_t COffset> template <>
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset> struct intrin_mfma_f32_4x4x1f32<8, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
llvm_intrin_amdgcn_mfma_f32_4x4x1f32( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_a, reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_b, reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4f16; struct intrin_mfma_f32_32x32x4f16;
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x4f16<64, 64, COffset> struct intrin_mfma_f32_32x32x4f16<64, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) = reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
llvm_intrin_amdgcn_mfma_f32_32x32x4f16( reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_a, reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_b, reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
} }
}; };
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x4f16<32, 64, COffset> struct intrin_mfma_f32_32x32x4f16<32, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) = reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
llvm_intrin_amdgcn_mfma_f32_32x32x4f16( reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8f16; struct intrin_mfma_f32_32x32x8f16;
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset> struct intrin_mfma_f32_32x32x8f16<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
llvm_intrin_amdgcn_mfma_f32_32x32x8f16( reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x16f16; struct intrin_mfma_f32_16x16x16f16;
template <index_t COffset> template <>
struct intrin_mfma_f32_16x16x16f16<16, 16, COffset> struct intrin_mfma_f32_16x16x16f16<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
llvm_intrin_amdgcn_mfma_f32_16x16x16f16( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x4f16; struct intrin_mfma_f32_16x16x4f16;
template <index_t COffset> template <>
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset> struct intrin_mfma_f32_16x16x4f16<16, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
llvm_intrin_amdgcn_mfma_f32_16x16x4f16( reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x4f16; struct intrin_mfma_f32_4x4x4f16;
template <index_t COffset> template <>
struct intrin_mfma_f32_4x4x4f16<4, 64, COffset> struct intrin_mfma_f32_4x4x4f16<4, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
llvm_intrin_amdgcn_mfma_f32_4x4x4f16( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
} }
}; };
template <index_t COffset> template <>
struct intrin_mfma_f32_4x4x4f16<8, 64, COffset> struct intrin_mfma_f32_4x4x4f16<8, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
llvm_intrin_amdgcn_mfma_f32_4x4x4f16( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_a, reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_b, reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
} }
}; };
...@@ -448,7 +340,6 @@ template <index_t MPerWave, index_t NPerWave> ...@@ -448,7 +340,6 @@ template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, __device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b, const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c); c_vec16_1_t::VecType reg_c);
template <> template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, __device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b, const ushort2_t* reg_b,
......
...@@ -48,7 +48,7 @@ struct Array<TData, 0> ...@@ -48,7 +48,7 @@ struct Array<TData, 0>
template <typename X, typename... Xs> template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{ {
using data_type = remove_cv_t<remove_reference_t<X>>; using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}}; return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}};
} }
......
...@@ -85,13 +85,13 @@ ...@@ -85,13 +85,13 @@
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#endif #endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#endif #endif
// pass tensor descriptor by value or void* // pass tensor descriptor by value or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
// merge transformation use magic number division // merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
......
...@@ -43,18 +43,15 @@ struct DynamicBuffer ...@@ -43,18 +43,15 @@ struct DynamicBuffer
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X, template <typename X,
typename enable_if< typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector = constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
...@@ -71,15 +68,14 @@ struct DynamicBuffer ...@@ -71,15 +68,14 @@ struct DynamicBuffer
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return amd_buffer_load_invalid_element_return_return_zero< return amd_buffer_load_invalid_element_return_return_zero<remove_cvref_t<T>,
remove_cv_t<remove_reference_t<T>>, t_per_x>(
t_per_x>(p_data_, i, is_valid_element, element_space_size_); p_data_, i, is_valid_element, element_space_size_);
} }
else else
{ {
return amd_buffer_load_invalid_element_return_customized_value< return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
remove_cv_t<remove_reference_t<T>>, t_per_x>(
t_per_x>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
} }
} }
...@@ -98,18 +94,15 @@ struct DynamicBuffer ...@@ -98,18 +94,15 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename enable_if< typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector = constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
...@@ -119,7 +112,7 @@ struct DynamicBuffer ...@@ -119,7 +112,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>( amd_buffer_store<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_);
#else #else
if(is_valid_element) if(is_valid_element)
...@@ -140,70 +133,65 @@ struct DynamicBuffer ...@@ -140,70 +133,65 @@ struct DynamicBuffer
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128 // ds_write_b128
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type, if constexpr(is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value)
int8_t>::value)
{ {
static_assert( static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) || (is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8x2_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) || (is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) || (is_same<remove_cvref_t<T>, int8x4_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) || (is_same<remove_cvref_t<T>, int8x8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value && is_same<remove_cvref_t<X>, int8x8_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) || (is_same<remove_cvref_t<T>, int8x16_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value && is_same<remove_cvref_t<X>, int8x16_t>::value),
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value), "wrong! not implemented for this combination, please add "
"wrong! not implemented for this combination, please add " "implementation");
"implementation");
if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int8_t*>(&p_data_[i]) = *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int8_t*>(&x); *c_style_pointer_cast<const int8_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) is_same<remove_cvref_t<X>, int8x2_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) = *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int16_t*>(&x); *c_style_pointer_cast<const int16_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) is_same<remove_cvref_t<X>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
int8x4_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
int8x8_t>::value && is_same<remove_cvref_t<X>, int8x8_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) = *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x); *c_style_pointer_cast<const int32x2_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
int8x16_t>::value && is_same<remove_cvref_t<X>, int8x16_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
...@@ -227,6 +215,35 @@ struct DynamicBuffer ...@@ -227,6 +215,35 @@ struct DynamicBuffer
} }
} }
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem");
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
#else
if(is_valid_element)
{
atomicAdd(&p_data_[i], x);
}
#endif
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
......
...@@ -114,12 +114,11 @@ struct MagicDivision ...@@ -114,12 +114,11 @@ struct MagicDivision
__host__ __device__ static constexpr uint32_t __host__ __device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{ {
uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32; uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift; return (tmp + dividend) >> shift;
} }
#if 1 // debug // magic division for int32_t
// HACK: magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct // non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended // TODO: figure out how to do magic number divison for int32_t as dividended
...@@ -127,27 +126,9 @@ struct MagicDivision ...@@ -127,27 +126,9 @@ struct MagicDivision
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{ {
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32); uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
uint32_t tmp = uint32_t tmp = __umulhi(dividend_u32, multiplier);
(static_cast<uint64_t>(dividend_u32) * static_cast<uint64_t>(multiplier)) >> 32;
return (tmp + dividend_u32) >> shift; return (tmp + dividend_u32) >> shift;
} }
#else
// the inline ASM is producing wrong result
__host__ __device__ static int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t r;
asm volatile("\n \
v_mul_hi_u32 %0, %1, %2 \n \
v_add_u32_e32 %0, %1, %0 \n \
v_lshrrev_b32_e32 %0, %3, %0 \n \
"
: "=v"(r)
: "v"(as_type<uint32_t>(dividend_i32)), "s"(multiplier), "s"(shift));
return as_type<int32_t>(r);
}
#endif
}; };
} // namespace ck } // namespace ck
......
...@@ -55,6 +55,98 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -55,6 +55,98 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
}; };
template <AddressSpaceEnum_t BufferAddressSpace,
typename T,
index_t N,
bool InvalidElementUseNumericalZeroValue>
struct StaticBufferV2 : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
using VecBaseType = typename T::d1_t;
__host__ __device__ static constexpr index_t GetVectorSize()
{
return sizeof(typename T::type) / sizeof(VecBaseType);
}
static constexpr index_t vector_size = GetVectorSize();
VecBaseType invalid_element_value_ = VecBaseType{0};
T invalid_vec_value_ = T{0};
__host__ __device__ constexpr StaticBufferV2() : base{} {}
__host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value)
: base{},
invalid_vec_value_{invalid_element_value},
invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{
return BufferAddressSpace;
}
template <index_t I>
__host__ __device__ constexpr auto& GetVector(Number<I> vec_id)
{
return this->At(vec_id);
}
template <index_t I>
__host__ __device__ constexpr const auto& GetVector(Number<I> vec_id) const
{
return this->At(vec_id);
}
template <index_t I>
__host__ __device__ constexpr auto& GetElement(Number<I> i, bool)
{
constexpr auto vec_id = Number<i / vector_size>{};
constexpr auto vec_off = Number<i % vector_size>{};
return this->At(vec_id).template AsType<VecBaseType>()(vec_off);
}
template <index_t I>
__host__ __device__ constexpr auto GetElement(Number<I> i, bool is_valid_element) const
{
constexpr auto vec_id = Number<i / vector_size>{};
constexpr auto vec_off = Number<i % vector_size>{};
if constexpr(InvalidElementUseNumericalZeroValue)
{
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
: VecBaseType{0};
}
else
{
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
: invalid_element_value_;
}
}
template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I> i) const
{
return GetElement(i, true);
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i)
{
return GetElement(i, true);
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N> template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>) __host__ __device__ constexpr auto make_static_buffer(Number<N>)
{ {
......
...@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...); return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
} }
} // namespace ck } // namespace ck
......
...@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>> ...@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>>
return container_reduce( return container_reduce(
Tuple<Ts...>{}, Tuple<Ts...>{},
[](auto x, bool r) { [](auto x, bool r) {
return is_known_at_compile_time< return is_known_at_compile_time<remove_cvref_t<decltype(x)>>::value & r;
remove_cv_t<remove_reference_t<decltype(x)>>>::value &
r;
}, },
true); true);
} }
......
...@@ -374,13 +374,8 @@ extern "C" __global__ void ...@@ -374,13 +374,8 @@ extern "C" __global__ void
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
CGridBlockCluster_BlockId_To_GM10_GN10{})); CGridBlockCluster_BlockId_To_GM10_GN10{}));
const auto desc_tuple = *reinterpret_cast<const DescTuple*>( const auto desc_tuple =
#pragma clang diagnostic push *reinterpret_cast<const DescTuple*>(cast_pointer_to_generic_address_space(p_desc_tuple));
#pragma clang diagnostic ignored "-Wold-style-cast"
// TODO: how to cast?
(const void*)p_desc_tuple
#pragma clang diagnostic pop
);
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
......
...@@ -13,9 +13,15 @@ include_directories(BEFORE ...@@ -13,9 +13,15 @@ include_directories(BEFORE
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp)
set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp)
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE})
add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE})
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)
target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor)
target_link_libraries(gemm_driver_offline PRIVATE host_tensor)
#ifndef DEBUG_HPP
#define DEBUG_HPP
namespace debug {
namespace debug_driver_gemm_xdlops_v2r3 {
// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping
static ck::index_t M01 = 1;
static ck::index_t N01 = 1;
} // namespace debug_driver_gemm_xdlops_v2r3
} // namespace debug
#endif
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp" #include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp" #include "driver_gemm_xdlops_v2r3.hpp"
#include "debug.hpp"
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
...@@ -48,17 +49,17 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -48,17 +49,17 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 1 #if 0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32 // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4; constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 2;
...@@ -76,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -76,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1 #elif 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -84,9 +85,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -84,9 +85,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 2;
...@@ -105,16 +106,16 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -105,16 +106,16 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1 #elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16 // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 2;
...@@ -133,16 +134,16 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -133,16 +134,16 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1 #elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16 // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4; constexpr index_t NRepeat = 4;
...@@ -159,34 +160,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -159,34 +160,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#endif #endif
...@@ -208,40 +181,42 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -208,40 +181,42 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( // clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
//clang-format on
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
...@@ -263,8 +238,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -263,8 +238,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerWave, GemmMPerXDL,
GemmNPerWave, GemmNPerXDL,
GemmK1, GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
...@@ -289,19 +264,23 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -289,19 +264,23 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_m1_m2_n_grid_step_hacks), decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat false, // CAccessOrderMRepeatNRepeat
false, // ABlockLdsExtraM
false // BBlockLdsExtraN
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), >(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc, out_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc, in_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_step_hacks, out_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_m1_m2_n_grid_step_hacks, in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
......
...@@ -49,7 +49,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -49,7 +49,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0 #if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32 // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 256;
...@@ -77,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -77,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0 #elif 0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32 // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -104,8 +104,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -104,8 +104,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16 // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 256;
...@@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16 // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -160,23 +160,91 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -160,23 +160,91 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif #elif 1
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
const auto descs = constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc, constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
wei_k_y_x_c_desc,
in_n_hi_wi_c_desc, constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
conv_strides, #elif 0
conv_dilations, // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
in_left_pads, constexpr index_t BlockSize = 128;
in_right_pads,
I0, constexpr index_t GemmMPerBlock = 128;
I0, constexpr index_t GemmNPerBlock = 64;
Number<GemmK1>{}); constexpr index_t GemmKPerBlock = 4;
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; constexpr index_t GemmMPerWave = 32;
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; constexpr index_t GemmNPerWave = 32;
const auto in_gemmm_gemmn_grid_desc = descs[I2]; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
...@@ -185,7 +253,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -185,7 +253,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-:
// gemmk1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
...@@ -195,25 +264,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -195,25 +264,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( // clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
// clang-format on
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
...@@ -223,64 +294,110 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -223,64 +294,110 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_gemm_xdlops_v2r3< const auto ConvStrideH = conv_strides[I0];
BlockSize, const auto ConvStrideW = conv_strides[I1];
TInWei,
TAcc, const auto ConvDilationH = conv_dilations[I0];
TOut, const auto ConvDilationW = conv_dilations[I1];
InMemoryDataOperationEnum_t::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
decltype(in_gemmm_gemmn_grid_desc),
GemmMPerBlock, const auto YTilda = ConvStrideH / GcdStrideDilationH;
GemmNPerBlock, const auto XTilda = ConvStrideW / GcdStrideDilationW;
GemmKPerBlock,
GemmMPerWave, float ave_time = 0;
GemmNPerWave,
GemmK1, for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda)
MRepeat, {
NRepeat, for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda)
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, {
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, const auto descs =
Sequence<1, 0, 2>, transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
Sequence<1, 0, 2>, out_n_ho_wo_k_desc,
2, wei_k_y_x_c_desc,
GemmABlockTransferSrcScalarPerVector_GemmK1, in_n_hi_wi_c_desc,
GemmABlockTransferDstScalarPerVector_GemmK1, conv_strides,
false, // don't move back src coordinate after threadwise copy conv_dilations,
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, in_left_pads,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, in_right_pads,
Sequence<2, 0, 1>, i_ytilda,
Sequence<0, 2, 1>, i_xtilda,
1, Number<GemmK1>{});
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1, const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
false, // don't move back src coordinate after threadwise copy const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto in_gemmm_gemmn_grid_desc = descs[I2];
const auto GemmK0 = out_gemmk0_gemmm_gemmk1_grid_desc.GetLength(I0);
if(GemmK0 != 0)
{
ave_time += driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<2, 0, 1>,
Sequence<0, 2, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
#if 0 #if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else #else
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
#endif #endif
7, 7,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_m1_m2_n_grid_step_hacks), decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
true // CAccessOrderMRepeatNRepeat true, // CAccessOrderMRepeatNRepeat
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), false, // ABlockLdsExtraM
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), false // BBlockLdsExtraN
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), >(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc, static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmn_gemmk1_grid_desc, static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
in_gemmm_gemmn_grid_desc, out_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks, in_gemmm_gemmn_grid_desc,
in_m0_m1_m2_n_grid_step_hacks, debug::debug_driver_gemm_xdlops_v2r3::M01,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, debug::debug_driver_gemm_xdlops_v2r3::N01,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmm_gemmk1_grid_step_hacks,
nrepeat); wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
}
}
}
{ {
const auto N = out_n_ho_wo_k_lengths[I0]; const auto N = out_n_ho_wo_k_lengths[I0];
......
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" #include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp" #include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei, template <typename TInWei,
...@@ -14,17 +14,17 @@ template <typename TInWei, ...@@ -14,17 +14,17 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1(
const InLengths& in_n_hi_wi_c_lengths, const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths, const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths, const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations&,
const InLeftPads& in_left_pads, const InLeftPads&,
const InRightPads& in_right_pads, const InRightPads&,
const Tensor<TInWei>& in_n_hi_wi_c, Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c, const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k, const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -35,11 +35,6 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -35,11 +35,6 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
...@@ -53,8 +48,8 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -53,8 +48,8 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 1 #if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32 // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 256;
...@@ -77,12 +72,12 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -77,12 +72,12 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32 // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -105,16 +100,16 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -105,16 +100,16 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0 #elif 1
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16 // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerWave = 32;
...@@ -122,7 +117,7 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -122,7 +117,7 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 4; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
...@@ -130,18 +125,46 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -130,18 +125,46 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
...@@ -149,10 +172,10 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -149,10 +172,10 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
...@@ -161,80 +184,139 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -161,80 +184,139 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif #elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 128;
const auto descs = constexpr index_t GemmMPerBlock = 128;
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, constexpr index_t GemmNPerBlock = 64;
in_n_hi_wi_c_desc, constexpr index_t GemmKPerBlock = 4;
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; constexpr index_t GemmMPerWave = 32;
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; constexpr index_t GemmNPerWave = 32;
const auto out_gemmm_gemmn_grid_desc = descs[I2]; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0>{}, // 1+: gemmm
Sequence<0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0>{}, // 0-: gemmk0
Sequence<0, 0, 0>{}, // 1-: gemmm
Sequence<0, 0, 0>{})); // 2-: gemmk1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0>{}, // 1+: gemmn
Sequence<0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0>{}, // 0-: Gemmk0
Sequence<0, 0, 0>{}, // 1-: Gemmn
Sequence<0, 0, 0>{})); // 2-: Gemmk1
// clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple( make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // clang-format on
constexpr auto out_m0_m1_m2_n_grid_step_hacks = constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
out_n_ho_wo_k_desc,
wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
conv_strides,
Number<GemmK1>{});
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto in_gemmm_gemmn_grid_desc = descs[I2];
float ave_time = driver_gemm_xdlops_v2r3< float ave_time = driver_gemm_xdlops_v2r3<
BlockSize, BlockSize,
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerWave, GemmMPerWave,
GemmNPerWave, GemmNPerWave,
GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
...@@ -247,32 +329,40 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -247,32 +329,40 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>, Sequence<2, 0, 1>,
Sequence<1, 0, 2>, Sequence<0, 2, 1>,
2, 1,
GemmBBlockTransferSrcScalarPerVector_GemmK1, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1, GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, #if 0
6, Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
#endif
7,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks), decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat true, // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), false, // ABlockLdsExtraM
false // BBlockLdsExtraN
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc, in_gemmm_gemmn_grid_desc,
out_gemmm_gemmn_grid_desc, debug::debug_driver_gemm_xdlops_v2r3::M01,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks, debug::debug_driver_gemm_xdlops_v2r3::N01,
in_gemmk0_gemmn_gemmk1_grid_step_hacks, out_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); nrepeat);
{ {
...@@ -280,16 +370,13 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -280,16 +370,13 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
const auto K = out_n_ho_wo_k_lengths[I3]; const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3]; const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1]; const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2]; const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1]; const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2]; const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time; (std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
...@@ -298,5 +385,5 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -298,5 +385,5 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
} }
// copy result back to host // copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
} }
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template <typename TIn,
typename TWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
typename GridSizeType>
void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TIn>& in_n_c_hi_wi,
Tensor<TWei>& wei_k_c_y_x,
const Tensor<TOut>& out_n_k_ho_wo,
GridSizeType desired_grid_size,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto N = in_n_c_hi_wi_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_desc.GetLength(I1);
const auto Ho = out_n_k_ho_wo_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_desc.GetLength(I2);
const auto X = wei_k_c_y_x_desc.GetLength(I3);
const auto GemmM = K;
const auto GemmN = Y * X * C;
const auto GemmKTotal = N * Ho * Wo;
const auto GemmK = GemmKTotal / GemmK1;
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch));
const index_t GemmK0 = BatchLen * GemmKPerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
<< std::endl;
const auto descs =
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{},
GemmKBatch,
GemmKPad);
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB
Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM
Sequence<0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemB
Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM
Sequence<0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmB
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 1, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
const auto driver_gemm_xdlops =
driver_gemm_xdlops_v2r4<BlockSize,
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::AtomicAdd,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1,
Sequence<0, 2, 1, 3>,
Sequence<0, 2, 1, 3>,
3,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1, 3>,
Sequence<0, 2, 1, 3>,
3,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false,
true,
true>;
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
driver_gemm_xdlops(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
0);
// copy result back to host
wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data());
}
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TIn,
typename TWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TIn>& in_n_c_hi_wi,
Tensor<TWei>& wei_k_c_y_x,
const Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 1, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true, // ABlockLdsExtraM
true // BBlockLdsExtraN
>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data());
}
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template <typename TIn,
typename TWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
typename GridSizeType>
void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TIn>& in_n_hi_wi_c,
Tensor<TWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
GridSizeType desired_grid_size,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto N = in_n_hi_wi_c_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_desc.GetLength(I3);
const auto Ho = out_n_ho_wo_k_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_desc.GetLength(I1);
const auto X = wei_k_y_x_c_desc.GetLength(I2);
const auto GemmM = Y * X * C;
const auto GemmN = K;
const auto GemmKTotal = N * Ho * Wo;
const auto GemmK = GemmKTotal / GemmK1;
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch));
const index_t GemmK0 = BatchLen * GemmKPerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
<< std::endl;
const auto descs =
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad(
in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{},
GemmKBatch,
GemmKPad);
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmKBatch
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmKBatch
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4<
BlockSize,
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::AtomicAdd,
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerXDL,
GemmNPerXDL,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
2,
GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6,
GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true,
true>;
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
driver_gemm_xdlops(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
0);
// copy result back to host
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
}
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" #include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r2.hpp" #include "driver_gemm_xdlops_v2r3.hpp"
#include "debug.hpp"
template <typename TInWei, template <typename TIn,
typename TWei,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
typename InLengths, typename InLengths,
...@@ -14,7 +16,7 @@ template <typename TInWei, ...@@ -14,7 +16,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths, const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths, const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths, const OutLengths& out_n_ho_wo_k_lengths,
...@@ -22,9 +24,9 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( ...@@ -22,9 +24,9 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c, const Tensor<TIn>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c, Tensor<TWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k, const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -36,8 +38,8 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( ...@@ -36,8 +38,8 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
...@@ -48,7 +50,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( ...@@ -48,7 +50,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 1 #if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32 // [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -56,154 +58,199 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( ...@@ -56,154 +58,199 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4; constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif #endif
const auto descs = const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, in_n_hi_wi_c_desc,
in_n_hi_wi_c_desc, wei_k_y_x_c_desc,
out_n_ho_wo_k_desc, out_n_ho_wo_k_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
Number<GemmK1>{}); Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
constexpr auto out_m0_m1_m2_n_grid_step_hacks = Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
Sequence<0, 0, 1, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_gemm_xdlops_v2r2< float ave_time = driver_gemm_xdlops_v2r3<
BlockSize, BlockSize,
TInWei, TIn,
TAcc, TAcc,
TOut, TWei,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerWave, GemmMPerXDL,
GemmNPerWave, GemmNPerXDL,
GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>, Sequence<0, 2, 1>,
Sequence<1, 0, 2>, Sequence<0, 2, 1>,
2, 1,
GemmABlockTransferSrcScalarPerVector_GemmK1, GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1, GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>, Sequence<0, 2, 1>,
Sequence<1, 0, 2>, Sequence<0, 2, 1>,
2, 1,
GemmBBlockTransferSrcScalarPerVector_GemmK1, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1, GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1>, Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
2, 7,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks), decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>( decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), false, // CAccessOrderMRepeatNRepeat
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), true,
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), true>(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc, static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmn_gemmk1_grid_desc, static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
out_gemmm_gemmn_grid_desc, in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks, out_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_step_hacks, wei_gemmm_gemmn_grid_desc,
out_m0_m1_m2_n_grid_step_hacks, debug::debug_driver_gemm_xdlops_v2r3::M01,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, debug::debug_driver_gemm_xdlops_v2r3::N01,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, in_gemmk0_gemmm_gemmk1_grid_step_hacks,
nrepeat); out_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{ {
const auto N = out_n_ho_wo_k_lengths[I0]; const auto N = out_n_ho_wo_k_lengths[I0];
...@@ -216,7 +263,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( ...@@ -216,7 +263,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
const auto Y = wei_k_y_x_c_lengths[I1]; const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2]; const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time; (std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
...@@ -225,5 +272,5 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( ...@@ -225,5 +272,5 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
} }
// copy result back to host // copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
} }
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template <typename TIn,
typename TWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
typename GridSizeType>
void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TIn>& in_n_hi_wi_c,
Tensor<TWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
GridSizeType desired_grid_size,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 16, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto N = in_n_hi_wi_c_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_desc.GetLength(I3);
const auto Ho = out_n_ho_wo_k_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_desc.GetLength(I1);
const auto X = wei_k_y_x_c_desc.GetLength(I2);
const auto GemmM = K;
const auto GemmN = Y * X * C;
const auto GemmKTotal = N * Ho * Wo;
const auto GemmK = GemmKTotal / GemmK1;
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch));
const index_t GemmK0 = BatchLen * GemmKPerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
<< std::endl;
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad(
in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{},
GemmKBatch,
GemmKPad);
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4<
BlockSize,
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::AtomicAdd,
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerXDL,
GemmNPerXDL,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
2,
GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 3, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true,
true>;
// timing
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// verification
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
driver_gemm_xdlops(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
0);
// copy result back to host
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment