Commit 9fb64dae authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into e2e_kernellib
parents e330961d fe96e8fb
......@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization)
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
return str.str();
......
......@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization)
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1
<< ">";
// clang-format on
......
......@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization)
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1 << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">";
// clang-format on
......
......@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on
......
......@@ -618,7 +618,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution;
// check device
if(get_device_name() == "gfx1100")
if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......@@ -876,7 +877,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector
<< ">";
// clang-format on
......
......@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on
......
......@@ -381,6 +381,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
a_mtx_mraw_kraw_.emplace_back(M, K);
b_mtx_nraw_kraw_.emplace_back(N, K);
if(M == 0)
{
skipped_group_count_++;
......@@ -485,6 +488,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
CDEElementwiseOperation c_element_op_;
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
index_t grid_size_;
};
......@@ -599,7 +604,28 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return false;
}
return true;
bool supported = true;
// If we use padding we do not support vector loads for dimensions not divisible by vector
// load size.
if constexpr(GemmSpec != GemmSpecialization::Default)
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
// thus we have to adapt it to the {M,K} or {N,K} layout.
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
for(index_t i = 0; i < arg.group_count_; ++i)
{
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
}
}
return supported;
}
// polymorphic
......@@ -661,7 +687,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">";
// clang-format on
......
#pragma once
#include "ck/utility/data_type.hpp"
// #include "ck/utility/get_id.hpp"
namespace ck {
namespace tensor_operation {
......@@ -17,18 +18,27 @@ struct Activation_Mul_Clamp
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void operator()(float& y, const int32_t& x) const
__device__ constexpr void operator()(int32_t& y, const int32_t& x) const
{
// We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
__host__ constexpr void operator()(float& y, const float& x) const
{
// CAUSION - We might float in & float out in reference code
activationOp_(y, x);
y = math::clamp(requantScale_ * y, -128.f, 127.f);
}
float requantScale_;
......@@ -51,6 +61,17 @@ struct Activation_Mul2_Clamp
y = ck::type_convert<int8_t>(y_fp32);
}
__device__ constexpr void
operator()(int32_t& y, const int32_t& x, const float& requantScale) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
Activation activationOp_;
};
......@@ -72,6 +93,17 @@ struct Add_Activation_Mul_Clamp
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float requantScale_;
Activation activationOp_;
};
......@@ -92,6 +124,17 @@ struct Add_Activation_Mul2_Clamp
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
Activation activationOp_;
};
......
......@@ -185,8 +185,10 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
return b_grid_desc_k0_n0_n1_k1;
}
// E desc for destination in blockwise copy
template <typename CGridDesc_M_N_>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
......@@ -238,19 +240,19 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
using DsGridPointer = decltype(MakeDsGridPointer());
template <typename DsGridDesc_M0_M10_M11_N0_N10_N11,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
bool HasDoubleTailKBlockLoop,
typename Block2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
void* __restrict__ p_shared_block,
const AElementwiseOperation&,
const BElementwiseOperation&,
const CDEElementwiseOperation& cde_element_op,
......@@ -399,8 +401,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
FloatAB* p_a_block_double = static_cast<FloatAB*>(p_shared_block);
FloatAB* p_b_block_double =
static_cast<FloatAB*>(p_shared_block) + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
......
......@@ -54,7 +54,8 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -147,8 +148,10 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -236,8 +239,9 @@ __global__ void
const CDEElementwiseOperation cde_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
......@@ -265,7 +269,7 @@ __global__ void
ignore = b_element_op;
ignore = cde_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx1100__ ))
}
template < // DataType Family
......
......@@ -45,8 +45,9 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
......
......@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
return amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
......@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
......@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
......@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
......
......@@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
"0"(c0),
"1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
#endif
}
......@@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
"2"(c2),
"3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
#endif
}
......@@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{
#if defined(__gfx11__)
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
#else
ignore = a;
ignore = b;
ignore = c;
#endif
}
} // namespace ck
......
......@@ -21,17 +21,18 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
if constexpr(AssemblyBackend)
{
amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
}
else
{
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
......@@ -45,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
......@@ -64,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
......@@ -82,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
......@@ -98,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a,
......@@ -106,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
......@@ -120,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
......@@ -135,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
......@@ -154,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
......@@ -172,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
......@@ -188,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a,
......@@ -196,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
clamp);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
......
......@@ -1022,38 +1022,36 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
uint32_t int32;
} u = {x};
if(~u.int32 & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
}
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
......
......@@ -135,6 +135,28 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
c);
}
template <>
__device__ void inner_product<int8_t, int8_t, int32_t>(const int8_t& a, const int8_t& b, int32_t& c)
{
c += type_convert<int32_t>(a) * type_convert<int32_t>(b);
}
template <>
__device__ void
inner_product<int8x2_t, int8x2_t, int32_t>(const int8x2_t& a, const int8x2_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
inner_product(vector_type<int8_t, 2>{a}.AsType<int8_t>()[I0],
vector_type<int8_t, 2>{b}.AsType<int8_t>()[I0],
c);
inner_product(vector_type<int8_t, 2>{a}.AsType<int8_t>()[I1],
vector_type<int8_t, 2>{b}.AsType<int8_t>()[I1],
c);
}
template <>
__device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
......
......@@ -93,6 +93,7 @@ using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
......@@ -74,18 +74,17 @@ template <typename ALayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<
ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
PassThrough>>
{
using DeviceOp = DeviceGroupedGemm<ALayout,
BLayout,
......@@ -95,9 +94,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
BDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
PassThrough,
PassThrough,
PassThrough>;
static auto GetInstances()
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
// GroupedGEMM + GELU
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>>
{
using DeviceOp = DeviceGroupedGemm<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment