Commit a3b86965 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into lds_bypass_spilling
parents bdd0f64e fe96e8fb
...@@ -54,7 +54,8 @@ __global__ void ...@@ -54,7 +54,8 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -147,7 +148,9 @@ __global__ void ...@@ -147,7 +148,9 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_etile_map) const Block2CTileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
...@@ -236,7 +239,8 @@ __global__ void ...@@ -236,7 +239,8 @@ __global__ void
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -265,7 +269,7 @@ __global__ void ...@@ -265,7 +269,7 @@ __global__ void
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx1100__ ))
} }
template < // DataType Family template < // DataType Family
...@@ -673,7 +677,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -673,7 +677,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
......
...@@ -45,8 +45,9 @@ __global__ void ...@@ -45,8 +45,9 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
......
...@@ -1201,7 +1201,12 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1201,7 +1201,12 @@ struct ThreadwiseTensorSliceTransfer_v4
SrcCoord src_ref_coord_; SrcCoord src_ref_coord_;
}; };
// Do NOT involve any tensor coordinates with StaticBuffer /**
* @brief Threadwise data transfer
*
* Do NOT involve any tensor coordinates with StaticBuffer
*
*/
template <typename SrcData, template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
......
...@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
return amd_buffer_load_impl<scalar_t, vector_size>( return amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
...@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_store_impl<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
...@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<scalar_t, vector_size>( amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
...@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_max_impl<scalar_t, vector_size>( amd_buffer_atomic_max_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
......
...@@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
// Ranged input operand // Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{ {
#if defined(__gfx11__)
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
#else
ignore = a;
ignore = b;
ignore = c;
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -21,17 +21,18 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend> ...@@ -21,17 +21,18 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
if constexpr(AssemblyBackend) // * Inline assembly need to elimate the duplicated data load, compiler won't help you
{ // delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32( // amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{})); // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
} #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
else reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
{
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
} #else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -45,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> ...@@ -45,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float8_t>()(Number<0>{}) = reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -64,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> ...@@ -64,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -82,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> ...@@ -82,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) = reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -98,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -98,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
template <class FloatC> template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<int32x8_t>()(Number<0>{}) = reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a, neg_a,
...@@ -106,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -106,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
bit_cast<int32x4_t>(reg_b), bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}], reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp); clamp);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -120,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> ...@@ -120,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -135,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> ...@@ -135,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -154,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> ...@@ -154,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -172,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> ...@@ -172,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<bhalf8_t>()(Number<0>{}) = reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -188,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -188,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
template <class FloatC> template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<int32x4_t>()(Number<0>{}) = reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a, neg_a,
...@@ -196,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -196,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
bit_cast<int32x4_t>(reg_b), bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}], reg_c.template AsType<int32x4_t>()[Number<0>{}],
clamp); clamp);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
......
...@@ -1022,8 +1022,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -1022,8 +1022,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
uint32_t int32; uint32_t int32;
} u = {x}; } u = {x};
if(~u.int32 & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal, // When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
...@@ -1035,15 +1033,13 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -1035,15 +1033,13 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
// has the value 0x7f, then incrementing it causes it to become 0x00 and // has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value // the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00. // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa // incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value. // of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even bool flag0 = ~u.int32 & 0x7f800000;
}
else if(u.int32 & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN. // When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa // mantissa bit. Quiet NaN is indicated by the most significant mantissa
...@@ -1051,9 +1047,11 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -1051,9 +1047,11 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
// mantissa bit being 0 but some other bit(s) being 1. If any of the // mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit // lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case // of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0. // the bfloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN bool flag1 = !flag0 && (u.int32 & 0xffff);
}
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
......
...@@ -135,6 +135,28 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h ...@@ -135,6 +135,28 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
c); c);
} }
template <>
__device__ void inner_product<int8_t, int8_t, int32_t>(const int8_t& a, const int8_t& b, int32_t& c)
{
c += type_convert<int32_t>(a) * type_convert<int32_t>(b);
}
template <>
__device__ void
inner_product<int8x2_t, int8x2_t, int32_t>(const int8x2_t& a, const int8x2_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
inner_product(vector_type<int8_t, 2>{a}.AsType<int8_t>()[I0],
vector_type<int8_t, 2>{b}.AsType<int8_t>()[I0],
c);
inner_product(vector_type<int8_t, 2>{a}.AsType<int8_t>()[I1],
vector_type<int8_t, 2>{b}.AsType<int8_t>()[I1],
c);
}
template <> template <>
__device__ void __device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c) inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
......
...@@ -93,6 +93,7 @@ using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; ...@@ -93,6 +93,7 @@ using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu; using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
template <typename Activation> template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>; using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
...@@ -74,8 +74,7 @@ template <typename ALayout, ...@@ -74,8 +74,7 @@ template <typename ALayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename EDataType> typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
ALayout,
BLayout, BLayout,
Empty_Tuple, Empty_Tuple,
ELayout, ELayout,
...@@ -83,9 +82,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -83,9 +82,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
BDataType, BDataType,
Empty_Tuple, Empty_Tuple,
EDataType, EDataType,
ck::tensor_operation::element_wise::PassThrough, PassThrough,
ck::tensor_operation::element_wise::PassThrough, PassThrough,
ck::tensor_operation::element_wise::PassThrough>> PassThrough>>
{ {
using DeviceOp = DeviceGroupedGemm<ALayout, using DeviceOp = DeviceGroupedGemm<ALayout,
BLayout, BLayout,
...@@ -95,9 +94,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -95,9 +94,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
BDataType, BDataType,
Empty_Tuple, Empty_Tuple,
EDataType, EDataType,
ck::tensor_operation::element_wise::PassThrough, PassThrough,
ck::tensor_operation::element_wise::PassThrough, PassThrough,
ck::tensor_operation::element_wise::PassThrough>; PassThrough>;
static auto GetInstances() static auto GetInstances()
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
// GroupedGEMM + GELU
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>>
{
using DeviceOp = DeviceGroupedGemm<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -18,7 +18,7 @@ namespace device { ...@@ -18,7 +18,7 @@ namespace device {
namespace instance { namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK // grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_conv2d_bias_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, GNHWC,
...@@ -34,7 +34,38 @@ void add_device_conv2d_bias_perchannel_quantization_int8_instances( ...@@ -34,7 +34,38 @@ void add_device_conv2d_bias_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp<PassThrough>>>>& Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul2_Clamp<Relu>>>>&
instances);
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
...@@ -98,9 +129,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -98,9 +129,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>) is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_bias_perchannel_quantization_int8_instances(op_ptrs); {
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(op_ptrs);
}
else if constexpr(is_same_v<Activation, Relu>) else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(op_ptrs); {
add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(op_ptrs);
}
} }
} }
......
...@@ -14,6 +14,10 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) ...@@ -14,6 +14,10 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
} }
} }
/**
* @brief Container for storing data in GPU device memory
*
*/
struct DeviceMem struct DeviceMem
{ {
DeviceMem() = delete; DeviceMem() = delete;
......
...@@ -100,6 +100,15 @@ struct FillMonotonicSeq ...@@ -100,6 +100,15 @@ struct FillMonotonicSeq
return tmp; return tmp;
}); });
} }
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<decltype(
std::declval<const FillMonotonicSeq&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
}; };
template <typename T> template <typename T>
...@@ -112,6 +121,15 @@ struct FillConstant ...@@ -112,6 +121,15 @@ struct FillConstant
{ {
std::fill(first, last, value_); std::fill(first, last, value_);
} }
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillConstant&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
}; };
} // namespace utils } // namespace utils
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment