Commit 4e911f3e authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents fa9da1a4 4939ee59
...@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
...@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
......
...@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( ...@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
...@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( ...@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
mean_var_grid_desc_m, mean_var_grid_desc_m,
blkgroup_size, blkgroup_size,
num_xy_k_block_tile_iteration, num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon, epsilon,
p_in_welford_mean, p_in_welford_mean,
p_in_welford_variance, p_in_welford_variance,
...@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
...@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
1, 1,
true>( true>(
...@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
1, 1,
true>( true>(
...@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
// Step 1: do final welford reduction to get mean and variance // Step 1: do final welford reduction to get mean and variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_count_thread_buf(I) = 0; welford_count_thread_buf(I) = 0;
}); });
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; constexpr auto mean_var_count_thread_copy_step_m_k =
++reducedTiles) make_multi_index(0, KThreadClusterSize);
int32_t reducedSize = 0;
while(reducedSize < blkgroup_size)
{ {
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_val_buf, welford_mean_global_val_buf,
...@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_var_thread_buf, welford_var_thread_buf,
welford_count_thread_buf); welford_count_thread_buf);
reducedSize += KThreadClusterSize;
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k); mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#pragma once #pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
......
...@@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2 ...@@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2
do do
{ {
#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
__builtin_amdgcn_iglp_opt(CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT);
#endif
block_sync_lds(); block_sync_lds();
// GEMM i // GEMM i
......
...@@ -27,6 +27,9 @@ template <typename GridwiseGemm, ...@@ -27,6 +27,9 @@ template <typename GridwiseGemm,
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
#endif #endif
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -60,6 +63,9 @@ template <typename GridwiseGemm, bool HasMainKBlockLoop> ...@@ -60,6 +63,9 @@ template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
#endif #endif
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
{ {
......
...@@ -29,7 +29,9 @@ enum struct MfmaInstr ...@@ -29,7 +29,9 @@ enum struct MfmaInstr
mfma_i32_16x16x16i8, mfma_i32_16x16x16i8,
mfma_i32_32x32x16i8, mfma_i32_32x32x16i8,
mfma_i32_16x16x32i8, mfma_i32_16x16x32i8,
mfma_f64_16x16x4f64 mfma_f64_16x16x4f64,
mfma_f32_32x32x16f8f8,
mfma_f32_16x16x32f8f8
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
...@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
{ {
...@@ -594,6 +640,18 @@ struct MfmaSelector ...@@ -594,6 +640,18 @@ struct MfmaSelector
} }
#endif #endif
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
__host__ __device__ constexpr MfmaSelector() __host__ __device__ constexpr MfmaSelector()
...@@ -794,7 +852,7 @@ struct XdlopsGemm ...@@ -794,7 +852,7 @@ struct XdlopsGemm
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value, is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
...@@ -13,6 +13,61 @@ ...@@ -13,6 +13,61 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace {
template <
index_t NDimSpatial,
typename ALayout,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization>
constexpr auto
make_out_n_ho_wo_k_grid_desc(const index_t N,
const index_t Ho,
const index_t Wo,
const index_t K,
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides)
{
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{
const index_t NStride = out_g_n_k_wos_strides[1];
const index_t HiStride = out_g_n_k_wos_strides[3];
const index_t WiStride = out_g_n_k_wos_strides[4];
const auto CStride = Number<1>{};
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WiStride, CStride));
}
else
{
return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K),
make_tuple(NStride, HiStride, WiStride, CStride));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
}
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
}
}
} // namespace
template < template <
index_t NDimSpatial, index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
...@@ -29,11 +84,12 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -29,11 +84,12 @@ struct TransformConvBwdDataToGemm_v1
template <typename ALayout, template <typename ALayout,
typename std::enable_if<NDimSpatial == 2 && typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALayout, tensor_layout::convolution::GNHWK>, (is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>),
bool>::type = false> bool>::type = false>
static auto MakeADescriptor_AK0_M_AK1( static auto MakeADescriptor_AK0_M_AK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */, const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */, const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
...@@ -70,9 +126,9 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -70,9 +126,9 @@ struct TransformConvBwdDataToGemm_v1
const index_t AK0 = K / AK1; const index_t AK0 = K / AK1;
// assume packed
const auto out_n_ho_wo_k_grid_desc = const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); make_out_n_ho_wo_k_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
N, Ho, Wo, K, out_g_n_k_wos_strides);
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
...@@ -80,7 +136,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -80,7 +136,7 @@ struct TransformConvBwdDataToGemm_v1
{ {
// A: output tensor // A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N * Ho * Wo), make_tuple(make_pass_through_transform(N * Ho * Wo),
make_unmerge_transform(make_tuple(AK0, AK1))), make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
......
...@@ -1114,13 +1114,30 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1114,13 +1114,30 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( if constexpr(is_same<scalar_t, f8_t>::value)
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
return bit_cast<vector_t>(tmp);
}
else
{
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
}
#else #else
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( if constexpr(is_same<scalar_t, f8_t>::value)
src_wave_buffer_resource, src_thread_addr_offset, 0); {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
return src_thread_element_valid ? tmp : vector_t(0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
}
else
{
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0);
}
#endif #endif
} }
...@@ -1179,13 +1196,33 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1179,13 +1196,33 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_store_impl<scalar_t, vector_size, coherence>( if constexpr(is_same<scalar_t, f8_t>::value)
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); {
auto tmp =
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
}
else
{
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
}
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
amd_buffer_store_impl<scalar_t, vector_size, coherence>( if constexpr(is_same<scalar_t, f8_t>::value)
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); {
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
else
{
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
} }
#endif #endif
} }
......
...@@ -354,5 +354,68 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -354,5 +354,68 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
#endif #endif
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;
template <>
struct intrin_mfma_f32_32x32x16f8f8<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f8f8;
template <>
struct intrin_mfma_f32_16x16x32f8f8<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp" #include "ck/utility/tuple_helper.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/magic_division.hpp" #include "ck/utility/magic_division.hpp"
#include "ck/utility/c_style_pointer_cast.hpp" #include "ck/utility/c_style_pointer_cast.hpp"
#include "ck/utility/is_known_at_compile_time.hpp" #include "ck/utility/is_known_at_compile_time.hpp"
......
...@@ -12,6 +12,7 @@ using half_t = _Float16; ...@@ -12,6 +12,7 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif #endif
using f8_t = uint8_t;
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
...@@ -142,6 +143,13 @@ struct scalar_type<int4_t> ...@@ -142,6 +143,13 @@ struct scalar_type<int4_t>
}; };
#endif #endif
template <>
struct scalar_type<f8_t>
{
using type = f8_t;
static constexpr index_t vector_size = 1;
};
// //
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
...@@ -944,151 +952,13 @@ using int8x16_t = typename vector_type<int8_t, 16>::type; ...@@ -944,151 +952,13 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type; using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// Convert X to Y // f8
template <typename Y, typename X> using f8x2_t = typename vector_type<f8_t, 2>::type;
__host__ __device__ constexpr Y type_convert(X x) using f8x4_t = typename vector_type<f8_t, 4>::type;
{ using f8x8_t = typename vector_type<f8_t, 8>::type;
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
return static_cast<Y>(x); using f8x64_t = typename vector_type<f8_t, 64>::type;
}
// convert bfp16 to fp32
template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
}
// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16);
}
// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<half_t>(x_fp32);
}
// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int8_t>(x_fp32);
}
// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
...@@ -1136,4 +1006,21 @@ struct NumericLimits<int4_t> ...@@ -1136,4 +1006,21 @@ struct NumericLimits<int4_t>
}; };
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<f8_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); }
};
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
namespace ck {
// fp8 rounding modes
// use standard for rounding to nearest, the faster one
// use stochastic for stochastic rounding, helps to avoid error accumulation
enum class f8_rounding_mode
{
standard,
stochastic
};
} // namespace ck
namespace ck::utils {
namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
int exponent;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr uint8_t nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
// convert to bitwise
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type
T_bitwise;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype
if constexpr(is_float)
{
head = x_bitwise & 0xFF800000;
mantissa = x_bitwise & 0x7FFFFF;
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant);
}
else if constexpr(is_half)
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan)
{
if((x_bitwise & nan_mask) == nan_mask)
return nan_code;
}
else
{
if((x_bitwise & nan_mask) == nan_mask)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
// check if x is 0.0
if(x_bitwise == 0)
return 0;
exponent -= exp_low_cutoff - 1;
if(exponent <= 0)
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1;
mantissa += 1 << type_mant;
// apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << type_mant))
{
mantissa >>= 1;
exponent++;
}
mantissa >>= (type_mant - f8_mant);
// check negative exponent
if(exponent <= 0)
{
if(x_bitwise == 0)
return 0;
else
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa >>= 1 - exponent;
exponent = 0;
}
}
// above range: quantize to maximum possible float of the same sign
else if(exponent > max_exp)
{
if(clip)
{
mantissa = (1 << f8_mant) - 1;
exponent = max_exp;
}
else
{
return signed_inf;
}
}
// check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant));
mantissa &= (1 << f8_mant) - 1;
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x)
{
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
// prepare the codes
constexpr uint8_t nan_code = 0x80;
T fInf, fNegInf, fNaN, fNeg0;
if constexpr(is_half)
{
constexpr uint16_t ihInf = 0x7C00;
constexpr uint16_t ihNegInf = 0xFC00;
constexpr uint16_t ihNaN = 0x7C01;
constexpr uint16_t ihNeg0 = 0x8000;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0));
}
else if constexpr(is_float)
{
constexpr uint32_t ifInf = 0x7F800000;
constexpr uint32_t ifNegInf = 0xFF800000;
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
// unpack the input
uint32_t sign = x >> (f8_exp + f8_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant;
constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval;
if constexpr(negative_zero_nan)
{
if(x == nan_code)
return fNaN;
}
else
{
if(x == nan_code)
return fNeg0;
if(exponent == ((1 << f8_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant);
mantissa <<= sh;
mantissa &= ((1 << f8_mant) - 1);
exponent += 1 - sh;
}
exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << type_mant;
mantissa >>= 1 - exponent;
exponent = 0;
}
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa;
return *(reinterpret_cast<const T*>(&retval));
}
} // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
// check if x is 0.0
if(x == 0)
return static_cast<T>(0);
return run_cast_from_f8<T, negative_zero_nan>(x);
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
template <index_t N>
static constexpr __device__ index_t get_shift()
{
return (get_shift<N / 2>() + 1);
};
template <>
constexpr __device__ index_t get_shift<1>()
{
return (0);
}
} // namespace ck
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
#include "type_convert.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
// Pseudo random number generator
// version for fp32
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is very
// large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
// version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is very
// large!
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
return rng;
}
// return 0 if data is not fp16 or fp32
template <typename T,
uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{
std::ignore = id;
std::ignore = val;
std::ignore = seed;
return 0;
}
} // namespace ck
...@@ -25,16 +25,4 @@ struct float_equal_zero ...@@ -25,16 +25,4 @@ struct float_equal_zero
}; };
}; };
template <index_t N>
static constexpr __device__ index_t get_shift()
{
return (get_shift<N / 2>() + 1);
};
template <>
constexpr __device__ index_t get_shift<1>()
{
return (0);
}
} // namespace ck } // namespace ck
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
namespace ck {
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// convert bfp16 to fp32
template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
}
// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16);
}
// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<half_t>(x_fp32);
}
// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int8_t>(x_fp32);
}
// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<float, negative_zero_nan>(x);
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<half_t, negative_zero_nan>(x);
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/host_utility/hip_check_error.hpp"
namespace ck {
// Initialization flag of Barrier object, can be any value except for zero
static constexpr int BarrierInitFlag = 0x7856;
// 1) only the first thread-block in the synchronizaton group is supposed to call this function. It
// is the responsibility of the user to ensure the two integer values in p_control_bits are zeros
// before calling gms_init().
// 2) Aftercalling gms_reset(), the two integer values in p_control_bits will be zeros, so no
// repetitious initialization of p_control_bits buffer is required
static __device__ void gms_init(int NumWarps, int* p_control_bits)
{
union
{
int two32[2];
unsigned long one64;
} regs;
regs.two32[0] = BarrierInitFlag;
regs.two32[1] = NumWarps;
if(threadIdx.x == 0)
atomicCAS(reinterpret_cast<unsigned long*>(p_control_bits), 0, regs.one64);
};
// all the workgroups in the synchronization group is supposed to call this function
static __device__ void gms_barrier(int* p_control_bits)
{
constexpr int mask = warpSize - 1;
if((threadIdx.x & mask) == 0)
{
// ensure the barrier object is initialized
do
{
const int r0 = __atomic_load_n(&p_control_bits[0], __ATOMIC_RELAXED);
if(r0 == BarrierInitFlag)
break;
} while(true);
// go ahead toward the barrier line
atomicSub(&p_control_bits[1], 1);
// wait until all warps have arrived
do
{
const int r1 = __atomic_load_n(&p_control_bits[1], __ATOMIC_RELAXED);
if(r1 == 0)
break;
} while(true);
};
};
// 1) Only the first thread-block in the synchronizaton group is supposed to call this function.
// 2) Aftercalling gms_reset(), the two integer values in p_control_bits will be zeros, so no
// repetitious initialization of p_control_bits buffer is required
static __device__ void gms_reset(int* p_control_bits)
{
// reset the barrier object
if(threadIdx.x == 0)
(void)atomicCAS(&p_control_bits[0], BarrierInitFlag, 0);
};
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment