Commit eed60199 authored by carlushuang's avatar carlushuang
Browse files

more robust api

parent cae751d1
......@@ -25,7 +25,7 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
#if 1
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
......@@ -42,9 +42,24 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
{
TOPK_SOFTMAX_DISPATCH(64)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
}
#else
if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
}
#endif
}
else if(t.input_type == "bf16" && t.weight_type == "fp32")
{
#if 1
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
......@@ -64,6 +79,15 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
{
TOPK_SOFTMAX_DISPATCH(64)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
}
#endif
}
return -1;
}
......@@ -50,6 +50,7 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
......
......@@ -66,6 +66,20 @@ struct space_filling_curve
return idx_tail - idx_head;
}
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
static CK_TILE_HOST_DEVICE constexpr auto get_step_between_static(number<AccessIdx1dHead>,
number<AccessIdx1dTail>)
{
static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
"1D index out of range");
static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
"1D index out of range");
constexpr auto idx_head = get_index_static(number<AccessIdx1dHead>{});
constexpr auto idx_tail = get_index_static(number<AccessIdx1dTail>{});
return idx_tail - idx_head;
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>)
{
......@@ -73,6 +87,13 @@ struct space_filling_curve
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step_static(number<AccessIdx1d>)
{
static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
return get_step_between_static(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>)
{
......@@ -153,9 +174,9 @@ struct space_filling_curve
return idx_md;
}
// FIXME: rename this function
// FIXME: return tuple of number<>, which is compile time only variable
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>)
static CK_TILE_HOST_DEVICE constexpr auto get_index_static(number<AccessIdx1d>)
{
constexpr auto idx = get_index(number<AccessIdx1d>{});
......
......@@ -156,8 +156,8 @@ struct buffer_load<2, pre_nop>
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
// static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = ushort; // typename impl::buffer_load_trait<2, T>::payload_t;
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
......@@ -315,9 +315,9 @@ struct buffer_load_if<2, pre_nop>
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
// static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
using mbuf_t = ushort; // typename impl::buffer_load_trait<2, T>::payload_t;
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
......@@ -676,19 +676,17 @@ template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
} // namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template<index_t>
struct smem_load ;
template <index_t>
struct smem_load;
template<>
template <>
struct smem_load<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
index_t v_offset,
index_t i_offset)
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t;
asm volatile("ds_read_b128 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
......@@ -700,9 +698,7 @@ template <>
struct smem_load<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
index_t v_offset,
index_t i_offset)
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
......@@ -717,9 +713,7 @@ template <>
struct smem_load<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
index_t v_offset,
index_t i_offset)
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
......@@ -734,11 +728,10 @@ template <>
struct smem_load<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
index_t v_offset,
index_t i_offset)
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u16 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
......@@ -750,9 +743,7 @@ template <>
struct smem_load<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
index_t v_offset,
index_t i_offset)
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
......@@ -1879,6 +1870,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset,
index_t dst_linear_addr_offset,
index_t is_valid_element = 1)
{
constexpr index_t bytes = sizeof(T) * N;
......@@ -1892,7 +1884,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0,
dst_linear_addr_offset,
is_valid_element);
}
else
......@@ -1901,7 +1893,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
dst_linear_addr_offset);
}
}
......@@ -2266,6 +2258,7 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size,
bool_constant<pre_nop> = {})
{
......@@ -2273,9 +2266,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
amd_async_buffer_load_impl<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
......@@ -2286,12 +2284,18 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
amd_async_buffer_load_impl<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
......@@ -2302,16 +2306,18 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
0,
src_linear_addr_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
......@@ -2368,6 +2374,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
......@@ -2375,11 +2382,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
......
......@@ -635,6 +635,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return r;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
......@@ -649,6 +657,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return r;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
......@@ -686,6 +702,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return a * x;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
......
This diff is collapsed.
......@@ -11,7 +11,7 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
......@@ -31,6 +31,20 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
return tile_window.load(bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(bool_constant<oob_conditional_check>{});
}
template <typename T,
typename BottomTensorView_,
typename WindowLengths_,
......@@ -49,6 +63,24 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename T,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
// while creating the smem window, which can enable compiler properly detect the
// dependency if using multiple smem window (multiple buffer)
......@@ -69,6 +101,22 @@ async_load_tile(LdsTileWindow_&& lds_tile,
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
......@@ -89,6 +137,25 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename WindowLengths>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
{
......@@ -100,4 +167,20 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo
{
}
// TODO: this function requires some sub-fileds exist for the target tile window
template <typename TileWindow, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
using TileDstr = typename TileWindow::TileDstr;
using DataType = typename TileWindow::DataType;
auto t = make_static_distributed_tensor<DataType>(TileDstr{});
load_tile_raw(t, w, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return t;
}
} // namespace ck_tile
......@@ -10,6 +10,7 @@
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
......@@ -90,4 +91,30 @@ store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
tile_window.store_raw(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_>
CK_TILE_DEVICE void store_tile(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_>
CK_TILE_DEVICE void store_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store_raw(dstr_tensor);
}
} // namespace ck_tile
......@@ -75,14 +75,34 @@ struct tensor_view
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get<X>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element, // flag
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get<X>(coord.get_offset(),
linear_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
......@@ -106,6 +126,24 @@ struct tensor_view
bool_constant<pre_nop>{});
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
......@@ -114,26 +152,71 @@ struct tensor_view
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
const TensorCoord& coord,
index_t linear_offset) const
{
return buf_.template async_get<X>(
smem,
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element) const
{
return buf_.template async_get<X>(smem,
coord.get_offset(),
linear_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool_constant<pre_nop> = {}) const
{
return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
smem,
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{
return buf_.template async_get_raw<X>(
smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
}
// X is vector of DataType.
......@@ -144,11 +227,15 @@ struct tensor_view
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
CK_TILE_HOST_DEVICE constexpr void
set_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template set<X, oob_conditional_check>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
......@@ -159,15 +246,53 @@ struct tensor_view
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
CK_TILE_HOST_DEVICE constexpr void
set_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template set<X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
set_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
set_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
......@@ -176,15 +301,36 @@ struct tensor_view
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");
......
......@@ -454,6 +454,7 @@ struct tile_distribution_detail
} // namespace detail
#if 0
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
......@@ -490,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
#endif
// this returns a static tile_distribution
template <typename StaticTileDistributionEncoding_>
......
......@@ -223,10 +223,11 @@ struct tile_window_with_static_distribution
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const AdaptorTopIndex& idx_diff_adaptor_top) const
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
......@@ -309,7 +310,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
......@@ -337,10 +338,11 @@ struct tile_window_with_static_distribution
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step_static(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
......@@ -398,7 +400,7 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
/**/,
0 /**/,
bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
......@@ -484,7 +486,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, pre_nop_);
smem, bottom_tensor_thread_coord, 0, pre_nop_);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......@@ -552,7 +554,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......@@ -618,7 +620,10 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......@@ -676,7 +681,7 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, vec_value);
bottom_tensor_thread_coord, 0, vec_value);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......@@ -736,7 +741,10 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......@@ -868,6 +876,27 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view, window_lengths, origin, tile_distribution};
}
// this version must not be called under a constexpr context
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord = 1>
CK_TILE_DEVICE auto
make_tile_window_raw(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
number<NumCoord> = {})
{
auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution};
w.init_raw();
return w;
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
......@@ -992,6 +1021,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
tile_distribution);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution)
{
auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution);
w.init_raw();
return w;
}
template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
......
This diff is collapsed.
......@@ -58,10 +58,18 @@ struct magic_division32_bit_range
// magic division for uint32_t
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
if(__builtin_is_constant_evaluated())
{
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
else
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
}
CK_TILE_HOST static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
......@@ -76,11 +84,20 @@ struct magic_division32_bit_range
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
if(__builtin_is_constant_evaluated())
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
else
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
}
CK_TILE_HOST static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
......
......@@ -11,6 +11,7 @@
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
......@@ -532,6 +533,28 @@ struct HostTensor
typename Data::size_type size() const { return mData.size(); }
// return a slice of this tensor
// for simplicity we just copy the data and return a new tensor
auto slice(std::vector<size_t> s_begin, std::vector<size_t> s_end) const
{
assert(s_begin.size() == s_end.size());
assert(s_begin.size() == get_num_of_dimension());
std::vector<size_t> s_len(s_begin.size());
std::transform(
s_end.begin(), s_end.end(), s_begin.begin(), s_len.begin(), std::minus<size_t>{});
HostTensor<T> sliced_tensor(s_len);
sliced_tensor.ForEach([&](auto& self, auto idx) {
std::vector<size_t> src_idx(idx.size());
std::transform(
idx.begin(), idx.end(), s_begin.begin(), src_idx.begin(), std::plus<size_t>{});
self(idx) = operator()(src_idx);
});
return sliced_tensor;
}
template <typename U = T>
auto AsSpan() const
{
......
......@@ -229,7 +229,7 @@ struct BlockFmhaPipelineQRAsyncEx
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_accs));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
......@@ -336,7 +336,7 @@ struct BlockFmhaPipelineQRAsyncEx
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
clear_tile(s_accs); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
......@@ -350,7 +350,7 @@ struct BlockFmhaPipelineQRAsyncEx
async_load_fence(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
gemm_0(s_accs,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
......@@ -373,7 +373,7 @@ struct BlockFmhaPipelineQRAsyncEx
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(
s_acc,
s_accs,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_load,
......@@ -385,8 +385,8 @@ struct BlockFmhaPipelineQRAsyncEx
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
s_accs = tile_elementwise_in(s_acc_element_func, s_accs);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_accs);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
......@@ -396,33 +396,33 @@ struct BlockFmhaPipelineQRAsyncEx
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
s_accs,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
constexpr auto s_spans = decltype(s_accs)::get_distributed_spans();
s_accs = tile_elementwise_in(s_acc_element_func, s_accs);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
s_accs.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
s_accs(i_j_idx) *= scale_s;
position_encoding.update(s_accs(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
s_accs = tile_elementwise_in(s_acc_element_func, s_accs);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_accs);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
......@@ -437,7 +437,7 @@ struct BlockFmhaPipelineQRAsyncEx
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
s_accs, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
......@@ -445,7 +445,7 @@ struct BlockFmhaPipelineQRAsyncEx
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
const auto s = cast_tile<SMPLComputeDataType>(s_accs); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
......
......@@ -10,114 +10,134 @@
namespace ck_tile {
// fp16
using WarpGemmMfmaF16F16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
using WarpGemmMfmaF16F16F32M16N16K16 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>;
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M32N32K16 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>;
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16,
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
// bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16,
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
// fp8
using WarpGemmMfma_f32_32x32x16_fp8_fp8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, WGAttrCtlEnum::Default_>,
2,
swizzle_factor>>;
......
......@@ -510,11 +510,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
});
}
template <index_t kKIter, bool post_nop_ = false>
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<kKIter>,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
......
......@@ -139,7 +139,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "b", "v")
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "a", "v")
}
else
{
......
......@@ -32,10 +32,8 @@ struct WarpGemmImpl
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
bool_constant<post_nop_> = {}) const
CK_TILE_DEVICE void
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, ATensor> &&
......@@ -56,7 +54,11 @@ struct WarpGemmImpl
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <<typename CTensor, typename ATensor, typename BTensor, index_t i_subk, bool post_nop_ = false>
template <typename CTensor,
typename ATensor,
typename BTensor,
index_t i_subk,
bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
......@@ -82,6 +84,7 @@ struct WarpGemmImpl
template <typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, ATensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BTensor>);
CTensor c;
......
......@@ -160,10 +160,9 @@ CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// TODO: lid_over_rid_derivative not ok in xor? maybe need limit the usage of
// xor
index_t src_lane = (__lane_id() * lid_over_rid_derivative) ^
(number<1 << istage.value>{}.value);
index_t src_lane =
__lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment