Unverified Commit 440e28b0 authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE] fused-moe first version (#1634)



* moe pipeline

* update code

* compile OK

* update

* update cpu reference

* update pipeline_gemm0

* compiler ok

* update pipeline

* rename to ex pipeline

* block-asm

* update

* update

* update first gemm ok

* compute correct

* update file structure

* update README

* update

* update

* update code

* update API

* return unsupport case

* add comment

* update readme

* update

* uncomment

* update

* fix build err

---------
Co-authored-by: default avatarvalarLip <340077269@qq.com>
parent 645fe812
...@@ -14,3 +14,5 @@ add_subdirectory(11_add_rmsnorm2d_rdquant) ...@@ -14,3 +14,5 @@ add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_smoothquant) add_subdirectory(12_smoothquant)
add_subdirectory(13_moe_sorting) add_subdirectory(13_moe_sorting)
add_subdirectory(14_moe_smoothquant) add_subdirectory(14_moe_smoothquant)
add_subdirectory(15_fused_moe)
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
...@@ -62,6 +63,7 @@ ...@@ -62,6 +63,7 @@
#include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp" #include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp"
......
...@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) ...@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if;
template <bool pre_nop>
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
template <bool pre_nop>
struct buffer_atomic_add<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl { namespace impl {
// below type indicate the data type used for buffer load inline asm // below type indicate the data type used for buffer load inline asm
// clang-format off // clang-format off
...@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) ...@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8 // buffer load i8
CK_TILE_DEVICE_EXTERN int8_t CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
...@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_ ...@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif #endif
} }
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
if constexpr(oob_conditional_check)
{
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
else
{
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
1);
}
}
// buffer_atomic_max requires: // buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
......
...@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds() ...@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif #endif
} }
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
asm volatile("s_wait_loadcnt %0 \n"
"s_barrier_signal -1 \n"
"s_barrier_wait -1"
:
: "n"(cnt)
: "memory");
#else
asm volatile("s_waitcnt vmcnt(%0) \n"
"s_barrier"
:
: "n"(cnt)
: "memory");
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load() CK_TILE_DEVICE void block_sync_lds_direct_load()
{ {
asm volatile("\ asm volatile("\
......
...@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane) ...@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
#endif #endif
} }
template <typename T>
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
{
static_assert(sizeof(T) == 4);
// per-thread v_flag store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
: [s_exec_flag] "=s"(exec_flag)
: [v_flag] "v"(v_flag));
return exec_flag;
}
template <typename X, typename Y>
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
{
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
// per-thread cmp store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
: [s_exec_flag] "=s"(exec_flag)
: [v_x] "v"(x), [v_y] "v"(y));
return exec_flag;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global, ...@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op, template <memory_operation_enum Op,
typename X, typename X,
bool oob_conditional_check = true,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x) CK_TILE_DEVICE void update(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{ {
if constexpr(Op == memory_operation_enum::set) if constexpr(Op == memory_operation_enum::set)
{ {
this->template set<X>(i, linear_offset, is_valid_element, x); this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
} }
else if constexpr(Op == memory_operation_enum::atomic_add) else if constexpr(Op == memory_operation_enum::atomic_add)
{ {
this->template atomic_add<X>(i, linear_offset, is_valid_element, x); this->template atomic_add<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
} }
else if constexpr(Op == memory_operation_enum::atomic_max) else if constexpr(Op == memory_operation_enum::atomic_max)
{ {
this->template atomic_max<X>(i, linear_offset, is_valid_element, x); this->template atomic_max<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
} }
// FIXME: remove memory_operation_enum::add // FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add) else if constexpr(Op == memory_operation_enum::add)
{ {
auto tmp = this->template get<X>(i, linear_offset, is_valid_element); auto tmp =
this->template set<X>(i, linear_offset, is_valid_element, x + tmp); this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
this->template set<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x + tmp);
// tmp += x; // tmp += x;
// this->template set<X>(i, is_valid_element, tmp); // this->template set<X>(i, is_valid_element, tmp);
} }
} }
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update_raw(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
}
}
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
...@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global, ...@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
} }
template <typename X, template <typename X,
bool oob_conditional_check = true,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
...@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global, ...@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global,
} }
template <typename X, template <typename X,
bool oob_conditional_check = true,
bool pre_nop = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add_raw<remove_cvref_t<T>,
t_per_x,
Coherence,
oob_conditional_check,
pre_nop>(
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
......
...@@ -22,28 +22,32 @@ template <typename BottomTensorView_, ...@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_, CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{}); return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_, CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{}); return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename DistributedTensor_, template <typename DistributedTensor_,
...@@ -51,15 +55,35 @@ template <typename DistributedTensor_, ...@@ -51,15 +55,35 @@ template <typename DistributedTensor_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(dst_tile, bool_constant<oob_conditional_check>{}); return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
/** /**
...@@ -76,6 +100,7 @@ template <typename T, ...@@ -76,6 +100,7 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
...@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
tile_window.load_raw( tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename T, template <typename T,
...@@ -95,6 +121,7 @@ template <typename T, ...@@ -95,6 +121,7 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
...@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
tile_window.load_raw( tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -114,6 +142,7 @@ template <typename LdsTileWindow_, ...@@ -114,6 +142,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
...@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
return tile_window.async_load_raw( return tile_window.async_load_raw(lds_tile,
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -134,6 +166,7 @@ template <typename LdsTileWindow_, ...@@ -134,6 +166,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
...@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
return tile_window.async_load_raw( return tile_window.async_load_raw(lds_tile,
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
} }
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
...@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number ...@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return unpacks; return unpacks;
} }
namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template <typename X, typename Y>
struct is_similiar_distributed_tensor
{
static constexpr bool value = false;
};
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
static_distributed_tensor<TypeY, DistY>>
{
using Tx = static_distributed_tensor<TypeX, DistX>;
using Ty = static_distributed_tensor<TypeY, DistY>;
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
};
template <typename X, typename Y>
inline constexpr bool is_similiar_distributed_tensor_v =
is_similiar_distributed_tensor<X, Y>::value;
} // namespace detail
} // namespace ck_tile } // namespace ck_tile
...@@ -333,6 +333,48 @@ struct tensor_view ...@@ -333,6 +333,48 @@ struct tensor_view
coord.get_offset(), linear_offset, is_valid_element, x); coord.get_offset(), linear_offset, is_valid_element, x);
} }
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const CK_TILE_HOST_DEVICE void print() const
{ {
printf("tensor_view{"); printf("tensor_view{");
......
...@@ -292,12 +292,15 @@ struct tile_window_with_static_distribution ...@@ -292,12 +292,15 @@ struct tile_window_with_static_distribution
{ {
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr); auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, bool_constant<oob_conditional_check>{}); load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor; return dst_tensor;
} }
template <typename DistributedTensor, bool oob_conditional_check = true> template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -785,6 +788,73 @@ struct tile_window_with_static_distribution ...@@ -785,6 +788,73 @@ struct tile_window_with_static_distribution
}); });
} }
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset] // [x0', x1', ... ] ==> [offset]
// also move window-origin // also move window-origin
......
...@@ -432,23 +432,38 @@ struct tile_window_linear ...@@ -432,23 +432,38 @@ struct tile_window_linear
CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>) CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>)
{ {
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{}); constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
// since this is linear offset, we assum bottom X tensor is always linear constexpr auto is_pure_linear_tensor =
constexpr index_t linear_offset = [&]() { reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{});
constexpr auto x_idx_ = linear_coord; if constexpr(is_pure_linear_tensor)
constexpr auto x_len_ = TileDstr{}.get_lengths(); {
static_assert(x_idx_.size() == x_len_.size()); // this case usually is a LDS window, everything is known at compile tile.
constexpr index_t x_dims_ = x_idx_.size(); // we directly use BottomTensorView transform to compute the offset, in case padding
index_t cu_stride_ = 1; auto bottom_tensor_coord =
index_t cu_offset_ = 0; make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
static_for<0, x_dims_, 1>{}([&](auto i_) { return bottom_tensor_coord.get_offset();
auto r_i_ = number<x_dims_ - i_ - 1>{}; }
cu_offset_ += x_idx_[r_i_] * cu_stride_; else
cu_stride_ *= x_len_[r_i_]; {
}); // this case usually is a global window, where last dim can be linear
return cu_offset_; // we hack here, that use the original TileDstr to compute the linear offset
}(); // ... hoping that there is no extra padding between other dims, which make sense
// since that would introduce runtime length (so can't use linear offset)
return linear_offset; constexpr index_t linear_offset = [&]() {
constexpr auto x_idx_ = linear_coord;
constexpr auto x_len_ = TileDstr{}.get_lengths();
static_assert(x_idx_.size() == x_len_.size());
constexpr index_t x_dims_ = x_idx_.size();
index_t cu_stride_ = 1;
index_t cu_offset_ = 0;
static_for<0, x_dims_, 1>{}([&](auto i_) {
auto r_i_ = number<x_dims_ - i_ - 1>{};
cu_offset_ += x_idx_[r_i_] * cu_stride_;
cu_stride_ *= x_len_[r_i_];
});
return cu_offset_;
}();
return linear_offset;
}
} }
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; } CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
...@@ -509,6 +524,64 @@ struct tile_window_linear ...@@ -509,6 +524,64 @@ struct tile_window_linear
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DstTile& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor;
}
template <typename DstTile, template <typename DstTile,
index_t i_access = -1, index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
...@@ -849,6 +922,58 @@ struct tile_window_linear ...@@ -849,6 +922,58 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE(); WINDOW_DISPATCH_ISSUE();
} }
template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
};
WINDOW_DISPATCH_ISSUE();
}
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset] // [x0', x1', ... ] ==> [offset]
// also move window-origin // also move window-origin
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#pragma once
namespace ck_tile {
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template <typename LdsTileWindow_>
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
return make_tuple(m0_init_value, size_per_issue);
}
} // namespace ck_tile
...@@ -41,15 +41,65 @@ template <typename BottomTensorView_, ...@@ -41,15 +41,65 @@ template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
typename DataType_> typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void CK_TILE_DEVICE void
update_tile(tile_window_with_static_distribution<BottomTensorView_, update_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{ {
tile_window.update(dstr_tensor); tile_window.update(dstr_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void
update_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.update_raw(dstr_tensor,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto update_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.update_raw(dstr_tensor,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
} }
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
template <typename Context, index_t Start = 0, index_t Step = 1>
struct static_counter
{
public:
template <typename Unique>
static constexpr index_t next()
{
return next<Unique>(0) * Step + Start;
}
template <unsigned long long>
static constexpr index_t next()
{
struct Unique
{
};
return next<Unique>(0) * Step + Start;
}
template <typename Unique>
static constexpr index_t current()
{
return current<Unique>(0) * Step + Start;
}
template <unsigned long long>
static constexpr index_t current()
{
struct Unique
{
};
return current<Unique>(0) * Step + Start;
}
private:
template <index_t I>
struct slot
{
_Pragma("GCC diagnostic push");
_Pragma("GCC diagnostic ignored \"-Wundefined-internal\"");
friend constexpr bool slot_allocated(slot<I>);
_Pragma("GCC diagnostic pop");
};
template <index_t I>
struct allocate_slot
{
friend constexpr bool slot_allocated(slot<I>) { return true; }
enum
{
value = I
};
};
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template <typename Unique, index_t I = 0, bool = slot_allocated(slot<I>())>
static constexpr index_t next(index_t)
{
return next<Unique, I + 1>(0);
}
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
// allocate_slot<I>.
template <typename Unique, index_t I = 0>
static constexpr index_t next(double)
{
return allocate_slot<I>::value;
}
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template <typename Unique, index_t I = Start, bool = slot_allocated(slot<I>())>
static constexpr index_t current(index_t)
{
return current<Unique, I + 1>(0);
}
// ...And this function will be used, instead, which will return the current counter, or assert
// in case next() hasn't been called yet.
template <typename Unique, index_t I = Start>
static constexpr index_t current(double)
{
static_assert(I != 0, "You must invoke next() first");
return I - 1;
}
};
namespace impl {
template <int I>
struct static_counter_uniq_;
}
#define MAKE_SC() \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
#define MAKE_SC_WITH(start_, step_) \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
#define NEXT_SC(c_) c_.next<__COUNTER__>()
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
// Usage:
// constexpr auto c = MAKE_SC()
// NEXT_SC(c) // -> constexpr 0
// NEXT_SC(c) // -> constexpr 1
// NEXT_SC(c) // -> constexpr 2
} // namespace ck_tile
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck_tile/host/fill.hpp" #include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp" #include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_batched_dropout.hpp"
...@@ -20,6 +21,7 @@ ...@@ -20,6 +21,7 @@
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <stdint.h> #include <stdint.h>
#include <stdexcept> #include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile { namespace ck_tile {
template <typename T> template <typename T>
...@@ -36,6 +37,19 @@ struct DeviceMem ...@@ -36,6 +37,19 @@ struct DeviceMem
mpDeviceBuf = nullptr; mpDeviceBuf = nullptr;
} }
} }
template <typename T>
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
{
if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
ToDevice(t.data());
}
void Realloc(std::size_t mem_size) void Realloc(std::size_t mem_size)
{ {
if(mpDeviceBuf) if(mpDeviceBuf)
...@@ -92,6 +106,27 @@ struct DeviceMem ...@@ -92,6 +106,27 @@ struct DeviceMem
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
} }
} }
// construct a host tensor with type T
template <typename T>
HostTensor<T> ToHost(std::size_t cpySize)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
HostTensor<T> h_({host_elements});
if(mpDeviceBuf)
{
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
return h_;
}
template <typename T>
HostTensor<T> ToHost()
{
return ToHost<T>(mMemSize);
}
void SetZero() const void SetZero() const
{ {
if(mpDeviceBuf) if(mpDeviceBuf)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <unordered_set> #include <unordered_set>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -22,13 +23,44 @@ struct FillUniformDistribution ...@@ -22,13 +23,44 @@ struct FillUniformDistribution
float a_{-5.f}; float a_{-5.f};
float b_{5.f}; float b_{5.f};
std::optional<uint32_t> seed_{11939}; std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
bool threaded = false;
template <typename ForwardIter> template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const void operator()(ForwardIter first, ForwardIter last) const
{ {
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); if(threaded)
std::uniform_real_distribution<float> dis(a_, b_); {
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); }); uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
}
} }
template <typename ForwardRange> template <typename ForwardRange>
...@@ -115,13 +147,44 @@ struct FillNormalDistribution ...@@ -115,13 +147,44 @@ struct FillNormalDistribution
float mean_{0.f}; float mean_{0.f};
float variance_{1.f}; float variance_{1.f};
std::optional<uint32_t> seed_{11939}; std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
bool threaded = false;
template <typename ForwardIter> template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const void operator()(ForwardIter first, ForwardIter last) const
{ {
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); if(threaded)
std::normal_distribution<float> dis(mean_, std::sqrt(variance_)); {
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); }); uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
}
} }
template <typename ForwardRange> template <typename ForwardRange>
...@@ -235,6 +298,44 @@ struct FillMonotonicSeq ...@@ -235,6 +298,44 @@ struct FillMonotonicSeq
} }
}; };
template <typename T, bool IsAscending = true>
struct FillStepRange
{
float start_value_{0};
float end_value_{3};
float step_{1};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, n = start_value_]() mutable {
auto tmp = n;
n += step_;
if constexpr(IsAscending)
{
if(n > end_value_)
n = start_value_;
}
else
{
if(n < end_value_)
n = start_value_;
}
return type_convert<T>(tmp);
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillStepRange&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
template <typename T> template <typename T>
struct FillConstant struct FillConstant
{ {
......
...@@ -8,12 +8,13 @@ ...@@ -8,12 +8,13 @@
#include <iostream> #include <iostream>
#include <iomanip> #include <iomanip>
#include <numeric> #include <numeric>
#include <thread>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <functional> #include <functional>
#include <fstream>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/ranges.hpp" #include "ck_tile/host/ranges.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old ...@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
return HostTensorDescriptor(new_lengths, new_strides); return HostTensorDescriptor(new_lengths, new_strides);
} }
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <typename F, typename... Xs> template <typename F, typename... Xs>
struct ParallelTensorFunctor struct ParallelTensorFunctor
{ {
...@@ -590,6 +574,107 @@ struct HostTensor ...@@ -590,6 +574,107 @@ struct HostTensor
size() * FromSize / ToSize}; size() * FromSize / ToSize};
} }
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
{
os << t.mDesc;
os << "[";
for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
{
os << type_convert<float>(t.mData[idx]) << " #### ";
}
else
{
os << t.mData[idx];
}
}
os << "]";
return os;
}
// read data from a file, as dtype
// the file could dumped from torch as (targeting tensor is t here)
// numpy.savetxt("f.txt", t.view(-1).numpy())
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
// will output f.txt, each line is a value
// dtype=float or int, internally will cast to real type
void loadtxt(std::string file_name, std::string dtype = "float")
{
std::ifstream file(file_name);
if(file.is_open())
{
std::string line;
index_t cnt = 0;
while(std::getline(file, line))
{
if(cnt >= static_cast<index_t>(mData.size()))
{
throw std::runtime_error(std::string("data read from file:") + file_name +
" is too big");
}
if(dtype == "float")
{
mData[cnt] = type_convert<T>(std::stof(line));
}
else if(dtype == "int" || dtype == "int32")
{
mData[cnt] = type_convert<T>(std::stoi(line));
}
cnt++;
}
file.close();
if(cnt < static_cast<index_t>(mData.size()))
{
std::cerr << "Warning! reading from file:" << file_name
<< ", does not match the size of this tensor" << std::endl;
}
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
// can save to a txt file and read from torch as:
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
void savetxt(std::string file_name, std::string dtype = "float")
{
std::ofstream file(file_name);
if(file.is_open())
{
for(auto& itm : mData)
{
if(dtype == "float")
file << type_convert<float>(itm) << std::endl;
else if(dtype == "int")
file << type_convert<int>(itm) << std::endl;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error
file << type_convert<float>(itm) << std::endl;
}
file.close();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
Descriptor mDesc; Descriptor mDesc;
Data mData; Data mData;
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <thread>
#include <utility>
namespace ck_tile {
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment