"...composable_kernel_rocm.git" did not exist on "4ceba063652c36cd9de3a2b06f7f75e608109c27"
Commit cae751d1 authored by carlushuang's avatar carlushuang
Browse files

wip

parent 41659ab1
......@@ -661,6 +661,108 @@ CK_TILE_DEVICE auto async_load_fence(number<cnt>)
buffer_load_fence(number<cnt>{});
}
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
template<index_t N, typename T> struct smem_load_trait;
template<typename T> struct smem_load_trait<16, T> { using payload_t = fp32x4_t; };
template<typename T> struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; };
template<typename T> struct smem_load_trait<4 , T> { using payload_t = float; };
template<typename T> struct smem_load_trait<2 , T> { using payload_t = float; };
template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
// clang-format on
} // namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template<index_t>
struct smem_load ;
template<>
struct smem_load<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
index_t v_offset,
index_t i_offset)
{
static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t
asm volatile("ds_read_b128 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
index_t v_offset,
index_t i_offset)
{
static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
asm volatile("ds_read_b64 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
index_t v_offset,
index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
asm volatile("ds_read_b32 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
index_t v_offset,
index_t i_offset)
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
asm volatile("ds_read_u16 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
index_t v_offset,
index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u8 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
// clang-format off
namespace impl{
......@@ -1365,6 +1467,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_linear_addr_offset,
index_t flag = 0,
bool_constant<pre_nop> = {})
{
......@@ -1379,7 +1482,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
src_linear_addr_offset,
flag,
bool_constant<pre_nop>{});
}
......@@ -1389,7 +1492,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
src_linear_addr_offset,
flag,
bool_constant<pre_nop>{});
}
......@@ -2105,6 +2208,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size,
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
......@@ -2113,12 +2217,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
is_valid_element,
bool_constant<pre_nop>{});
}
......@@ -2132,16 +2238,19 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
is_valid_element,
bool_constant<pre_nop>{});
}
......
......@@ -352,7 +352,8 @@ struct buffer_view<address_space_enum::global,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
index_t i,
index_t v_offset,
index_t i_offset,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{
......@@ -366,7 +367,7 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
dst, cached_buf_res_, v_offset, i_offset, is_valid_element, bool_constant<pre_nop>{});
}
// i is offset of T, not X. i should be aligned to X
......@@ -733,6 +734,33 @@ struct buffer_view<address_space_enum::lds,
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
index_t v_offset,
index_t i_offset,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{
#if 0
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
#endif
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
......
......@@ -234,6 +234,7 @@ adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor,
return valid;
}
// TODO: not actually used in ck_tile, maybe can deprecate this
template <typename Adaptor, typename AdpatorCoord>
CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
const AdpatorCoord& coord)
......
......@@ -82,6 +82,7 @@ coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
}
// TODO: not actually used in ck_tile, maybe can deprecate this
template <typename TensorDesc, typename TensorCoord>
CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
......
......@@ -94,12 +94,14 @@ struct tensor_view
bool>::type = false>
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
index_t linear_offset,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst,
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
......
......@@ -398,6 +398,7 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
/**/,
bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
......
......@@ -33,6 +33,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
......
......@@ -51,10 +51,13 @@ struct WarpGemmAtrributeMfma
sequence<0, 2>>;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
Impl{}(c_vec, a_vec, b_vec);
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
......@@ -111,8 +114,11 @@ struct WarpGemmAtrributeMfmaIterateK
sequence<0, 2>>;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
......@@ -122,10 +128,33 @@ struct WarpGemmAtrributeMfmaIterateK
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
});
}
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
......@@ -194,11 +223,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
sequence<0, 2>>;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
// swap A and B
Impl{}(c_vec, b_vec, a_vec);
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
......@@ -255,12 +287,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
sequence<2, 2>,
sequence<0, 2>>;
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
// swap A and B
Impl{}(c_vec, b_vec, a_vec);
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
......@@ -316,9 +351,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
sequence<2, 2>,
sequence<0, 2>>;
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
......@@ -328,10 +366,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]);
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
});
}
template <index_t iKIter, bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
......@@ -429,8 +491,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
sequence<0, 2>>;
#endif
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
......@@ -440,10 +505,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]);
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
});
}
template <index_t kKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<kKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
......@@ -518,8 +606,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
sequence<0, 2>>;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
......@@ -529,10 +620,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
});
}
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
......
......@@ -7,12 +7,39 @@
namespace ck_tile {
enum class WGAttrCtlEnum
{
Default_ = 0,
Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
"s_nop 16" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
} \
else \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
}
// FP16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
{
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
......@@ -33,16 +60,30 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8f16", "+v", "a", "a", "v")
}
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
......@@ -59,11 +100,13 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
......@@ -84,16 +127,30 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "b", "v")
}
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
......@@ -111,11 +168,13 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
};
// Bf16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
{
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
......@@ -136,28 +195,42 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_32x32x8bf16_1k", "+v", "a", "a", "v")
}
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
......@@ -188,11 +261,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
{
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
......@@ -213,28 +288,42 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "v", "v", "v")
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16bf16_1k", "+v", "a", "a", "v")
}
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
......@@ -266,12 +355,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
};
// FP8
template <typename AType_, typename BType_>
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
{
using ADataType = AType_;
using BDataType = BType_;
using CDataType = float;
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = AType_;
using BDataType = BType_;
using CDataType = float;
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
......@@ -292,38 +382,82 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
}
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
}
}
else
{
#if defined(__gfx94__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ADataType>()[number<k>{}]);
float b_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<BDataType>()[number<k>{}]);
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
});
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ADataType>()[number<k>{}]);
float b_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<BDataType>()[number<k>{}]);
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
});
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
......@@ -363,13 +497,22 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t>;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t>;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t>;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t, Ctrl_>;
#undef DISPATCH_MFMA_
} // namespace ck_tile
......@@ -31,15 +31,18 @@ struct WarpGemmImpl
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
template <typename CTensor, typename ATensor, typename BTensor>
CK_TILE_DEVICE void operator()(CTensor& c, const ATensor& a, const BTensor& b) const
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
bool_constant<post_nop_> = {}) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, ATensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BTensor>);
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
......@@ -48,7 +51,30 @@ struct WarpGemmImpl
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec);
WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <<typename CTensor, typename ATensor, typename BTensor, index_t i_subk, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
number<i_subk>,
bool_constant<post_nop_> = {}) const
{
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec, number<i_subk>{}, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
......@@ -56,13 +82,13 @@ struct WarpGemmImpl
template <typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CWarpTensor c;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, ATensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BTensor>);
CTensor c;
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment