Commit eed60199 authored by carlushuang's avatar carlushuang
Browse files

more robust api

parent cae751d1
...@@ -25,7 +25,7 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c ...@@ -25,7 +25,7 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
using ts_input_type = ck_tile::fp16_t; using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float; using ts_weight_type = float;
using ts_index_type = ck_tile::index_t; using ts_index_type = ck_tile::index_t;
#if 1
if(t.experts <= 8) if(t.experts <= 8)
{ {
TOPK_SOFTMAX_DISPATCH(8) TOPK_SOFTMAX_DISPATCH(8)
...@@ -42,9 +42,24 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c ...@@ -42,9 +42,24 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
{ {
TOPK_SOFTMAX_DISPATCH(64) TOPK_SOFTMAX_DISPATCH(64)
} }
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
}
#else
if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
}
#endif
} }
else if(t.input_type == "bf16" && t.weight_type == "fp32") else if(t.input_type == "bf16" && t.weight_type == "fp32")
{ {
#if 1
using ts_input_type = ck_tile::bf16_t; using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float; using ts_weight_type = float;
using ts_index_type = ck_tile::index_t; using ts_index_type = ck_tile::index_t;
...@@ -64,6 +79,15 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c ...@@ -64,6 +79,15 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
{ {
TOPK_SOFTMAX_DISPATCH(64) TOPK_SOFTMAX_DISPATCH(64)
} }
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
}
#endif
} }
return -1; return -1;
} }
...@@ -50,6 +50,7 @@ ...@@ -50,6 +50,7 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
......
...@@ -66,6 +66,20 @@ struct space_filling_curve ...@@ -66,6 +66,20 @@ struct space_filling_curve
return idx_tail - idx_head; return idx_tail - idx_head;
} }
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
static CK_TILE_HOST_DEVICE constexpr auto get_step_between_static(number<AccessIdx1dHead>,
number<AccessIdx1dTail>)
{
static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
"1D index out of range");
static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
"1D index out of range");
constexpr auto idx_head = get_index_static(number<AccessIdx1dHead>{});
constexpr auto idx_tail = get_index_static(number<AccessIdx1dTail>{});
return idx_tail - idx_head;
}
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>) static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>)
{ {
...@@ -73,6 +87,13 @@ struct space_filling_curve ...@@ -73,6 +87,13 @@ struct space_filling_curve
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{}); return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
} }
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step_static(number<AccessIdx1d>)
{
static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
return get_step_between_static(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
}
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>) static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>)
{ {
...@@ -153,9 +174,9 @@ struct space_filling_curve ...@@ -153,9 +174,9 @@ struct space_filling_curve
return idx_md; return idx_md;
} }
// FIXME: rename this function // FIXME: return tuple of number<>, which is compile time only variable
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>) static CK_TILE_HOST_DEVICE constexpr auto get_index_static(number<AccessIdx1d>)
{ {
constexpr auto idx = get_index(number<AccessIdx1d>{}); constexpr auto idx = get_index(number<AccessIdx1d>{});
......
...@@ -156,8 +156,8 @@ struct buffer_load<2, pre_nop> ...@@ -156,8 +156,8 @@ struct buffer_load<2, pre_nop>
index_t /*flag*/ = 0, index_t /*flag*/ = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually // static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; using mbuf_t = ushort; // typename impl::buffer_load_trait<2, T>::payload_t;
if constexpr(pre_nop) if constexpr(pre_nop)
asm volatile("s_nop 4\n" asm volatile("s_nop 4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3" "buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
...@@ -315,9 +315,9 @@ struct buffer_load_if<2, pre_nop> ...@@ -315,9 +315,9 @@ struct buffer_load_if<2, pre_nop>
index_t flag = 0, index_t flag = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); // static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; using mbuf_t = ushort; // typename impl::buffer_load_trait<2, T>::payload_t;
if constexpr(pre_nop) if constexpr(pre_nop)
asm volatile("s_nop 4\n" asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n" "v_cmpx_le_u32 exec, 1, %4\n"
...@@ -676,23 +676,21 @@ template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; }; ...@@ -676,23 +676,21 @@ template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
} // namespace impl } // namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :) // NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template<index_t> template <index_t>
struct smem_load ; struct smem_load;
template<> template <>
struct smem_load<16> struct smem_load<16>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
index_t v_offset,
index_t i_offset)
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t;
asm volatile("ds_read_b128 %0, %1 offset:%2" asm volatile("ds_read_b128 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write : "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset) : "v"(v_offset), "n"(i_offset)
: "memory"); : "memory");
} }
}; };
...@@ -700,16 +698,14 @@ template <> ...@@ -700,16 +698,14 @@ template <>
struct smem_load<8> struct smem_load<8>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
index_t v_offset,
index_t i_offset)
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t; using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
asm volatile("ds_read_b64 %0, %1 offset:%2" asm volatile("ds_read_b64 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write : "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset) : "v"(v_offset), "n"(i_offset)
: "memory"); : "memory");
} }
}; };
...@@ -717,16 +713,14 @@ template <> ...@@ -717,16 +713,14 @@ template <>
struct smem_load<4> struct smem_load<4>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
index_t v_offset,
index_t i_offset)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t; using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
asm volatile("ds_read_b32 %0, %1 offset:%2" asm volatile("ds_read_b32 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write : "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset) : "v"(v_offset), "n"(i_offset)
: "memory"); : "memory");
} }
}; };
...@@ -734,15 +728,14 @@ template <> ...@@ -734,15 +728,14 @@ template <>
struct smem_load<2> struct smem_load<2>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
index_t v_offset,
index_t i_offset)
{ {
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u16 %0, %1 offset:%2" asm volatile("ds_read_u16 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write : "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset) : "v"(v_offset), "n"(i_offset)
: "memory"); : "memory");
} }
}; };
...@@ -750,16 +743,14 @@ template <> ...@@ -750,16 +743,14 @@ template <>
struct smem_load<1> struct smem_load<1>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
index_t v_offset,
index_t i_offset)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t; using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u8 %0, %1 offset:%2" asm volatile("ds_read_u8 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write : "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset) : "v"(v_offset), "n"(i_offset)
: "memory"); : "memory");
} }
}; };
...@@ -1879,6 +1870,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1879,6 +1870,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
int32x4_t dst_wave_buffer_resource, int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset, index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset, index_t dst_wave_addr_offset,
index_t dst_linear_addr_offset,
index_t is_valid_element = 1) index_t is_valid_element = 1)
{ {
constexpr index_t bytes = sizeof(T) * N; constexpr index_t bytes = sizeof(T) * N;
...@@ -1892,7 +1884,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1892,7 +1884,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0, dst_linear_addr_offset,
is_valid_element); is_valid_element);
} }
else else
...@@ -1901,7 +1893,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1901,7 +1893,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); dst_linear_addr_offset);
} }
} }
...@@ -2266,6 +2258,7 @@ template <typename T, ...@@ -2266,6 +2258,7 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const T* p_src_wave, const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size, index_t src_element_space_size,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
...@@ -2273,9 +2266,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, ...@@ -2273,9 +2266,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(smem,
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{}); src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
} }
// This version support buffer resource as input arg // This version support buffer resource as input arg
...@@ -2286,12 +2284,18 @@ template <typename T, ...@@ -2286,12 +2284,18 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource, const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(smem,
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{}); src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
} }
// This version support buffer resource as input arg // This version support buffer resource as input arg
...@@ -2302,16 +2306,18 @@ template <typename T, ...@@ -2302,16 +2306,18 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem, CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const int32x4_t src_wave_buffer_resource, const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool is_valid_element, bool is_valid_element,
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load<T, N, coherence>(smem, amd_async_buffer_load<T, N, coherence>(smem,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
0, 0,
0, src_linear_addr_offset,
is_valid_element, is_valid_element,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
} }
...@@ -2368,6 +2374,7 @@ template <typename T, ...@@ -2368,6 +2374,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data, CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave, T* p_dst_wave,
const index_t dst_thread_element_offset, const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid, const bool dst_thread_element_valid,
const index_t dst_element_space_size) const index_t dst_element_space_size)
{ {
...@@ -2375,11 +2382,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d ...@@ -2375,11 +2382,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data, amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
0, 0,
dst_linear_addr_offset,
dst_thread_element_valid); dst_thread_element_valid);
} }
......
...@@ -635,6 +635,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y) ...@@ -635,6 +635,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
}
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
...@@ -649,6 +657,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y) ...@@ -649,6 +657,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
}
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
...@@ -686,6 +702,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a) ...@@ -686,6 +702,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return a * x; return a * x;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
}
template <typename... Xs, typename... Ys> template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y) CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{ {
......
...@@ -91,8 +91,10 @@ struct buffer_view<address_space_enum::generic, ...@@ -91,8 +91,10 @@ struct buffer_view<address_space_enum::generic,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto get(index_t i,
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -107,11 +109,11 @@ struct buffer_view<address_space_enum::generic, ...@@ -107,11 +109,11 @@ struct buffer_view<address_space_enum::generic,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp; X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
return tmp; return tmp;
#else #else
return *c_style_pointer_cast<const X*>(&p_data_[i]); return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
#endif #endif
} }
else else
...@@ -134,17 +136,17 @@ struct buffer_view<address_space_enum::generic, ...@@ -134,17 +136,17 @@ struct buffer_view<address_space_enum::generic,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
if constexpr(Op == memory_operation_enum::set) if constexpr(Op == memory_operation_enum::set)
{ {
this->template set<X>(i, is_valid_element, x); this->template set<X>(i, linear_offset, is_valid_element, x);
} }
// FIXME: remove memory_operation_enum::add // FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add) else if constexpr(Op == memory_operation_enum::add)
{ {
auto tmp = this->template get<X>(i, is_valid_element); auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp); this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
} }
} }
...@@ -154,7 +156,7 @@ struct buffer_view<address_space_enum::generic, ...@@ -154,7 +156,7 @@ struct buffer_view<address_space_enum::generic,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -169,9 +171,9 @@ struct buffer_view<address_space_enum::generic, ...@@ -169,9 +171,9 @@ struct buffer_view<address_space_enum::generic,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x; X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
#else #else
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
#endif #endif
} }
} }
...@@ -276,8 +278,10 @@ struct buffer_view<address_space_enum::global, ...@@ -276,8 +278,10 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto get(index_t i,
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -303,7 +307,7 @@ struct buffer_view<address_space_enum::global, ...@@ -303,7 +307,7 @@ struct buffer_view<address_space_enum::global,
t_per_x, t_per_x,
Coherence, Coherence,
oob_conditional_check>( oob_conditional_check>(
p_data_, i, is_valid_element, buffer_size_); p_data_, i + linear_offset, is_valid_element, buffer_size_);
} }
else else
{ {
...@@ -311,8 +315,11 @@ struct buffer_view<address_space_enum::global, ...@@ -311,8 +315,11 @@ struct buffer_view<address_space_enum::global,
remove_cvref_t<T>, remove_cvref_t<T>,
t_per_x, t_per_x,
Coherence, Coherence,
oob_conditional_check>( oob_conditional_check>(p_data_,
p_data_, i, is_valid_element, buffer_size_, invalid_element_value_); i + linear_offset,
is_valid_element,
buffer_size_,
invalid_element_value_);
} }
} }
else else
...@@ -322,11 +329,11 @@ struct buffer_view<address_space_enum::global, ...@@ -322,11 +329,11 @@ struct buffer_view<address_space_enum::global,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp; X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
return tmp; return tmp;
#else #else
return *c_style_pointer_cast<const X*>(&p_data_[i]); return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
#endif #endif
} }
else else
...@@ -379,6 +386,7 @@ struct buffer_view<address_space_enum::global, ...@@ -379,6 +386,7 @@ struct buffer_view<address_space_enum::global,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto async_get(CK_TILE_LDS_ADDR remove_cvref_t<T>* smem, CK_TILE_DEVICE constexpr auto async_get(CK_TILE_LDS_ADDR remove_cvref_t<T>* smem,
index_t i, index_t i,
index_t linear_offset,
bool is_valid_element, bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
...@@ -392,7 +400,12 @@ struct buffer_view<address_space_enum::global, ...@@ -392,7 +400,12 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>( amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem, cached_buf_res_, i, is_valid_element, bool_constant<oob_conditional_check>{}); smem,
cached_buf_res_,
i,
linear_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
} }
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
...@@ -404,6 +417,7 @@ struct buffer_view<address_space_enum::global, ...@@ -404,6 +417,7 @@ struct buffer_view<address_space_enum::global,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem, CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
index_t i, index_t i,
index_t linear_offset,
bool /*is_valid_element*/, bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
...@@ -417,7 +431,7 @@ struct buffer_view<address_space_enum::global, ...@@ -417,7 +431,7 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>( amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
smem, cached_buf_res_, i, bool_constant<pre_nop>{}); smem, cached_buf_res_, i, linear_offset, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
...@@ -427,11 +441,11 @@ struct buffer_view<address_space_enum::global, ...@@ -427,11 +441,11 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
if constexpr(Op == memory_operation_enum::set) if constexpr(Op == memory_operation_enum::set)
{ {
this->template set<X>(i, is_valid_element, x); this->template set<X>(i, linear_offset, is_valid_element, x);
} }
else if constexpr(Op == memory_operation_enum::atomic_add) else if constexpr(Op == memory_operation_enum::atomic_add)
{ {
...@@ -458,7 +472,7 @@ struct buffer_view<address_space_enum::global, ...@@ -458,7 +472,7 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -479,7 +493,7 @@ struct buffer_view<address_space_enum::global, ...@@ -479,7 +493,7 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>( amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>(
x, p_data_, i, is_valid_element, buffer_size_); x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
} }
else else
{ {
...@@ -488,9 +502,9 @@ struct buffer_view<address_space_enum::global, ...@@ -488,9 +502,9 @@ struct buffer_view<address_space_enum::global,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x; X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
#else #else
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
#endif #endif
} }
} }
...@@ -503,7 +517,7 @@ struct buffer_view<address_space_enum::global, ...@@ -503,7 +517,7 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -515,7 +529,7 @@ struct buffer_view<address_space_enum::global, ...@@ -515,7 +529,7 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>( amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
x, p_data_, i, is_valid_element, buffer_size_); x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
} }
template <typename X, template <typename X,
...@@ -523,7 +537,8 @@ struct buffer_view<address_space_enum::global, ...@@ -523,7 +537,8 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void
atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type; using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
...@@ -558,13 +573,13 @@ struct buffer_view<address_space_enum::global, ...@@ -558,13 +573,13 @@ struct buffer_view<address_space_enum::global,
if constexpr(use_amd_buffer_addressing) if constexpr(use_amd_buffer_addressing)
{ {
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, buffer_size_); x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
} }
else else
{ {
if(is_valid_element) if(is_valid_element)
{ {
atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x); atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
} }
} }
} }
...@@ -574,7 +589,8 @@ struct buffer_view<address_space_enum::global, ...@@ -574,7 +589,8 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void
atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -598,11 +614,11 @@ struct buffer_view<address_space_enum::global, ...@@ -598,11 +614,11 @@ struct buffer_view<address_space_enum::global,
if constexpr(use_amd_buffer_addressing) if constexpr(use_amd_buffer_addressing)
{ {
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, buffer_size_); x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
} }
else if(is_valid_element) else if(is_valid_element)
{ {
atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x); atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
} }
} }
...@@ -694,8 +710,10 @@ struct buffer_view<address_space_enum::lds, ...@@ -694,8 +710,10 @@ struct buffer_view<address_space_enum::lds,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto get(index_t i,
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -710,14 +728,14 @@ struct buffer_view<address_space_enum::lds, ...@@ -710,14 +728,14 @@ struct buffer_view<address_space_enum::lds,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp; X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
return tmp; return tmp;
#else #else
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type, using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
scalar_per_t_vector * scalar_per_x_vector>; scalar_per_t_vector * scalar_per_x_vector>;
// using buf_t = ushort __attribute__((ext_vector_type(8))); // using buf_t = ushort __attribute__((ext_vector_type(8)));
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i]); auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
return bit_cast<X>(rtn); return bit_cast<X>(rtn);
#endif #endif
} }
...@@ -745,7 +763,7 @@ struct buffer_view<address_space_enum::lds, ...@@ -745,7 +763,7 @@ struct buffer_view<address_space_enum::lds,
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst, CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
index_t v_offset, index_t v_offset,
index_t i_offset, index_t i_offset,
bool is_valid_element, bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
#if 0 #if 0
...@@ -768,17 +786,17 @@ struct buffer_view<address_space_enum::lds, ...@@ -768,17 +786,17 @@ struct buffer_view<address_space_enum::lds,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
if constexpr(Op == memory_operation_enum::set) if constexpr(Op == memory_operation_enum::set)
{ {
this->template set<X>(i, is_valid_element, x); this->template set<X>(i, linear_offset, is_valid_element, x);
} }
// FIXME: remove memory_operation_enum::add // FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add) else if constexpr(Op == memory_operation_enum::add)
{ {
auto tmp = this->template get<X>(i, is_valid_element); auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp); this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
} }
} }
...@@ -788,7 +806,7 @@ struct buffer_view<address_space_enum::lds, ...@@ -788,7 +806,7 @@ struct buffer_view<address_space_enum::lds,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -804,6 +822,7 @@ struct buffer_view<address_space_enum::lds, ...@@ -804,6 +822,7 @@ struct buffer_view<address_space_enum::lds,
bool constexpr workaround_int8_ds_write_issue = false; bool constexpr workaround_int8_ds_write_issue = false;
#endif #endif
i += linear_offset; // simplicity
if constexpr(std::is_same<typename vector_traits<remove_cvref_t<T>>::scalar_type, if constexpr(std::is_same<typename vector_traits<remove_cvref_t<T>>::scalar_type,
int8_t>::value && int8_t>::value &&
workaround_int8_ds_write_issue) workaround_int8_ds_write_issue)
...@@ -1005,8 +1024,10 @@ struct buffer_view<address_space_enum::vgpr, ...@@ -1005,8 +1024,10 @@ struct buffer_view<address_space_enum::vgpr,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto get(index_t i,
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const index_t /*linear_offset*/,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -1048,17 +1069,17 @@ struct buffer_view<address_space_enum::vgpr, ...@@ -1048,17 +1069,17 @@ struct buffer_view<address_space_enum::vgpr,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
if constexpr(Op == memory_operation_enum::set) if constexpr(Op == memory_operation_enum::set)
{ {
this->template set<X>(i, is_valid_element, x); this->template set<X>(i, linear_offset, is_valid_element, x);
} }
// FIXME: remove memory_operation_enum::add // FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add) else if constexpr(Op == memory_operation_enum::add)
{ {
auto tmp = this->template get<X>(i, is_valid_element); auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp); this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
} }
} }
...@@ -1068,7 +1089,7 @@ struct buffer_view<address_space_enum::vgpr, ...@@ -1068,7 +1089,7 @@ struct buffer_view<address_space_enum::vgpr,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -1083,9 +1104,9 @@ struct buffer_view<address_space_enum::vgpr, ...@@ -1083,9 +1104,9 @@ struct buffer_view<address_space_enum::vgpr,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS #if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x; X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
#else #else
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
#endif #endif
} }
} }
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp" #include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp" #include "ck_tile/core/tensor/null_tensor.hpp"
...@@ -31,6 +31,20 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT ...@@ -31,6 +31,20 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
return tile_window.load(bool_constant<oob_conditional_check>{}); return tile_window.load(bool_constant<oob_conditional_check>{});
} }
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(bool_constant<oob_conditional_check>{});
}
template <typename T, template <typename T,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
...@@ -49,6 +63,24 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -49,6 +63,24 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename T,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem // for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
// while creating the smem window, which can enable compiler properly detect the // while creating the smem window, which can enable compiler properly detect the
// dependency if using multiple smem window (multiple buffer) // dependency if using multiple smem window (multiple buffer)
...@@ -69,6 +101,22 @@ async_load_tile(LdsTileWindow_&& lds_tile, ...@@ -69,6 +101,22 @@ async_load_tile(LdsTileWindow_&& lds_tile,
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{}); return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{});
} }
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
...@@ -89,6 +137,25 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -89,6 +137,25 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename WindowLengths> template <typename WindowLengths>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&) CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
{ {
...@@ -100,4 +167,20 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo ...@@ -100,4 +167,20 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo
{ {
} }
// TODO: this function requires some sub-fileds exist for the target tile window
template <typename TileWindow, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
using TileDstr = typename TileWindow::TileDstr;
using DataType = typename TileWindow::DataType;
auto t = make_static_distributed_tensor<DataType>(TileDstr{});
load_tile_raw(t, w, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return t;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -90,4 +91,30 @@ store_tile_raw(tile_window_with_static_distribution<BottomTensorView_, ...@@ -90,4 +91,30 @@ store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
tile_window.store_raw(dstr_tensor); tile_window.store_raw(dstr_tensor);
} }
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_>
CK_TILE_DEVICE void store_tile(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_>
CK_TILE_DEVICE void store_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store_raw(dstr_tensor);
}
} // namespace ck_tile } // namespace ck_tile
...@@ -75,14 +75,34 @@ struct tensor_view ...@@ -75,14 +75,34 @@ struct tensor_view
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X> CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_vectorized_elements(const TensorCoord& coord, get_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
return buf_.template get<X>( return buf_.template get<X>(
coord.get_offset(), coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
} }
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element, // flag
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get<X>(coord.get_offset(),
linear_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
// X is vector of DataType. // X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X // "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X, template <typename X,
...@@ -106,6 +126,24 @@ struct tensor_view ...@@ -106,6 +126,24 @@ struct tensor_view
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
}
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
typename std::enable_if< typename std::enable_if<
...@@ -114,26 +152,71 @@ struct tensor_view ...@@ -114,26 +152,71 @@ struct tensor_view
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem, async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const const TensorCoord& coord,
index_t linear_offset) const
{ {
return buf_.template async_get<X>( return buf_.template async_get<X>(
smem, smem,
coord.get_offset(), coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
} }
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element) const
{
return buf_.template async_get<X>(smem,
coord.get_offset(),
linear_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
template <typename X, template <typename X,
bool pre_nop = false, bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw( CK_TILE_HOST_DEVICE constexpr void
remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool_constant<pre_nop> = {}) const
{ {
return buf_.template async_get_raw<X>( return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{}); smem,
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{
return buf_.template async_get_raw<X>(
smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
} }
// X is vector of DataType. // X is vector of DataType.
...@@ -144,11 +227,15 @@ struct tensor_view ...@@ -144,11 +227,15 @@ struct tensor_view
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements( CK_TILE_HOST_DEVICE constexpr void
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {}) set_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {})
{ {
buf_.template set<X, oob_conditional_check>( buf_.template set<X, oob_conditional_check>(
coord.get_offset(), coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x); x);
} }
...@@ -159,15 +246,53 @@ struct tensor_view ...@@ -159,15 +246,53 @@ struct tensor_view
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw( CK_TILE_HOST_DEVICE constexpr void
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {}) set_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template set<X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
set_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {})
{ {
buf_.template set_raw<X, oob_conditional_check>( buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(), coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x); x);
} }
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
set_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
// X is vector of DataType. // X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X // "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X, template <typename X,
...@@ -176,15 +301,36 @@ struct tensor_view ...@@ -176,15 +301,36 @@ struct tensor_view
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements( CK_TILE_HOST_DEVICE constexpr void
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {}) update_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {})
{ {
buf_.template update<DstInMemOp, X, oob_conditional_check>( buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(), coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x); x);
} }
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const CK_TILE_HOST_DEVICE void print() const
{ {
printf("tensor_view{"); printf("tensor_view{");
......
...@@ -454,6 +454,7 @@ struct tile_distribution_detail ...@@ -454,6 +454,7 @@ struct tile_distribution_detail
} // namespace detail } // namespace detail
#if 0
// this returns a constexpr tile_distribution // this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_> template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_) CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
...@@ -490,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution ...@@ -490,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{ detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor}; ps_ys_to_xs_adaptor, ys_to_d_descriptor};
} }
#endif
// this returns a static tile_distribution // this returns a static tile_distribution
template <typename StaticTileDistributionEncoding_> template <typename StaticTileDistributionEncoding_>
......
...@@ -223,10 +223,11 @@ struct tile_window_with_static_distribution ...@@ -223,10 +223,11 @@ struct tile_window_with_static_distribution
// move thread's window adaptor coordinate and bottom tensor coordinate // move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord, WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord, BottomTensorCoord& bottom_tensor_thread_coord,
const AdaptorTopIndex& idx_diff_adaptor_top) const const ATopIndex& idx_diff_adaptor_top) const
{ {
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom; array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
...@@ -309,7 +310,7 @@ struct tile_window_with_static_distribution ...@@ -309,7 +310,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor // read from bottom tensor
const vector_t vec_value = const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>( get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{}); bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
#if 1 #if 1
// write into distributed tensor // write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
...@@ -337,10 +338,11 @@ struct tile_window_with_static_distribution ...@@ -337,10 +338,11 @@ struct tile_window_with_static_distribution
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{ {
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ys = SFC_Ys::get_forward_step_static(iAccess);
constexpr auto idx_diff_ps_ys = constexpr auto idx_diff_ps_ys = container_concat(
container_concat(array<index_t, NDimP>{0}, idx_diff_ys); generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate( move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
...@@ -398,7 +400,7 @@ struct tile_window_with_static_distribution ...@@ -398,7 +400,7 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>( get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(), dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord, bottom_tensor_thread_coord,
/**/, 0 /**/,
bool_constant<oob_conditional_check>{}, bool_constant<oob_conditional_check>{},
pre_nop_); pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \ #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
...@@ -484,7 +486,7 @@ struct tile_window_with_static_distribution ...@@ -484,7 +486,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor // read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>( get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, pre_nop_); smem, bottom_tensor_thread_coord, 0, pre_nop_);
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -552,7 +554,7 @@ struct tile_window_with_static_distribution ...@@ -552,7 +554,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor // read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>( get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{}); smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -618,7 +620,10 @@ struct tile_window_with_static_distribution ...@@ -618,7 +620,10 @@ struct tile_window_with_static_distribution
// write into bottom tensor // write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>( get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{}); bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -676,7 +681,7 @@ struct tile_window_with_static_distribution ...@@ -676,7 +681,7 @@ struct tile_window_with_static_distribution
// write into bottom tensor // write into bottom tensor
get_bottom_tensor_view() get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>( .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, vec_value); bottom_tensor_thread_coord, 0, vec_value);
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -736,7 +741,10 @@ struct tile_window_with_static_distribution ...@@ -736,7 +741,10 @@ struct tile_window_with_static_distribution
// write into bottom tensor // write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>( get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{}); bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -868,6 +876,27 @@ make_tile_window(const TensorView_& tensor_view, ...@@ -868,6 +876,27 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view, window_lengths, origin, tile_distribution}; tensor_view, window_lengths, origin, tile_distribution};
} }
// this version must not be called under a constexpr context
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord = 1>
CK_TILE_DEVICE auto
make_tile_window_raw(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
number<NumCoord> = {})
{
auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution};
w.init_raw();
return w;
}
template <typename TensorView_, template <typename TensorView_,
typename WindowLengths_, typename WindowLengths_,
typename StaticTileDistribution_, typename StaticTileDistribution_,
...@@ -992,6 +1021,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths ...@@ -992,6 +1021,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
tile_distribution); tile_distribution);
} }
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution)
{
auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution);
w.init_raw();
return w;
}
template <typename TensorView_, typename WindowLengths_> template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window( CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window, tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
//
// This version of tile window will pre-cache offset/flags based on need
//
// LinearBottomDims_, e.g seq<0, 1> for 2d tensor, the last one is linear dim
// so last dim can use immediate offset to indexing, can save register
// TODO: if using this struct, better use load_raw()/store_raw(), can control
// the the immediate offset on the fly
// space-filing-curve is non-snaked here!
//
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_>
struct tile_window_linear
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
using LinearBottomDims = remove_cvref_t<LinearBottomDims_>;
static_assert(LinearBottomDims::size() == BottomTensorView::get_num_of_dimension());
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::is_static(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct traits
{
private:
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths,
bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths =
bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides =
bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
window_adaptor_vector_lengths{-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
window_adaptor_vector_strides{-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims =
typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
using vector_t = thread_buffer<DataType, ScalarPerVector>;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto thread_tensor_lengths_ys =
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_),
false /*!!! no snaked curve! */>{};
}
public:
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
private:
static constexpr auto get_num_non_linear_access()
{
constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
using ys_to_rhs_major =
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto non_linear = [&]() {
index_t cnt = 1;
static_for<0, NDimY, 1>{}([&](auto i_dim_y) {
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
if constexpr(LinearBottomDims{}[target_h_dim] == 0)
{
cnt *= sfc_access_lens[i_dim_y];
}
});
return cnt;
}();
return non_linear;
}
// example:
// non_linear_access_map: sequence<0, 0, 0, 0, 1, 1, 1, 1> for 8 access, totally 2 register
// used
// -> histogram : sequence<4, 4>
// -> prefixsum : seqneuce<0, 4, 8>
// non_linear_access_map: sequence<0, 1, 2, 3, 4, 5, 6, 7> for 8 access, totally 8 register
// used, will pre-cache 8
// -> histogram : sequence<1, 1, 1, 1, 1, 1, 1, 1>
// -> prefixsum : seqneuce<0, 1, 2, 3, 4, 5, 6, 7, 8>
// non_linear_access_map: sequence<0, 0, 1, 1, 2, 2, 3, 3> for 8 access, totally 4 register
// used, will pre-cache 4
// -> histogram : sequence<2, 2, 2, 2>
// -> prefixsum : seqneuce<0, 2, 4, 6, 8>
static constexpr auto get_non_linear_access_map()
{
constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
using ys_to_rhs_major =
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto non_linear_map = [&]() {
array<index_t, NumAccess> m_{0};
index_t cumulative_len_ = 1;
index_t cumulative_non_linear_len_ = 1;
static_for<0, NDimY, 1>{}([&](auto i_y) {
constexpr auto i_dim_y = number<NDimY - i_y - 1>{}; // from right to left
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim];
array<index_t, NumAccess> current_m_{0};
constexpr auto current_len_ = sfc_access_lens[i_dim_y];
// copy cumulative length as current pattern
for(auto i_ = 0; i_ < cumulative_len_; i_++)
{
current_m_(i_) = m_[i_];
}
for(auto j_ = 0; j_ < current_len_; j_++)
{
auto j_offset_ = is_linear_dim ? 0 : j_ * cumulative_non_linear_len_;
for(auto i_ = 0; i_ < cumulative_len_; i_++)
{
m_(j_ * cumulative_len_ + i_) = current_m_[i_] + j_offset_;
}
}
cumulative_len_ *= current_len_;
if(!is_linear_dim)
cumulative_non_linear_len_ *= current_len_;
});
return m_;
}();
return TO_SEQUENCE(non_linear_map, NumAccess);
}
static constexpr auto get_non_linear_access_histogram()
{
constexpr auto m_ = get_non_linear_access_map();
// m_.foo();
constexpr auto r_ =
typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{};
constexpr auto h_ = histogram_sorted_sequence(m_, r_);
return h_;
}
static constexpr auto get_non_linear_access_histogram_prefix_sum()
{
constexpr auto h_ = get_non_linear_access_histogram();
constexpr auto h_prefix_sum_ = prefix_sum_sequence(h_);
return h_prefix_sum_;
}
public:
static constexpr index_t NumAccess_NonLinear = get_num_non_linear_access();
using AccessMap_NonLinear = decltype(get_non_linear_access_map()); // sequence
using AccessHistogram_NonLinear = decltype(get_non_linear_access_histogram());
using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum());
};
static constexpr index_t NumAccess = traits::NumAccess;
static constexpr index_t NumAccess_NonLinear = traits::NumAccess_NonLinear;
using AccessMap_NonLinear = typename traits::AccessMap_NonLinear;
using AccessHistogram_NonLinear = typename traits::AccessHistogram_NonLinear;
using AccessPrefixSum_NonLinear = typename traits::AccessPrefixSum_NonLinear;
CK_TILE_DEVICE constexpr tile_window_linear() = default;
CK_TILE_DEVICE constexpr tile_window_linear(const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
cached_coords_{},
cached_flags_{}
{
auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(make_tuple(get_warp_id(), get_lane_id()),
generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// future load/store() calls (might allocate more registers)
using SFC_Ys = typename traits::SFC_Ys;
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
constexpr auto need_save_non_linear_coord =
bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
if constexpr(need_save_non_linear_coord)
{
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
}
// TODO: need pad_tensor_view to check which dim need use flag to check
// cached flag is independent from non-linear-coord
// but need be updated in move_tile, with proper dims
cached_flags_(i_access) = coordinate_has_valid_offset_assuming_top_index_is_valid(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
if constexpr(i_access != (NumAccess - 1))
{
constexpr auto idx_diff_ys =
SFC_Ys::get_forward_step_static(i_access); // tuple of number
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord_tmp,
bottom_tensor_thread_coord_tmp,
idx_diff_ps_ys);
}
});
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
template <index_t i_access>
CK_TILE_DEVICE static constexpr auto get_bottom_linear_coordinate(number<i_access>)
{
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto idx_ys = SFC_Ys::get_index_static(number<i_access>{});
using ys_to_rhs_major =
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto modified_idx_ys = generate_tuple(
[&](auto i_dim_y) {
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
if constexpr(LinearBottomDims{}[target_h_dim] == 0)
{
return number<0>{};
}
else
{
return number<idx_ys[i_dim_y]>{};
}
},
number<NDimY>{});
constexpr auto adaptor_ = TileDstr{}.get_ps_ys_to_xs_adaptor();
constexpr auto idx_ =
container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys);
return adaptor_.calculate_bottom_index(idx_);
}
template <index_t i_access>
CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>)
{
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
// since this is linear offset, we assum bottom X tensor is always linear
constexpr index_t linear_offset = [&]() {
constexpr auto x_idx_ = linear_coord;
constexpr auto x_len_ = TileDstr{}.get_lengths();
static_assert(x_idx_.size() == x_len_.size());
constexpr index_t x_dims_ = x_idx_.size();
index_t cu_stride_ = 1;
index_t cu_offset_ = 0;
static_for<0, x_dims_, 1>{}([&](auto i_) {
auto r_i_ = number<x_dims_ - i_ - 1>{};
cu_offset_ += x_idx_[r_i_] * cu_stride_;
cu_stride_ *= x_len_[r_i_];
});
return cu_offset_;
}();
return linear_offset;
}
CK_TILE_DEVICE constexpr auto get_num_access() const { return traits::NumAccess; }
template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start_static);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
});
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % traits::ScalarPerVector == 0);
using vectorized_tbuf = array<vector_t, YElementSize / traits::ScalarPerVector>;
constexpr auto tile_dstr = TileDstr{};
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access == 0 &&
BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start_static = SFC_Ys::get_index_static(IAccess);
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start_static);
static_assert(d % traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
linear_offset /**/,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert(NumAccess_NonLinear == NumAccess);
static_assert(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global);
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
using vector_t = typename traits::vector_t;
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
// move thread coordinate
if constexpr(i_access != (NumAccess - 1))
{
m0_inc_with_memory(size_per_issue);
}
});
}
template <typename LdsTileWindow_, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert(NumAccess_NonLinear == NumAccess);
static_assert(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global);
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using vector_t = typename traits::vector_t;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
0,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(i_access != (NumAccess - 1))
{
smem += size_per_issue; // Note we manually increase the per-issue offset
}
});
}
template <bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{});
});
}
CK_TILE_DEVICE void
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
});
}
template <bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
{
window_origin_ += step;
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
constexpr auto need_update_non_linear_coord =
bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
if constexpr(need_update_non_linear_coord)
{
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
cached_coords_(non_linear_id),
step);
}
// move the current coord with linear_coords
auto tmp_coords = cached_coords_[non_linear_id];
constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess);
move_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
cached_flags_(IAccess) = coordinate_has_valid_offset_assuming_top_index_is_valid(
bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
});
}
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
TileDstr{}.get_ps_ys_to_xs_adaptor(),
container_concat(make_tuple(get_warp_id(), get_lane_id()),
generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// future load/store() calls (might allocate more registers)
using SFC_Ys = typename traits::SFC_Ys;
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
constexpr auto need_save_non_linear_coord =
bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
if constexpr(need_save_non_linear_coord)
{
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
}
if constexpr(i_access != (NumAccess - 1))
{
constexpr auto idx_diff_ys =
SFC_Ys::get_forward_step_static(i_access); // tuple of number
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord_tmp,
bottom_tensor_thread_coord_tmp,
idx_diff_ps_ys);
}
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
// this contains:
array<BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
array<bool, traits::NumAccess> cached_flags_;
};
namespace impl {
template <address_space_enum, index_t len_>
struct default_linear_bottom_dims_impl
{
using type = typename uniform_sequence_gen<len_, 0>::type;
};
template <index_t len_>
struct default_linear_bottom_dims_impl<address_space_enum::global, len_>
{
// global default to seq<0,0,....1>
using type = typename sequence_merge<typename uniform_sequence_gen<len_ - 1, 0>::type,
sequence<1>>::type;
};
template <index_t len_>
struct default_linear_bottom_dims_impl<address_space_enum::lds, len_>
{
// lds default to seq<1,1.....1>
using type = typename uniform_sequence_gen<len_, 1>::type;
};
} // namespace impl
template <typename TensorView_>
using default_linear_bottom_dims =
typename impl::default_linear_bottom_dims_impl<TensorView_::buffer_view::get_address_space(),
TensorView_::get_num_of_dimension()>::type;
// if using this API, will create a tile_window_linear
// this structure can have the chance to use immediate value, save register
// need pass in LinearBottomDims_ properly to control which dim is linear
// so to generate a constexpr offset as linear_offset for this dim
// (and finally pass to the immediate offset of buffer/lds instruction)
//
// Note: there is no internal check for which dim is OK to use linear offset
// user must make sure by themselves
//
// e.g.
// 2d global matrix, set LinearBottomDims_=seq<0, 1>, the last dim will generate
// immediate offset if each thread has multiple issue along last dim
//
// 2d LDS buffer, set LinearBottomDims_=seq<1, 1>, then only one vgpr used as offset
// everything else is just using immediate offset.
//
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
CK_TILE_DEVICE constexpr auto
make_tile_window_linear(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
LinearBottomDims_ = {})
{
static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
return tile_window_linear<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<LinearBottomDims_>>{
tensor_view, window_lengths, origin, tile_distribution};
}
template <
typename TileWindow_,
typename StaticTileDistribution_,
typename LinearBottomDims_ = default_linear_bottom_dims<typename TileWindow_::BottomTensorView>>
CK_TILE_DEVICE constexpr auto
make_tile_window_linear(const TileWindow_& tile_window,
const StaticTileDistribution_& tile_distribution,
LinearBottomDims_ = {})
{
return make_tile_window_linear(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
LinearBottomDims_{});
}
// this version must not be called under a constexpr context
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
CK_TILE_DEVICE auto
make_tile_window_linear_raw(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
LinearBottomDims_ = {})
{
static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
auto w = tile_window_linear<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<LinearBottomDims_>>{
tensor_view, window_lengths, origin, tile_distribution};
w.init_raw();
return w;
}
template <
typename TileWindow_,
typename StaticTileDistribution_,
typename LinearBottomDims_ = default_linear_bottom_dims<typename TileWindow_::BottomTensorView>>
CK_TILE_DEVICE constexpr auto
make_tile_window_linear_raw(const TileWindow_& tile_window,
const StaticTileDistribution_& tile_distribution,
LinearBottomDims_ = {})
{
return make_tile_window_linear_raw(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
LinearBottomDims_{});
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_>
CK_TILE_DEVICE void move_tile_window(
tile_window_linear<TensorView_, WindowLengths_, StaticTileDistribution_, LinearBottomDims_>&
window,
const typename tile_window_linear<TensorView_,
WindowLengths_,
StaticTileDistribution_,
LinearBottomDims_>::BottomTensorIndex& step)
{
window.move(step);
}
} // namespace ck_tile
...@@ -59,8 +59,16 @@ struct magic_division32_bit_range ...@@ -59,8 +59,16 @@ struct magic_division32_bit_range
CK_TILE_DEVICE static constexpr uint32_t CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift) do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{ {
uint32_t tmp = __umulhi(dividend, multiplier); if(__builtin_is_constant_evaluated())
return (tmp + dividend) >> shift; {
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
else
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
} }
CK_TILE_HOST static constexpr uint32_t CK_TILE_HOST static constexpr uint32_t
...@@ -77,9 +85,18 @@ struct magic_division32_bit_range ...@@ -77,9 +85,18 @@ struct magic_division32_bit_range
CK_TILE_DEVICE static constexpr int32_t CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{ {
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32); if(__builtin_is_constant_evaluated())
uint32_t tmp = __umulhi(dividend_u32, multiplier); {
return (tmp + dividend_u32) >> shift; uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
else
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
} }
CK_TILE_HOST static constexpr int32_t CK_TILE_HOST static constexpr int32_t
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <thread> #include <thread>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <functional>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp" #include "ck_tile/host/ranges.hpp"
...@@ -532,6 +533,28 @@ struct HostTensor ...@@ -532,6 +533,28 @@ struct HostTensor
typename Data::size_type size() const { return mData.size(); } typename Data::size_type size() const { return mData.size(); }
// return a slice of this tensor
// for simplicity we just copy the data and return a new tensor
auto slice(std::vector<size_t> s_begin, std::vector<size_t> s_end) const
{
assert(s_begin.size() == s_end.size());
assert(s_begin.size() == get_num_of_dimension());
std::vector<size_t> s_len(s_begin.size());
std::transform(
s_end.begin(), s_end.end(), s_begin.begin(), s_len.begin(), std::minus<size_t>{});
HostTensor<T> sliced_tensor(s_len);
sliced_tensor.ForEach([&](auto& self, auto idx) {
std::vector<size_t> src_idx(idx.size());
std::transform(
idx.begin(), idx.end(), s_begin.begin(), src_idx.begin(), std::plus<size_t>{});
self(idx) = operator()(src_idx);
});
return sliced_tensor;
}
template <typename U = T> template <typename U = T>
auto AsSpan() const auto AsSpan() const
{ {
......
...@@ -229,7 +229,7 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -229,7 +229,7 @@ struct BlockFmhaPipelineQRAsyncEx
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type // infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc)); using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_accs));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>( using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
...@@ -336,7 +336,7 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -336,7 +336,7 @@ struct BlockFmhaPipelineQRAsyncEx
do do
{ {
// STAGE 1, QK gemm // STAGE 1, QK gemm
clear_tile(s_acc); // initialize C clear_tile(s_accs); // initialize C
if constexpr(k0_loops > 1) if constexpr(k0_loops > 1)
{ {
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
...@@ -350,7 +350,7 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -350,7 +350,7 @@ struct BlockFmhaPipelineQRAsyncEx
async_load_fence(k_dram_window.get_num_access()); async_load_fence(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc, gemm_0(s_accs,
get_slice_tile( get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}), q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
...@@ -373,7 +373,7 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -373,7 +373,7 @@ struct BlockFmhaPipelineQRAsyncEx
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
{ // tail { // tail
gemm_0( gemm_0(
s_acc, s_accs,
get_slice_tile( get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}), q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_load, get_slice_tile(k_lds_load,
...@@ -385,8 +385,8 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -385,8 +385,8 @@ struct BlockFmhaPipelineQRAsyncEx
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_accs = tile_elementwise_in(s_acc_element_func, s_accs);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_accs);
tile_elementwise_inout( tile_elementwise_inout(
[&](auto& x, const auto& y) { [&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2 #if !CK_TILE_FMHA_FWD_FAST_EXP2
...@@ -396,33 +396,33 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -396,33 +396,33 @@ struct BlockFmhaPipelineQRAsyncEx
type_convert<SaccDataType>(bias_element_func(y)); type_convert<SaccDataType>(bias_element_func(y));
#endif #endif
}, },
s_acc, s_accs,
bias_tile); bias_tile);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto s_spans = decltype(s_accs)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_accs = tile_elementwise_in(s_acc_element_func, s_accs);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices( const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); s_accs.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s; s_accs(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col); position_encoding.update(s_accs(i_j_idx), row, col);
}); });
}); });
} }
else else
{ {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_accs = tile_elementwise_in(s_acc_element_func, s_accs);
#if !CK_TILE_FMHA_FWD_FAST_EXP2 #if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_accs);
#endif #endif
} }
move_tile_window(bias_dram_window, {0, kN0}); move_tile_window(bias_dram_window, {0, kN0});
...@@ -437,7 +437,7 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -437,7 +437,7 @@ struct BlockFmhaPipelineQRAsyncEx
if(need_perpixel_check) if(need_perpixel_check)
{ {
set_tile_if( set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) { s_accs, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col); return mask.IsOutOfBound(row, col);
...@@ -445,7 +445,7 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -445,7 +445,7 @@ struct BlockFmhaPipelineQRAsyncEx
} }
} }
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j} const auto s = cast_tile<SMPLComputeDataType>(s_accs); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>( auto m_local = block_tile_reduce<SMPLComputeDataType>(
s, s,
sequence<1>{}, sequence<1>{},
......
...@@ -10,114 +10,134 @@ ...@@ -10,114 +10,134 @@
namespace ck_tile { namespace ck_tile {
// fp16 // fp16
using WarpGemmMfmaF16F16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
using WarpGemmMfmaF16F16F32M16N16K16 = using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M32N32K16 = using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K32 = using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl< using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>; WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>; WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8, WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16, WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8, WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
// bf16 // bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16 = using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32 = using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>; WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>; WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
// fp8 // fp8
using WarpGemmMfma_f32_32x32x16_fp8_fp8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2> template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution = using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>, WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, WGAttrCtlEnum::Default_>,
2, 2,
swizzle_factor>>; swizzle_factor>>;
......
...@@ -510,11 +510,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -510,11 +510,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
}); });
} }
template <index_t kKIter, bool post_nop_ = false> template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec, CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec, const AVecType& a_vec,
const BVecType& b_vec, const BVecType& b_vec,
number<kKIter>, number<iKIter>,
bool_constant<post_nop_> = {}) const bool_constant<post_nop_> = {}) const
{ {
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>; using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
......
...@@ -139,7 +139,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -139,7 +139,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
} }
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa) else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{ {
DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "b", "v") DISPATCH_MFMA_("v_mfma_f32_16x16x16f16", "+v", "a", "a", "v")
} }
else else
{ {
......
...@@ -32,10 +32,8 @@ struct WarpGemmImpl ...@@ -32,10 +32,8 @@ struct WarpGemmImpl
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>; using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false> template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c, CK_TILE_DEVICE void
const ATensor& a, operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
const BTensor& b,
bool_constant<post_nop_> = {}) const
{ {
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CTensor> && static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, ATensor> && detail::is_similiar_distributed_tensor_v<ATensor, ATensor> &&
...@@ -56,7 +54,11 @@ struct WarpGemmImpl ...@@ -56,7 +54,11 @@ struct WarpGemmImpl
c.get_thread_buffer().template set_as<CVec>(I0, c_vec); c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
} }
template <<typename CTensor, typename ATensor, typename BTensor, index_t i_subk, bool post_nop_ = false> template <typename CTensor,
typename ATensor,
typename BTensor,
index_t i_subk,
bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c, CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a, const ATensor& a,
const BTensor& b, const BTensor& b,
...@@ -82,6 +84,7 @@ struct WarpGemmImpl ...@@ -82,6 +84,7 @@ struct WarpGemmImpl
template <typename ATensor, typename BTensor> template <typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{ {
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, ATensor> && static_assert(detail::is_similiar_distributed_tensor_v<ATensor, ATensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BTensor>); detail::is_similiar_distributed_tensor_v<BTensor, BTensor>);
CTensor c; CTensor c;
......
...@@ -160,10 +160,9 @@ CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor ...@@ -160,10 +160,9 @@ CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor
// reduction sweep forward // reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) { static_for<0, nstage, 1>{}([&](auto istage) {
// TODO: lid_over_rid_derivative not ok in xor? maybe need limit the usage of
// xor // xor
index_t src_lane = (__lane_id() * lid_over_rid_derivative) ^ index_t src_lane =
(number<1 << istage.value>{}.value); __lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane // pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane); const auto v_remote = warp_shuffle(v_local, src_lane);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment