Commit da1a2829 authored by dummycoderfe's avatar dummycoderfe
Browse files

Merge branch 'ck_tile/moe_sorting' of github.com:ROCm/composable_kernel into ck_tile/moe_sorting

parents fbfad6c2 e44e7a95
...@@ -23,12 +23,12 @@ ...@@ -23,12 +23,12 @@
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp" #include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp" #include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp" #include "ck_tile/host/timer.hpp"
...@@ -16,8 +16,8 @@ namespace ck_tile { ...@@ -16,8 +16,8 @@ namespace ck_tile {
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension) // synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true> template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func, const ReduceFunc& reduce_func,
bool_constant<WithBroadcast> = {}) bool_constant<WithBroadcast> = {})
{ {
using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode; using DstrEncode = typename Dstr::DstrEncode;
...@@ -116,7 +116,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, ...@@ -116,7 +116,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
*/ */
template <typename AccDistributedTensor_, typename ReduceFunc> template <typename AccDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor, CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func) const ReduceFunc& reduce_func)
{ {
using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode; using DstrEncode = typename Dstr::DstrEncode;
...@@ -175,9 +175,9 @@ template <typename AccDistributedTensor_, ...@@ -175,9 +175,9 @@ template <typename AccDistributedTensor_,
index_t... InReduceDims, index_t... InReduceDims,
typename ReduceFunc> typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor, CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
const InDistributedTensor_& in_tensor, const InDistributedTensor_& in_tensor,
sequence<InReduceDims...>, sequence<InReduceDims...>,
const ReduceFunc& reduce_func) const ReduceFunc& reduce_func)
{ {
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
...@@ -250,9 +250,9 @@ template <typename AccDataType_, ...@@ -250,9 +250,9 @@ template <typename AccDataType_,
typename ReduceFunc, typename ReduceFunc,
typename InDataType_> typename InDataType_>
CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor, CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
sequence<InReduceDims...> in_reduce_dims, sequence<InReduceDims...> in_reduce_dims,
const ReduceFunc& reduce_func, const ReduceFunc& reduce_func,
const InDataType_& reduce_init) const InDataType_& reduce_init)
{ {
using InDataType = typename InDistributedTensor_::DataType; using InDataType = typename InDistributedTensor_::DataType;
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment