Commit 8c4e33f1 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into v5r1_add

parents 5aed38d4 3737bb03
...@@ -30,7 +30,8 @@ struct PassThrough ...@@ -30,7 +30,8 @@ struct PassThrough
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
__host__ __device__ static void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up)
{ {
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
...@@ -1708,7 +1709,8 @@ struct Vectorize ...@@ -1708,7 +1709,8 @@ struct Vectorize
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
__host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{ {
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
......
#ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP
#include "ignore.hpp"
namespace ck {
// StaticTensor for Scalar
template <AddressSpaceEnum_t AddressSpace,
typename T,
typename TensorDesc,
bool InvalidElementUseNumericalZeroValue,
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
struct StaticTensor
{
static constexpr auto desc_ = TensorDesc{};
static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
__host__ __device__ constexpr StaticTensor() : invalid_element_value_{0} {}
__host__ __device__ constexpr StaticTensor(T invalid_element_value)
: invalid_element_value_{invalid_element_value}
{
}
// read access
template <typename Idx,
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr const T& operator[](Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
if constexpr(is_valid)
{
return data_[Number<offset>{}];
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return T{0};
}
else
{
return invalid_element_value_;
}
}
}
// write access
template <typename Idx,
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr T& operator()(Idx)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
if constexpr(is_valid)
{
return data_(Number<offset>{});
}
else
{
return ignore;
}
}
StaticBuffer<AddressSpace, T, element_space_size_, true> data_;
T invalid_element_value_ = T{0};
};
// StaticTensor for vector
template <AddressSpaceEnum_t AddressSpace,
typename S,
index_t ScalarPerVector,
typename TensorDesc,
bool InvalidElementUseNumericalZeroValue,
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
struct StaticTensorTupleOfVectorBuffer
{
static constexpr auto desc_ = TensorDesc{};
static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
static constexpr index_t num_of_vector_ =
math::integer_divide_ceil(element_space_size_, ScalarPerVector);
using V = vector_type<S, ScalarPerVector>;
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() : invalid_element_value_{0} {}
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value)
: invalid_element_value_{invalid_element_value}
{
}
// Get S
// Idx is for S, not V
template <typename Idx,
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr const S& operator[](Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
if constexpr(is_valid)
{
return data_[Number<offset>{}];
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return S{0};
}
else
{
return invalid_element_value_;
}
}
}
// Set S
// Idx is for S, not V
template <typename Idx,
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr S& operator()(Idx)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
if constexpr(is_valid)
{
return data_(Number<offset>{});
}
else
{
return ignore;
}
}
// Get X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr X GetAsType(Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
if constexpr(is_valid)
{
return data_.template GetAsType<X>(Number<offset>{});
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
// TODO: is this right way to initialize a vector?
return X{0};
}
else
{
// TODO: is this right way to initialize a vector?
return X{invalid_element_value_};
}
}
}
// Set X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Idx, X x)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
if constexpr(is_valid)
{
data_.template SetAsType<X>(Number<offset>{}, x);
}
}
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template <typename Idx>
__host__ __device__ constexpr const V& GetVectorTypeReference(Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
return data_.GetVectorTypeReference(Number<offset>{});
}
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template <typename Idx>
__host__ __device__ constexpr V& GetVectorTypeReference(Idx)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
return data_.GetVectorTypeReference(Number<offset>{});
}
StaticBufferTupleOfVector<AddressSpace, S, num_of_vector_, ScalarPerVector, true> data_;
S invalid_element_value_ = S{0};
};
template <AddressSpaceEnum_t AddressSpace,
typename T,
typename TensorDesc,
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
__host__ __device__ constexpr auto make_static_tensor(TensorDesc)
{
return StaticTensor<AddressSpace, T, TensorDesc, true>{};
}
template <
AddressSpaceEnum_t AddressSpace,
typename T,
typename TensorDesc,
typename X,
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false,
typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
__host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_element_value)
{
return StaticTensor<AddressSpace, T, TensorDesc, true>{invalid_element_value};
}
} // namespace ck
#endif
...@@ -151,6 +151,20 @@ struct TensorAdaptor ...@@ -151,6 +151,20 @@ struct TensorAdaptor
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; } __host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
#if 0 // debug
template <index_t I>
__host__ __device__ constexpr index_t GetTopDimensionLength(Number<I> idim) const
{
// TODO: not implemented
}
template <index_t I>
__host__ __device__ constexpr index_t GetBottomDimensionLength(Number<I> idim) const
{
// TODO: not implemented
}
#endif
template <typename TopIdx> template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{ {
......
...@@ -37,7 +37,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -37,7 +37,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferV2<AddressSpaceEnum_t::Vgpr, vector_type<FloatAcc, 16>, MRepeat * NRepeat, true> StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, 16>,
MRepeat * NRepeat,
true>
c_thread_buf_; c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer_v3r2.hpp"
namespace ck { namespace ck {
...@@ -146,22 +146,22 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -146,22 +146,22 @@ struct BlockwiseTensorSliceTransfer_v4
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
SrcDesc, SrcDesc,
DstDesc, DstDesc,
SrcDimAccessOrder, SrcDimAccessOrder,
DstDimAccessOrder, DstDimAccessOrder,
SrcVectorDim, SrcVectorDim,
DstVectorDim, DstVectorDim,
SrcScalarPerVector, SrcScalarPerVector,
DstScalarPerVector, DstScalarPerVector,
SrcScalarStrideInVector, SrcScalarStrideInVector,
DstScalarStrideInVector, DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>; ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
......
...@@ -12,18 +12,19 @@ enum struct MfmaInstr ...@@ -12,18 +12,19 @@ enum struct MfmaInstr
mfma_f32_32x32x1xf32 = 0, mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32, mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32, mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction mfma_f32_32x32x2xf32,
mfma_f32_16x16x4xf32, // k reduction mfma_f32_16x16x4xf32,
mfma_f32_32x32x4f16, mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16, mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16, mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction mfma_f32_32x32x8f16,
mfma_f32_16x16x16f16, // k reduction mfma_f32_16x16x16f16,
mfma_f32_32x32x2bf16, mfma_f32_32x32x8bf16_1k,
mfma_f32_16x16x2bf16, mfma_f32_16x16x16bf16_1k,
mfma_f32_4x4x2bf16, mfma_f32_32x32x4bf16,
mfma_f32_32x32x4bf16, // k reduction mfma_f32_16x16x8bf16,
mfma_f32_16x16x8bf16, // k reduction mfma_i32_32x32x8i8,
mfma_i32_16x16x16i8,
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
...@@ -250,9 +251,8 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16> ...@@ -250,9 +251,8 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
} }
}; };
#if 0
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16> struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t num_groups_per_blk = 4;
...@@ -260,26 +260,38 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16> ...@@ -260,26 +260,38 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32; static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 2; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a); intrin_mfma_f32_32x32x8bf16_1k<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b); }
};
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run( template <>
p_a, p_b, reg_c); struct mfma_type<MfmaInstr::mfma_f32_16x16x16bf16_1k>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x16bf16_1k<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -298,19 +310,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16> ...@@ -298,19 +310,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16>
static constexpr index_t k_per_blk = 2; static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a); intrin_mfma_f32_32x32x4bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
} }
}; };
...@@ -329,84 +332,56 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16> ...@@ -329,84 +332,56 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
static constexpr index_t k_per_blk = 2; static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a); intrin_mfma_f32_16x16x8bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
} }
}; };
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x2bf16> struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 4; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16; static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 16; static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 2; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a); intrin_mfma_i32_32x32x8i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
} }
}; };
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_4x4x2bf16> struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 4; static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 64; static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 2; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a); intrin_mfma_i32_16x16x16i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
} }
}; };
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
...@@ -498,73 +473,37 @@ struct MfmaSelector ...@@ -498,73 +473,37 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
#if 0
template <>
static constexpr auto GetMfma<ushort, 128, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 128>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 32>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 32, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 16>()
{
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
}
template <> template <>
static constexpr auto GetMfma<ushort, 16, 64>() static constexpr auto GetMfma<ushort, 32, 32>()
{
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 8, 64>()
{ {
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{}; #if defined(CK_AMD_GPU_GFX90A)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
#else
return MfmaInstr::mfma_f32_32x32x4bf16;
#endif
} }
template <> template <>
static constexpr auto GetMfma<ushort, 4, 64>() static constexpr auto GetMfma<ushort, 16, 16>()
{ {
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{}; #if defined(CK_AMD_GPU_GFX90A)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
#else
return MfmaInstr::mfma_f32_16x16x8bf16;
#endif
} }
template <> template <>
static constexpr auto GetMfma<ushort, 32, 32>() static constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return xdlops_info<MfmaInstr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{}; return MfmaInstr::mfma_i32_32x32x8i8;
} }
template <> template <>
static constexpr auto GetMfma<ushort, 16, 16>() static constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return xdlops_info<MfmaInstr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{}; return MfmaInstr::mfma_i32_16x16x16i8;
} }
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
...@@ -686,8 +625,8 @@ struct XdlopsGemm ...@@ -686,8 +625,8 @@ struct XdlopsGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value || static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, ushort>::value, is_same<base_type, ushort>::value || is_same<base_type, int8_t>::value,
"base base_type must be float, half, ushort!"); "base base_type must be float, half, ushort, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread); mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
......
...@@ -50,11 +50,24 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, ...@@ -50,11 +50,24 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
__device__ int16_t __device__ ushort
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
__device__ ushort2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
__device__ ushort4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
__device__ int32_t __device__ int32_t
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset, index_t voffset,
...@@ -133,12 +146,26 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, ...@@ -133,12 +146,26 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
__device__ void __device__ void
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, llvm_amdgcn_raw_buffer_store_i16(ushort vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void
llvm_amdgcn_raw_buffer_store_i16x2(ushort2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
__device__ void
llvm_amdgcn_raw_buffer_store_i16x4(ushort4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
__device__ void __device__ void
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -228,6 +255,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -228,6 +255,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -326,6 +354,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -326,6 +354,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
return as_type<half8_t>(tmp); return as_type<half8_t>(tmp);
} }
} }
else if constexpr(is_same<T, ushort>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_i16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_i16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<ushort8_t>(tmp);
}
}
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
...@@ -458,6 +511,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -458,6 +511,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, double>::value && (N == 1 || N == 2)) || (is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -560,6 +614,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -560,6 +614,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
#endif #endif
} }
} }
else if constexpr(is_same<T, ushort>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
}
}
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
namespace ck { namespace ck {
// A, B, C, cbsz, abid, blgp // A, B, C, cbsz, abid, blgp
// fp32
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32( extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32"); float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
...@@ -21,6 +22,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32( ...@@ -21,6 +22,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32"); float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
// fp16
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16( extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16"); half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
...@@ -36,6 +38,13 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16( ...@@ -36,6 +38,13 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16"); half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
// bfp16
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
ushort4_t, ushort4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16bf16.1k");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16( extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16"); ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
...@@ -51,6 +60,23 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( ...@@ -51,6 +60,23 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
// int8
extern "C" __device__ int32x32_t llvm_intrin_amdgcn_mfma_i32_32x32x4i8(
int, int, int32x32_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x4i8");
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_16x16x4i8(
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x4i8");
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_4x4x4i8(
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.4x4x4i8");
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_32x32x8i8(
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x8i8");
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_16x16x16i8(
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x16i8");
// fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x1f32; struct intrin_mfma_f32_32x32x1f32;
...@@ -148,6 +174,7 @@ struct intrin_mfma_f32_4x4x1f32<8, 64> ...@@ -148,6 +174,7 @@ struct intrin_mfma_f32_4x4x1f32<8, 64>
} }
}; };
// fp16
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4f16; struct intrin_mfma_f32_32x32x4f16;
...@@ -244,147 +271,102 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> ...@@ -244,147 +271,102 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
} }
}; };
#if 0 // bfp16
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x2bf16; struct intrin_mfma_f32_32x32x8bf16_1k;
template <index_t AStride, index_t BStride> template <>
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride> struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
{ {
__device__ static c_vec32_4_t::VecType template <class FloatC>
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); reg_c.template AsType<float16_t>()(Number<0>{}) =
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
} }
}; };
template <index_t AStride, index_t BStride> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride> struct intrin_mfma_f32_16x16x16bf16_1k;
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride> template <>
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride> struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
{ {
__device__ static c_vec32_2_t::VecType template <class FloatC>
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c) __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); reg_c.template AsType<float4_t>()(Number<0>{}) =
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
return reg_c;
} }
}; };
template <index_t AStride, index_t BStride> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride> struct intrin_mfma_f32_32x32x4bf16;
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride> template <>
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride> struct intrin_mfma_f32_32x32x4bf16<32, 32>
{ {
__device__ static c_vec32_1_t::VecType template <class FloatC>
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
return reg_c; reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, struct intrin_mfma_f32_16x16x8bf16;
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <> template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a, struct intrin_mfma_f32_16x16x8bf16<16, 16>
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4); template <class FloatC>
return reg_c; __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
} {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x2bf16; struct intrin_mfma_i32_32x32x8i8;
template <> template <>
struct intrin_mfma_f32_4x4x2bf16<4, 64> struct intrin_mfma_i32_32x32x8i8<32, 32>
{ {
__device__ static c_vec4_1_t::VecType template <class FloatC>
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c) __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); reg_c.template AsType<int32x16_t>()(Number<0>{}) =
return reg_c; llvm_intrin_amdgcn_mfma_i32_32x32x8i8(as_type<int>(reg_a),
as_type<int>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8;
template <> template <>
struct intrin_mfma_f32_4x4x2bf16<8, 64> struct intrin_mfma_i32_16x16x16i8<16, 16>
{ {
__device__ static c_vec4_2_t::VecType template <class FloatC>
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c) __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); reg_c.template AsType<int32x4_t>()(Number<0>{}) =
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0); llvm_intrin_amdgcn_mfma_i32_16x16x16i8(as_type<int>(reg_a),
return reg_c; as_type<int>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -30,7 +30,11 @@ ...@@ -30,7 +30,11 @@
#include "amd_address_space.hpp" #include "amd_address_space.hpp"
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "static_buffer.hpp" #include "static_buffer.hpp"
// TODO remove this
#include "static_buffer_of_vector_type_v2.hpp"
#include "dynamic_buffer.hpp" #include "dynamic_buffer.hpp"
#include "is_known_at_compile_time.hpp"
#include "transpose_vectors.hpp"
#include "inner_product.hpp" #include "inner_product.hpp"
......
...@@ -76,7 +76,7 @@ ...@@ -76,7 +76,7 @@
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif #endif
// experimental implementation // experimental implementation for buffer load/store/atomic
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#endif #endif
...@@ -89,6 +89,11 @@ ...@@ -89,6 +89,11 @@
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#endif #endif
// experimental implementation for in-regsiter sub-dword transpose
#ifndef CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1
#endif
// pass tensor descriptor by value or void* // pass tensor descriptor by value or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
......
...@@ -373,19 +373,6 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>& ...@@ -373,19 +373,6 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
} }
template <typename Container>
__host__ __device__ constexpr auto to_tuple_of_number(const Container&)
{
static_assert(is_known_at_compile_time<Container>::value, "wrong!");
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Container::At(i);
return Number<tmp>{};
},
Container::Size());
}
template <index_t... Is> template <index_t... Is>
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>) __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
{ {
......
...@@ -58,6 +58,18 @@ __host__ __device__ constexpr auto make_vector_type(Number<N>) ...@@ -58,6 +58,18 @@ __host__ __device__ constexpr auto make_vector_type(Number<N>)
template <typename TV> template <typename TV>
struct scalar_type; struct scalar_type;
// is_scalar_type
template <typename TV>
struct is_scalar_type
{
static constexpr bool value = (scalar_type<remove_cvref_t<TV>>::vector_size == 1);
};
// has_same_scalar_type
template <typename X, typename Y>
using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<Y>>::type>;
template <typename T, index_t N> template <typename T, index_t N>
struct scalar_type<T __attribute__((ext_vector_type(N)))> struct scalar_type<T __attribute__((ext_vector_type(N)))>
{ {
......
#ifndef CK_IGNORE_HPP
#define CK_IGNORE_HPP
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
namespace ck {
namespace detail {
struct ignore_t
{
template <typename T>
constexpr void operator=(T&&) const noexcept
{
}
};
} // namespace detail
inline constexpr detail::ignore_t ignore;
} // namespace ck
#endif
#ifndef IS_KNOWN_AT_COMPILE_TIME_HPP
#define IS_KNOWN_AT_COMPILE_TIME_HPP
#include "config.hpp"
#include "integral_constant.hpp"
#include "sequence.hpp"
#include "tuple.hpp"
namespace ck {
template <typename T>
struct is_known_at_compile_time;
template <>
struct is_known_at_compile_time<index_t>
{
static constexpr bool value = false;
};
template <typename T, T X>
struct is_known_at_compile_time<integral_constant<T, X>>
{
static constexpr bool value = true;
};
template <index_t... Is>
struct is_known_at_compile_time<Sequence<Is...>>
{
static constexpr bool value = true;
};
template <typename... Ts>
struct is_known_at_compile_time<Tuple<Ts...>>
{
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return container_reduce(
Tuple<Ts...>{},
[](auto x, bool r) {
return is_known_at_compile_time<remove_cvref_t<decltype(x)>>::value & r;
},
true);
}
static constexpr bool value = IsKnownAtCompileTime();
};
} // namespace ck
#endif
...@@ -5,158 +5,156 @@ ...@@ -5,158 +5,156 @@
namespace ck { namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace, // static buffer for scalar
template <AddressSpaceEnum_t AddressSpace,
typename T, typename T,
index_t N, index_t N,
bool InvalidElementUseNumericalZeroValue> bool InvalidElementUseNumericalZeroValue> // TODO remove this bool, no longer needed
struct StaticBuffer : public StaticallyIndexedArray<T, N> struct StaticBuffer : public StaticallyIndexedArray<T, N>
{ {
using type = T; using type = T;
using base = StaticallyIndexedArray<T, N>; using base = StaticallyIndexedArray<T, N>;
T invalid_element_value_ = T{0};
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
: base{}, invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{ {
return BufferAddressSpace; return AddressSpace;
} }
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
// read access
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const __host__ __device__ constexpr const T& operator[](Number<I> i) const
{ {
if constexpr(InvalidElementUseNumericalZeroValue) return base::operator[](i);
{
return is_valid_element ? At(i) : T{0};
}
else
{
return is_valid_element ? At(i) : invalid_element_value_;
}
} }
// write access
template <index_t I> template <index_t I>
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x) __host__ __device__ constexpr T& operator()(Number<I> i)
{ {
if(is_valid_element) return base::operator()(i);
{
At(i) = x;
}
} }
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
}; };
template <AddressSpaceEnum_t BufferAddressSpace, // static buffer for vector
typename T, template <AddressSpaceEnum_t AddressSpace,
index_t N, typename S,
bool InvalidElementUseNumericalZeroValue> index_t NumOfVector,
struct StaticBufferV2 : public StaticallyIndexedArray<T, N> index_t ScalarPerVector,
bool InvalidElementUseNumericalZeroValue, // TODO remove this bool, no longer needed,
typename enable_if<is_scalar_type<S>::value, bool>::type = false>
struct StaticBufferTupleOfVector
: public StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>
{ {
using type = T; using V = typename vector_type<S, ScalarPerVector>::type;
using base = StaticallyIndexedArray<T, N>; using base = StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>;
static constexpr auto s_per_v = Number<ScalarPerVector>{};
static constexpr auto num_of_v_ = Number<NumOfVector>{};
using VecBaseType = typename T::d1_t; __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
__host__ __device__ static constexpr index_t GetVectorSize() __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{ {
return sizeof(typename T::type) / sizeof(VecBaseType); return AddressSpace;
} }
static constexpr index_t vector_size = GetVectorSize(); __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
VecBaseType invalid_element_value_ = VecBaseType{0};
T invalid_vec_value_ = T{0};
__host__ __device__ constexpr StaticBufferV2() : base{} {} __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
__host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value) // Get S
: base{}, // i is offset of S
invalid_vec_value_{invalid_element_value}, template <index_t I>
invalid_element_value_{invalid_element_value} __host__ __device__ constexpr const S& operator[](Number<I> i) const
{ {
} constexpr auto i_v = i / s_per_v;
constexpr auto i_s = i % s_per_v;
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() return base::operator[](i_v).template AsType<S>()[i_s];
{
return BufferAddressSpace;
} }
// Set S
// i is offset of S
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto& GetVector(Number<I> vec_id) __host__ __device__ constexpr S& operator()(Number<I> i)
{ {
return this->At(vec_id); constexpr auto i_v = i / s_per_v;
} constexpr auto i_s = i % s_per_v;
template <index_t I> return base::operator()(i_v).template AsType<S>()(i_s);
__host__ __device__ constexpr const auto& GetVector(Number<I> vec_id) const
{
return this->At(vec_id);
} }
template <index_t I> // Get X
__host__ __device__ constexpr auto& GetElement(Number<I> i, bool) // i is offset of S, not X. i should be aligned to X
template <typename X,
index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const
{ {
constexpr auto vec_id = Number<i / vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
constexpr auto vec_off = Number<i % vector_size>{};
static_assert(s_per_v % s_per_x == 0, "wrong! V must one or multiple X");
static_assert(i % s_per_x == 0, "wrong!");
constexpr auto i_v = i / s_per_v;
constexpr auto i_x = (i % s_per_v) / s_per_x;
return this->At(vec_id).template AsType<VecBaseType>()(vec_off); return base::operator[](i_v).template AsType<X>()[i_x];
} }
template <index_t I> // Set X
__host__ __device__ constexpr auto GetElement(Number<I> i, bool is_valid_element) const // i is offset of S, not X. i should be aligned to X
template <typename X,
index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{ {
constexpr auto vec_id = Number<i / vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
constexpr auto vec_off = Number<i % vector_size>{};
static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X");
if constexpr(InvalidElementUseNumericalZeroValue) static_assert(i % s_per_x == 0, "wrong!");
{
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off] constexpr auto i_v = i / s_per_v;
: VecBaseType{0}; constexpr auto i_x = (i % s_per_v) / s_per_x;
}
else base::operator()(i_v).template AsType<X>()(i_x) = x;
{
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
: invalid_element_value_;
}
} }
// Get read access to vector_type V
// i is offset of S, not V. i should be aligned to V
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I> i) const __host__ __device__ constexpr const auto& GetVectorTypeReference(Number<I> i) const
{ {
return GetElement(i, true); static_assert(i % s_per_v == 0, "wrong!");
constexpr auto i_v = i / s_per_v;
return base::operator[](i_v);
} }
// Get write access to vector_type V
// i is offset of S, not V. i should be aligned to V
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i) __host__ __device__ constexpr auto& GetVectorTypeReference(Number<I> i)
{ {
return GetElement(i, true); static_assert(i % s_per_v == 0, "wrong!");
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } constexpr auto i_v = i / s_per_v;
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } return base::operator()(i_v);
}
}; };
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N> template <AddressSpaceEnum_t AddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>) __host__ __device__ constexpr auto make_static_buffer(Number<N>)
{ {
return StaticBuffer<BufferAddressSpace, T, N, true>{}; return StaticBuffer<AddressSpace, T, N, true>{};
}
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
{
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
} }
} // namespace ck } // namespace ck
......
#ifndef CK_STATIC_BUFFER_OF_VECTOR_TYPE_V2_HPP
#define CK_STATIC_BUFFER_OF_VECTOR_TYPE_V2_HPP
#include "statically_indexed_array.hpp"
namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace,
typename T,
index_t N,
bool InvalidElementUseNumericalZeroValue>
struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
using VecBaseType = typename T::d1_t;
__host__ __device__ static constexpr index_t GetVectorSize()
{
return sizeof(typename T::type) / sizeof(VecBaseType);
}
static constexpr index_t vector_size = GetVectorSize();
VecBaseType invalid_element_value_ = VecBaseType{0};
T invalid_vec_value_ = T{0};
__host__ __device__ constexpr StaticBufferOfVectorTypeV2() : base{} {}
__host__ __device__ constexpr StaticBufferOfVectorTypeV2(VecBaseType invalid_element_value)
: base{},
invalid_vec_value_{invalid_element_value},
invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{
return BufferAddressSpace;
}
template <index_t I>
__host__ __device__ constexpr auto& GetVector(Number<I> vec_id)
{
return this->At(vec_id);
}
template <index_t I>
__host__ __device__ constexpr const auto& GetVector(Number<I> vec_id) const
{
return this->At(vec_id);
}
template <index_t I>
__host__ __device__ constexpr auto& GetElement(Number<I> i, bool)
{
constexpr auto vec_id = Number<i / vector_size>{};
constexpr auto vec_off = Number<i % vector_size>{};
return this->At(vec_id).template AsType<VecBaseType>()(vec_off);
}
template <index_t I>
__host__ __device__ constexpr auto GetElement(Number<I> i, bool is_valid_element) const
{
constexpr auto vec_id = Number<i / vector_size>{};
constexpr auto vec_off = Number<i % vector_size>{};
if constexpr(InvalidElementUseNumericalZeroValue)
{
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
: VecBaseType{0};
}
else
{
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
: invalid_element_value_;
}
}
template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I> i) const
{
return GetElement(i, true);
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i)
{
return GetElement(i, true);
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
} // namespace ck
#endif
...@@ -8,20 +8,38 @@ ...@@ -8,20 +8,38 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <typename X, typename Y>
struct tuple_concat;
template <typename T, index_t NSize> template <typename... Xs, typename... Ys>
__host__ __device__ constexpr auto generate_same_type_tuple() struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
{ {
return generate_tuple([](auto) -> T { return T{}; }, Number<NSize>{}); using type = Tuple<Xs..., Ys...>;
} };
template <typename T, index_t NSize> template <typename T, index_t N>
using same_type_tuple = decltype(generate_same_type_tuple<T, NSize>()); struct StaticallyIndexedArrayImpl
{
using type =
typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type,
typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type;
};
template <typename T>
struct StaticallyIndexedArrayImpl<T, 0>
{
using type = Tuple<>;
};
template <typename T>
struct StaticallyIndexedArrayImpl<T, 1>
{
using type = Tuple<T>;
};
} // namespace detail } // namespace detail
template <typename T, index_t NSize> template <typename T, index_t N>
using StaticallyIndexedArray = detail::same_type_tuple<T, NSize>; using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl<T, N>::type;
template <typename X, typename... Xs> template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) __host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
......
#ifndef CK_TRANSPOSE_VECTORS_AMD_HPP
#define CK_TRANSPOSE_VECTORS_AMD_HPP
#include "config.hpp"
#include "statically_indexed_array.hpp"
#include "data_type.hpp"
namespace ck {
template <typename S,
index_t NX,
index_t NY,
typename enable_if<is_scalar_type<S>::value, bool>::type = false>
struct transpose_vectors;
// transpose fp16 2x2
__device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1)
{
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx0{x0}, vx1{x1};
vector_type<half_t, 2> vy0, vy1;
vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
asm volatile("\n \
v_pack_b32_f16 %0, %1, %2 \n \
"
: "=v"(y0)
: "v"(x0), "v"(x1));
asm volatile("\n \
v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1] \n \
"
: "=v"(y1)
: "v"(x0), "v"(x1));
#endif
}
template <index_t NX, index_t NY>
struct transpose_vectors<half_t, NX, NY>
{
// we got [NY * NX] ammount of S data to be transposed
static constexpr index_t s_per_x = NY;
static constexpr index_t s_per_y = NX;
using S = half_t;
using VX = vector_type<half_t, s_per_x>;
using VY = vector_type<half_t, s_per_y>;
__device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
StaticallyIndexedArray<VY&, NY>& vy_tuple)
{
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// reference to 2 half2_t data from vx_tuple
const auto& x_s2_0 = vx_tuple[ix].template AsType<half2_t>()[iy / I2];
const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<half2_t>()[iy / I2];
// reference to 2 half2_t data from vy_tuple
auto& y_s2_0 = vy_tuple(iy).template AsType<half2_t>()(ix / I2);
auto& y_s2_1 = vy_tuple(iy + I1).template AsType<half2_t>()(ix / I2);
// transpose
transpose_fp16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1);
});
});
}
};
} // namespace ck
#endif
...@@ -117,6 +117,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -117,6 +117,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
// read access
template <index_t I> template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I>) const __host__ __device__ constexpr const auto& At(Number<I>) const
{ {
...@@ -124,6 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -124,6 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
return base::GetElementByKey(detail::TupleElementKey<I>{}); return base::GetElementByKey(detail::TupleElementKey<I>{});
} }
// write access
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto& At(Number<I>) __host__ __device__ constexpr auto& At(Number<I>)
{ {
...@@ -131,12 +133,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -131,12 +133,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
return base::GetElementByKey(detail::TupleElementKey<I>{}); return base::GetElementByKey(detail::TupleElementKey<I>{});
} }
// read access
template <index_t I> template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I> i) const __host__ __device__ constexpr const auto& operator[](Number<I> i) const
{ {
return At(i); return At(i);
} }
// write access
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i) __host__ __device__ constexpr auto& operator()(Number<I> i)
{ {
...@@ -162,5 +166,12 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs) ...@@ -162,5 +166,12 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...); return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
} }
// https://en.cppreference.com/w/cpp/utility/tuple/tie
template <typename... Args>
constexpr Tuple<Args&...> tie(Args&... args) noexcept
{
return {args...};
}
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment