"docs/en/vscode:/vscode.git/clone" did not exist on "0d1b224fb1bd6aedffbdd54999d7ef8370aacee9"
Commit f23a2e2a authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

resolved conflicts

parents f3eb5a18 c0adab48
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = BatchedTransposePolicy>
struct BatchedTransposePipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using InputType = ck_tile::remove_cvref_t<typename Problem::InputType>;
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t AlignmentM = Problem::AlignmentM;
static constexpr index_t AlignmentN = Problem::AlignmentN;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
template <typename InputWindow, typename OutputWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window)
{
auto inp_win =
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
auto out_win =
make_tile_window(out_window, Policy::template MakeOutputDistribution<Problem>());
auto x = load_tile(inp_win); // x->thread input_win->block
auto y = make_static_distributed_tensor<InputType>(
Policy::template MakeOutputDistribution<Problem>());
constexpr auto span_2d_x = decltype(x)::get_distributed_spans();
sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx1, idx0);
y(i_j_idx) = x(i_j_idx);
});
});
store_tile(out_win, y);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace ck_tile {
struct BatchedTransposePolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
using S = Problem;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
using S = Problem;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>,
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>>,
tuple<sequence<2, 1>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<2, 1>,
sequence<2, 2>>{});
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
#define VectorLoadSize 16
namespace ck_tile {
template <typename InputType_,
typename BlockTile, // Sequence<...
typename WarpTile, // Sequence<...
typename ThreadTile, // Sequence<...
bool kPadM_ = true,
bool kPadN_ = true>
struct BatchedTransposeProblem
{
using InputType = remove_cvref_t<InputType_>;
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread;
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp;
static constexpr index_t kBlockSize =
kMThreadPerWarp * kNThreadPerWarp * kMWarpPerBlock * kNWarpPerBlock;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO
static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1;
};
} // namespace ck_tile
...@@ -5,3 +5,4 @@ ...@@ -5,3 +5,4 @@
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
namespace ck_tile {
// clang-format off
template <typename T> struct typeToStr;
template <> struct typeToStr<float> { static constexpr const char * name = "fp32"; };
template <> struct typeToStr<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct typeToStr<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct typeToStr<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct typeToStr<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct typeToStr<int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
template <typename ADataType_, typename BDataType_>
std::string gemm_prec_str()
{
std::string base_str = std::string(typeToStr<ADataType_>::name);
if(!std::is_same_v<ADataType_, BDataType_>)
{
base_str += "_" + std::string(typeToStr<BDataType_>::name);
}
return base_str;
}
} // namespace ck_tile
...@@ -6,3 +6,4 @@ ...@@ -6,3 +6,4 @@
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -8,3 +8,4 @@ ...@@ -8,3 +8,4 @@
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#define CK_TILE_MAX_RANK 5 #include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template <typename AccDataType_, template <typename AccDataType_,
typename ODataType_, typename ODataType_,
bool kPadM_, typename CLayout_,
bool kPadN_, index_t kBlockSize_,
bool kTilePermute_, index_t kM_,
index_t kRank_, index_t kN_,
index_t kPerm0, index_t kMWave_,
index_t kPerm1, index_t kNWave_,
index_t TileSize0, index_t kMPerXdl_,
index_t TileSize1, index_t kNPerXdl_,
index_t kPerm2 = 0, index_t kKPerXdl_,
index_t kPerm3 = 0, bool isCTransposed_>
index_t kPerm4 = 0,
index_t TileSize2 = 0,
index_t TileSize3 = 0,
index_t TileSize4 = 0>
struct CShuffleEpilogueProblem struct CShuffleEpilogueProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_; using CLayout = remove_cvref_t<CLayout_>;
static constexpr bool kPadN = kPadN_; static constexpr index_t kBlockSize = kBlockSize_;
static constexpr bool kTilePermute = kTilePermute_; static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kRank = kRank_; static constexpr index_t kNPerBlock = kN_;
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4}; static constexpr index_t kMWave = kMWave_;
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = { static constexpr index_t kNWave = kNWave_;
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4}; static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
}; };
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue struct CShuffleEpilogue
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM; using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr bool kPadN = Problem::kPadN; static constexpr index_t kBlockSize = Problem::kBlockSize;
const index_t* kPerm = Problem::kPerm; static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr bool kTilePermute = Problem::kTilePermute; static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t kRank = Problem::kRank; static constexpr index_t kMWave = Problem::kMWave;
const index_t* tile_sizes = Problem::tile_sizes; static constexpr index_t kNWave = Problem::kNWave;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
// No additional shared memory needed static constexpr index_t kNPerXdl = Problem::kNPerXdl;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename ODataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{ {
// TODO: At now CShuffle doesn't allow to vector store after permute. constexpr index_t MaxVectorStoreSize = 16;
// It should be fixed and this function should return true. return MaxVectorStoreSize / sizeof(ODataType);
return false;
} }
template <typename OAccTile> template <typename Problem>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{ {
using DataType = typename OAccTile::DataType; // N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
// Get thread buffer
auto& thread_buf = o_acc_tile.get_thread_buffer();
// Create a temporary buffer to hold the permuted data
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
// Get the lengths of each dimension
auto thread_tensor_lengths = o_acc_tile.get_lengths();
// Total number of elements
index_t total_elements = OAccTile::kThreadElementSpaceSize;
// Iterate over all elements
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
{ {
// Convert linear index to multi-dimensional indices return make_naive_tensor_descriptor(
array<index_t, kRank> indices; make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
index_t remaining = linear_idx; make_tuple(number<kNWave * kNPerXdl>{}, number<1>{}));
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
remaining /= thread_tensor_lengths.get(number<rev_i>{});
});
// Apply the permutation
array<index_t, kRank> permuted_indices;
static_for<0, kRank, 1>{}(
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
// Compute offsets
index_t dst_offset = 0;
index_t stride = 1;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
dst_offset += permuted_indices[rev_i] * stride;
stride *= thread_tensor_lengths.get(number<rev_i>{});
});
// Move the data
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
} }
// M is contiguous dimension
// Copy the permuted data back to the original thread buffer else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
for(index_t i = 0; i < total_elements; ++i) {
return make_naive_tensor_descriptor(
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<1>{}, number<kMWave * kMPerXdl>{}));
}
else
{ {
thread_buf.set_as(i, permuted_thread_buf.get(i)); static_assert(false, "Unsupported CLayout!");
} }
} }
template <typename ODramWindowTmp, CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
}
template <typename ODramWindow,
typename OAccTile, typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set> memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) CK_TILE_DEVICE auto
operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
{ {
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
for(index_t i = 0; i < kRank; ++i)
{
permuted_tile_coords[i] = tile_coords[kPerm[i]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin const index_t iMWarp = get_warp_id() / kNWave;
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0}; const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
for(index_t i = 0; i < kRank; ++i)
{ constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i]; auto o_lds_block = make_tensor_view<address_space_enum::lds>(
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]); static_cast<ODataType*>(p_smem), lds_block_desc);
} auto in_lds_window =
make_tile_window(o_lds_block,
typename ODramWindowTmp::BottomTensorIndex step = {}; make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
for(index_t i = 0; i < kRank; ++i) {number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
{ auto out_lds_window =
step[i] = permuted_window_origin[i] - current_window_origin[i]; make_tile_window(o_lds_block,
} make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
constexpr index_t num_access = SFC::get_num_of_access();
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
GetVectorSizeC<ODataType>(),
tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
CWarpTensor c_warp_in_tensor;
static_for<0, num_access, 1>{}([&](auto iAccess) {
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
block_sync_lds();
store_tile(in_lds_window, c_warp_in_tensor_casted);
block_sync_lds();
const auto c_out_tensor =
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
// Move the window
move_tile_window(o_dram_window_tmp, step);
// Permute the data within the tile if necessary
if constexpr(kTilePermute)
{
permute_tile_data(o_acc_tile);
}
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
{
if constexpr(out_memory_data_op == memory_operation_enum::set) if constexpr(out_memory_data_op == memory_operation_enum::set)
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile(out_dram_window, c_out_tensor);
} }
else else
{ {
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); update_tile(out_dram_window, c_out_tensor);
} }
buffer_store_fence(); if constexpr(iAccess != num_access - 1)
}
else
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
} }
else });
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
} }
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem ...@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem
static constexpr bool UseRawStore = UseRawStore_; static constexpr bool UseRawStore = UseRawStore_;
}; };
template <typename AccDataType_,
typename ODataType_,
typename CLayout_,
bool kPadM_,
bool kPadN_,
index_t kMPerXdl_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
bool UseRawStore_ = true>
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
{
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
};
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct Default2DEpilogue struct Default2DEpilogue
{ {
...@@ -35,14 +57,13 @@ struct Default2DEpilogue ...@@ -35,14 +57,13 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
template <typename ODramWindowTmp, template <typename ODramWindowTmp,
typename OAccTile, typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set> memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile) CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
{ {
// TODO: this is ugly // TODO: this is ugly
...@@ -71,4 +92,76 @@ struct Default2DEpilogue ...@@ -71,4 +92,76 @@ struct Default2DEpilogue
} }
} }
}; };
template <typename Problem_, typename Policy_ = void>
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -9,3 +9,4 @@ ...@@ -9,3 +9,4 @@
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -824,4 +824,4 @@ ...@@ -824,4 +824,4 @@
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
// clang-format on // clang-format on
...@@ -722,4 +722,4 @@ ...@@ -722,4 +722,4 @@
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
// clang-format on // clang-format on
...@@ -771,4 +771,4 @@ ...@@ -771,4 +771,4 @@
#undef _UK_MFMA_ #undef _UK_MFMA_
#undef CK_TILE_FLATMM_UK_2B #undef CK_TILE_FLATMM_UK_2B
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
// clang-format on // clang-format on
...@@ -44,3 +44,4 @@ ...@@ -44,3 +44,4 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" #include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
...@@ -14,6 +15,6 @@ ...@@ -14,6 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......
...@@ -15,6 +15,10 @@ namespace ck_tile { ...@@ -15,6 +15,10 @@ namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
// clang-format off // clang-format off
// [indexing implementation-1] // [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices // using M_a as constexpr block_size to partition all tokens into different slices
...@@ -28,7 +32,7 @@ namespace ck_tile { ...@@ -28,7 +32,7 @@ namespace ck_tile {
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...@@ -55,6 +59,34 @@ namespace ck_tile { ...@@ -55,6 +59,34 @@ namespace ck_tile {
// num_tokens_post_padded_ptr : [28] // num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7] // num_sorted_tiles_ptr : [7]
// //
// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens)
// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5]
// num_tokens_post_padded_ptr : [24]
//
// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case)
// and modify the output expert-ID, because we will only have enbaled expert on specific GPU.
// we call expert input to this kernel as "global expert id", output as "local expert id"
//
// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id")
// num_tokens_post_padded_ptr : [20]
//
// * different from vLLM // * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id // 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr // 2)need sorted_weight_ptr
...@@ -67,10 +99,80 @@ namespace ck_tile { ...@@ -67,10 +99,80 @@ namespace ck_tile {
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
// //
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int num_experts_)
{
/* num_experts + 1
* +--------------------------------------+
* | |
* | |
* | | * -> sub-tokens
* | |
* | |
* +--------------------------------------+
* | | 2 -> cumsum buffer
* +--------------------------------------+
*
*/
int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here
int smem_rows = [&](){
index_t target_occupancy_ = 2;
constexpr index_t total_ = 65536 / sizeof(int);
constexpr index_t sub_unroll = 8;
constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt
// at lease 2 lines, one for sub_token unroll, one for cumsum
// should be enough
if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) {
if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols))
throw std::runtime_error("too many num_experts, can't allocate smem");
target_occupancy_ = 1;
}
int r = total_ / target_occupancy_ / smem_cols;
// round to sub_unroll multipl
int r_for_sub_token = r - cumsum_bufs;
r_for_sub_token = min(r_for_sub_token, num_tokens_);
r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll;
r_for_sub_token = max(r_for_sub_token, 1);
if(r_for_sub_token > 1)
{
int r_unroll_ = r_for_sub_token / sub_unroll;
// round to 1x/2x/4x/8x number of sub_unroll
int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29
int mask_ = (1 << (31 - clz_)) - 1;
mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most
mask_ = ~mask_;
//printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout);
r_for_sub_token = (r_unroll_ & mask_) * sub_unroll;
}
// final check
if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) {
throw std::runtime_error("can't run this kernel, request LDS over size");
}
return r_for_sub_token + cumsum_bufs;
}();
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
return ck_tile::make_tuple(smem_rows, smem_cols);
}
struct MoeSortingHostArgs struct MoeSortingHostArgs
{ {
const void* p_topk_ids; // [token, topk] const void* p_topk_ids; // [token, topk]
const void* p_weights; // [token, topk] const void* p_weights; // [token, topk]
const void* p_local_expert_mask;
void* p_sorted_token_ids; void* p_sorted_token_ids;
void* p_sorted_weights; void* p_sorted_weights;
void* p_sorted_expert_ids; void* p_sorted_expert_ids;
...@@ -101,6 +203,7 @@ struct MoeSortingKernel ...@@ -101,6 +203,7 @@ struct MoeSortingKernel
{ {
const void* p_topk_ids; const void* p_topk_ids;
const void* p_weights; const void* p_weights;
const void* p_local_expert_mask;
void* p_sorted_token_ids; void* p_sorted_token_ids;
void* p_sorted_weights; void* p_sorted_weights;
void* p_sorted_expert_ids; void* p_sorted_expert_ids;
...@@ -111,8 +214,11 @@ struct MoeSortingKernel ...@@ -111,8 +214,11 @@ struct MoeSortingKernel
index_t moe_buf_bytes; index_t moe_buf_bytes;
index_t tokens_per_thread; index_t tokens_per_thread;
index_t smem_rows;
mdiv unit_size_mdiv; mdiv unit_size_mdiv;
mdiv topk_mdiv; mdiv topk_mdiv;
mdiv expert_mdiv;
// mdiv sub_tokens_mdiv;
}; };
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
...@@ -123,15 +229,25 @@ struct MoeSortingKernel ...@@ -123,15 +229,25 @@ struct MoeSortingKernel
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
{ {
#if MOE_SORTING_USE_EX_KERNEL
(void)h;
return dim3(256);
#else
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size())); return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
#endif
} }
// in byte // in byte
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{ {
#if MOE_SORTING_USE_EX_KERNEL
auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
return smem_rows * smem_cols * sizeof(int);
#else
const auto blocks = BlockSize(h); const auto blocks = BlockSize(h);
// usually num_experts is power of 2, we pad 1 dword here for the row-size // usually num_experts is power of 2, we pad 1 dword here for the row-size
return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t); return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t);
#endif
} }
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
...@@ -139,6 +255,7 @@ struct MoeSortingKernel ...@@ -139,6 +255,7 @@ struct MoeSortingKernel
Kargs k; Kargs k;
k.p_topk_ids = h.p_topk_ids; k.p_topk_ids = h.p_topk_ids;
k.p_weights = h.p_weights; k.p_weights = h.p_weights;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_sorted_token_ids = h.p_sorted_token_ids; k.p_sorted_token_ids = h.p_sorted_token_ids;
k.p_sorted_weights = h.p_sorted_weights; k.p_sorted_weights = h.p_sorted_weights;
k.p_sorted_expert_ids = h.p_sorted_expert_ids; k.p_sorted_expert_ids = h.p_sorted_expert_ids;
...@@ -152,10 +269,18 @@ struct MoeSortingKernel ...@@ -152,10 +269,18 @@ struct MoeSortingKernel
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x); k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)}; k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)}; k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.smem_rows = [&](){
auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
(void) c_;
return r_;
}();
k.expert_mdiv = mdiv{static_cast<uint32_t>(h.num_experts)};
// k.sub_tokens_mdiv = mdiv{static_cast<uint32_t>(k.smem_rows - 1)};
return k; return k;
} }
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....] // [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
// NOTE: wave_size need at least be 16!! dpp 16 is one row
template <typename data_t, int wave_size> template <typename data_t, int wave_size>
__device__ inline void wave_cumsum(data_t& thread_data) const __device__ inline void wave_cumsum(data_t& thread_data) const
{ {
...@@ -196,6 +321,40 @@ struct MoeSortingKernel ...@@ -196,6 +321,40 @@ struct MoeSortingKernel
bank_mask, bank_mask,
bound_ctrl))); // row_shr:4 bound_ctrl))); // row_shr:4
} }
if constexpr(wave_size == 8) {
// wave-size=8 need one extra shift
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x118,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:8
#if 0
constexpr int bank_mask_0_7 = 0b1100;
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
__builtin_amdgcn_update_dpp(0, /* old value */
__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask_0_7,
bound_ctrl))// row_newbcast:7
);
#else
data_t xxx =__builtin_bit_cast(data_t,
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask,
bound_ctrl)); // row_newbcast:7
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
thread_data = thread_data - yyy;
#endif
}
if constexpr(wave_size > 8) if constexpr(wave_size > 8)
{ {
thread_data = thread_data =
...@@ -224,6 +383,36 @@ struct MoeSortingKernel ...@@ -224,6 +383,36 @@ struct MoeSortingKernel
} }
} }
// reduce single pixel within a wave
template <typename T, typename F, index_t wave_size_ = warpSize>
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
// constexpr int reduce_stage = 6; // 1<<6=64
// clang-format off
constexpr int reduce_stage = [](){
if constexpr(wave_size_ == 2) return 1;
else if constexpr(wave_size_ == 4) return 2;
else if constexpr(wave_size_ == 8) return 3;
else if constexpr(wave_size_ == 16) return 4;
else if constexpr(wave_size_ == 32) return 5;
else if constexpr(wave_size_ == 64) return 6;
else return 0;
}();
// clang-format on
T v_local = local;
#pragma unroll reduce_stage
for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
{
int src_lane = __lane_id() ^ (1 << i_stage);
int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
T v_remote = bit_cast<T>(v_remote_tmp);
v_local = reduce_f(v_local, v_remote);
}
return v_local;
}
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
{ {
return row * total_col + col; return row * total_col + col;
...@@ -257,37 +446,37 @@ struct MoeSortingKernel ...@@ -257,37 +446,37 @@ struct MoeSortingKernel
index_t* shared_mem = reinterpret_cast<index_t*>(smem); index_t* shared_mem = reinterpret_cast<index_t*>(smem);
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1) index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts + 1); // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i) for(int i = 0; i < num_experts; ++i)
{ {
tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0; tokens_cnts[calc_index(num_experts + 1, tid + 1, i)] = 0;
} }
#pragma unroll Problem_::InternalLoadUnroll #pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])]; ++tokens_cnts[calc_index(num_experts + 1, tid + 1, topk_id[i])];
} }
__syncthreads(); __syncthreads();
#if 1 #if 1
if(tid < num_experts) if(tid < num_experts)
{ {
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
index_t local_c[8]; index_t local_c[8];
index_t prev_c = 0; index_t prev_c = 0;
// TODO: manually unroll. pragma unroll does not work well when we have dependency // TODO: manually unroll. pragma unroll does not work well when we have dependency
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i+= 8) for(int i = 1; i <= static_cast<index_t>(blockDim.x); i += 8)
{ {
local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)]; local_c[0] = tokens_cnts[calc_index(num_experts + 1, i + 0, tid)];
local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)]; local_c[1] = tokens_cnts[calc_index(num_experts + 1, i + 1, tid)];
local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)]; local_c[2] = tokens_cnts[calc_index(num_experts + 1, i + 2, tid)];
local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)]; local_c[3] = tokens_cnts[calc_index(num_experts + 1, i + 3, tid)];
local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)]; local_c[4] = tokens_cnts[calc_index(num_experts + 1, i + 4, tid)];
local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)]; local_c[5] = tokens_cnts[calc_index(num_experts + 1, i + 5, tid)];
local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)]; local_c[6] = tokens_cnts[calc_index(num_experts + 1, i + 6, tid)];
local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)]; local_c[7] = tokens_cnts[calc_index(num_experts + 1, i + 7, tid)];
local_c[0] += prev_c; local_c[0] += prev_c;
local_c[1] += local_c[0]; local_c[1] += local_c[0];
...@@ -299,51 +488,57 @@ struct MoeSortingKernel ...@@ -299,51 +488,57 @@ struct MoeSortingKernel
local_c[7] += local_c[6]; local_c[7] += local_c[6];
prev_c = local_c[7]; prev_c = local_c[7];
tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0]; tokens_cnts[calc_index(num_experts + 1, i + 0, tid)] = local_c[0];
tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1]; tokens_cnts[calc_index(num_experts + 1, i + 1, tid)] = local_c[1];
tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2]; tokens_cnts[calc_index(num_experts + 1, i + 2, tid)] = local_c[2];
tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3]; tokens_cnts[calc_index(num_experts + 1, i + 3, tid)] = local_c[3];
tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4]; tokens_cnts[calc_index(num_experts + 1, i + 4, tid)] = local_c[4];
tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5]; tokens_cnts[calc_index(num_experts + 1, i + 5, tid)] = local_c[5];
tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6]; tokens_cnts[calc_index(num_experts + 1, i + 6, tid)] = local_c[6];
tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7]; tokens_cnts[calc_index(num_experts + 1, i + 7, tid)] = local_c[7];
} }
} }
#else #else
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future
// heuristic
{ {
if(tid < num_experts) if(tid < num_experts)
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
for(int i = 0; i < num_experts; i+=8) { for(int i = 0; i < num_experts; i += 8)
{
index_t local_c[8]; index_t local_c[8];
#pragma unroll #pragma unroll
for(int j = 0; j < 8; j++) { for(int j = 0; j < 8; j++)
local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)]; {
local_c[j] = tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)];
} }
#pragma unroll #pragma unroll
for(int j = 0; j < 8; j++) { for(int j = 0; j < 8; j++)
{
wave_cumsum<int, 64>(local_c[j]); wave_cumsum<int, 64>(local_c[j]);
} }
#pragma unroll #pragma unroll
for(int j = 0; j < 8; j++) { for(int j = 0; j < 8; j++)
tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j]; {
tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)] = local_c[j];
} }
} }
} }
#endif #endif
__syncthreads(); __syncthreads();
if constexpr (Problem::ExpertTile == 0) { if constexpr(Problem::ExpertTile == 0)
{
if(tid == 0) if(tid == 0)
{ {
cumsum[0] = 0; cumsum[0] = 0;
for(int i = 1; i <= num_experts; ++i) for(int i = 1; i <= num_experts; ++i)
{ {
auto current_units = [&]() { auto current_units = [&]() {
index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] + index_t x_ = tokens_cnts[calc_index(num_experts + 1, blockDim.x, i - 1)] +
unit_size_mdiv.divisor - 1; unit_size_mdiv.divisor - 1;
index_t y_ = unit_size_mdiv.div(x_); index_t y_ = unit_size_mdiv.div(x_);
return max(y_, 1) * unit_size_mdiv.divisor; return max(y_, 1) * unit_size_mdiv.divisor;
}(); }();
...@@ -351,20 +546,24 @@ struct MoeSortingKernel ...@@ -351,20 +546,24 @@ struct MoeSortingKernel
} }
*p_total_tokens_post_pad = cumsum[num_experts]; *p_total_tokens_post_pad = cumsum[num_experts];
} }
} else { }
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert) else
// for simplicity, not check experts here. {
int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; // TODO: we have out-of-bound read here. But result is still OK (will ignore tid >=
// expert) for simplicity, not check experts here.
int local_cnt = tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
int local_cumsum = padded_tokens_per_expert; int local_cumsum = padded_tokens_per_expert;
wave_cumsum<int, 64>(local_cumsum); wave_cumsum<int, 64>(local_cumsum);
if(tid == (num_experts - 1)) { if(tid == (num_experts - 1))
cumsum[0] = 0; {
cumsum[0] = 0;
*p_total_tokens_post_pad = local_cumsum; *p_total_tokens_post_pad = local_cumsum;
} }
if(tid < num_experts) { if(tid < num_experts)
{
cumsum[tid + 1] = local_cumsum; cumsum[tid + 1] = local_cumsum;
} }
} }
...@@ -373,7 +572,7 @@ struct MoeSortingKernel ...@@ -373,7 +572,7 @@ struct MoeSortingKernel
if(tid < num_experts) if(tid < num_experts)
{ {
int e_start = cumsum[tid]; int e_start = cumsum[tid];
int e_end = cumsum[tid + 1]; int e_end = cumsum[tid + 1];
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{ {
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
...@@ -383,8 +582,8 @@ struct MoeSortingKernel ...@@ -383,8 +582,8 @@ struct MoeSortingKernel
#pragma unroll Problem_::InternalLoadUnroll #pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
index_t expert_id = topk_id[i]; index_t expert_id = topk_id[i];
index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)]; index_t local_cnt = tokens_cnts[calc_index(num_experts + 1, tid, expert_id)];
index_t rank_post_pad = local_cnt + cumsum[expert_id]; index_t rank_post_pad = local_cnt + cumsum[expert_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t curr_token_id, curr_topk_id; uint32_t curr_token_id, curr_topk_id;
...@@ -393,16 +592,17 @@ struct MoeSortingKernel ...@@ -393,16 +592,17 @@ struct MoeSortingKernel
#else #else
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i); p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
#endif #endif
p_sorted_weights[rank_post_pad] = weights[i]; p_sorted_weights[rank_post_pad] = weights[i];
tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1; tokens_cnts[calc_index(num_experts + 1, tid, expert_id)] = local_cnt + 1;
} }
if constexpr (Problem::ExpertTile == 0) { if constexpr(Problem::ExpertTile == 0)
{
const index_t prefill_token = topk_mdiv.div(numel); const index_t prefill_token = topk_mdiv.div(numel);
if(tid < num_experts) if(tid < num_experts)
{ {
index_t expert_offset = index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; cumsum[tid] + tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
index_t expert_end = cumsum[tid + 1]; index_t expert_end = cumsum[tid + 1];
while(expert_offset < expert_end) while(expert_offset < expert_end)
{ {
...@@ -417,16 +617,19 @@ struct MoeSortingKernel ...@@ -417,16 +617,19 @@ struct MoeSortingKernel
} }
} }
} }
else { else
{
const index_t prefill_token = topk_mdiv.div(numel); const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32 // TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile; static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
{ {
index_t eid = tid / experts_per_wave; index_t eid = tid / experts_per_wave;
index_t expert_offset = index_t expert_offset = cumsum[eid] +
cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave; tokens_cnts[calc_index(num_experts + 1, blockDim.x, eid)] +
tid % experts_per_wave;
index_t expert_end = cumsum[eid + 1]; index_t expert_end = cumsum[eid + 1];
if(eid < num_experts) { if(eid < num_experts)
{
while(expert_offset < expert_end) while(expert_offset < expert_end)
{ {
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
...@@ -436,10 +639,363 @@ struct MoeSortingKernel ...@@ -436,10 +639,363 @@ struct MoeSortingKernel
p_sorted_token_ids[expert_offset] = prefill_token; p_sorted_token_ids[expert_offset] = prefill_token;
#endif #endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0); p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset+=experts_per_wave; expert_offset += experts_per_wave;
}
}
}
}
}
// only support index_t, and single pixel access
struct simple_smem_indexer
{
index_t* smem;
index_t row_stride;
// this is 2D
CK_TILE_DEVICE simple_smem_indexer(index_t* smem_, index_t row_stride_)
: smem(smem_), row_stride(row_stride_)
{
}
CK_TILE_DEVICE const index_t& operator()(index_t i_row, index_t i_col) const
{
return smem[i_row * row_stride + i_col];
}
CK_TILE_DEVICE index_t& operator()(index_t i_row, index_t i_col)
{
return smem[i_row * row_stride + i_col];
}
// this is 1D or linear
CK_TILE_DEVICE simple_smem_indexer(index_t* smem_) : smem(smem_), row_stride(0) {}
CK_TILE_DEVICE const index_t& operator()(index_t idx) const { return smem[idx]; }
CK_TILE_DEVICE index_t& operator()(index_t idx) { return smem[idx]; }
};
CK_TILE_DEVICE void
moe_align_block_size_kernel_ex(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights,
const IndexType* __restrict__ local_expert_mask,
index_t* p_sorted_token_ids,
WeightType* p_sorted_weights,
index_t* p_sorted_expert_ids,
index_t* p_total_tokens_post_pad,
const index_t num_experts,
const index_t tokens,
const mdiv unit_size_mdiv,
const mdiv topk_mdiv,
const mdiv expert_mdiv,
const index_t smem_rows,
void* smem) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize);
const index_t lid = __lane_id();
constexpr index_t block_size = 256; // blockDim.x;
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
const index_t topk = topk_mdiv.divisor;
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
const index_t smem_cols = num_experts + 1;
simple_smem_indexer smem_cumsum{reinterpret_cast<index_t*>(smem) + 0};
simple_smem_indexer smem_cumdup{reinterpret_cast<index_t*>(smem) + smem_cols};
simple_smem_indexer smem_tokens{reinterpret_cast<index_t*>(smem) + 2 * smem_cols,
smem_cols};
// #pragma unroll 8
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
{
uint32_t curr_token_id, curr_expert_id;
expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
smem_tokens(curr_token_id, curr_expert_id) = 0;
}
__syncthreads();
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
{
// NOTE: below for loop can't have barrier inside!!
for(int i = tid; i < (sub_tokens * topk); i += block_size)
{
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
int i_t = i_token + curr_token_id;
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
if constexpr(Problem::SubTokenOneShot)
smem_tokens(curr_token_id, eid) = curr_topk_id + 1;
else
smem_tokens(curr_token_id, eid)++;
}
__builtin_amdgcn_s_waitcnt(0xc07f);
}
__syncthreads(); // make sure different i_token iteration not overlap by different wave
}
// counting
if(tid == 0)
{
smem_cumsum(0) = 0;
// smem_cumdup(0) = 0;
}
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
int lane_group_os = tid % lane_group_sz;
constexpr int lane_group_nm = block_size / lane_group_sz;
for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm)
{
index_t local_c[Problem::SubTokenTile];
index_t cnt = 0;
for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile)
{
#pragma unroll Problem::SubTokenTile
for(int j = 0; j < Problem::SubTokenTile; j++)
{
local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e);
if constexpr(Problem::SubTokenOneShot)
{
local_c[j] = local_c[j] != 0 ? 1 : 0;
}
}
#pragma unroll Problem::SubTokenTile
for(int j = 0; j < Problem::SubTokenTile; j++)
{
cnt += wave_reduce(local_c[j], f_sum, number<8>{});
}
}
if(lane_group_os == 0)
smem_cumsum(i_e + 1) = cnt;
}
}
if constexpr(Problem::LocalExpertMasking)
{
smem_cumdup(0) = 0;
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
// reuse this buffer
smem_cumdup(i_e + 1) = local_expert_mask[i_e];
}
}
__syncthreads();
{
if(wid == 0)
{
// NOTE: under this block can never use __syncthreads!
int i_e_ = 0;
int local_cumsum_ = 0;
for(; i_e_ < num_experts; i_e_ += warpSize)
{
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
int local_cnt = smem_cumsum(i_e_ + lid + 1);
int blocks_pers_expert =
unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int pre_cumsum_masking = [&]() {
if constexpr(Problem::LocalExpertMasking)
return smem_cumdup(lid == 0 ? i_e_ : 0);
else
return 0; // not used
}();
int local_masking = [&]() {
if constexpr(Problem::LocalExpertMasking)
return smem_cumdup(i_e_ + lid + 1);
else
return 0; // not used
}();
int padded_tokens_per_expert = [&]() {
int x_ = [&]() {
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
// if local_cnt is zero, blocks_pers_expert will be zero
// this is what we want to achieve
return blocks_pers_expert * unit_size_mdiv.divisor;
}
else
{
return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
}
}();
if constexpr(Problem::LocalExpertMasking)
{
return local_masking ? x_ : 0;
}
else
return x_;
}();
local_cumsum_ = padded_tokens_per_expert;
local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum<int, warpSize>(local_cumsum_);
if((i_e_ + lid) < num_experts)
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
if constexpr(Problem::LocalExpertMasking)
{
local_masking += pre_cumsum_masking;
wave_cumsum<int, warpSize>(local_masking);
if((i_e_ + lid) < num_experts)
smem_cumdup(i_e_ + lid + 1) = local_masking;
}
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
// for above write however __syncthreads will cause barrier with waves other
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt(0xc07f);
}
if((lid + i_e_ - warpSize) == (num_experts - 1))
{
*p_total_tokens_post_pad = local_cumsum_;
}
}
__syncthreads();
}
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
int e_start = smem_cumsum(i_e);
int e_end = smem_cumsum(i_e + 1);
int expert_id = [&]() {
if constexpr(Problem::LocalExpertMasking)
{
// local expert id from cumsum
return smem_cumdup(i_e);
}
else
return i_e;
}();
smem_cumdup(i_e) = e_start; // duplicate cumsum for later use
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
continue;
}
if constexpr(Problem::LocalExpertMasking)
{
if(local_expert_mask[i_e] == 0)
continue;
}
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
}
}
smem_cumdup(num_experts) = smem_cumsum(num_experts);
// fill the p_sorted_token_ids/p_sorted_weights
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
{
if constexpr(!Problem::SubTokenOneShot)
{
// clear every time
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
{
uint32_t curr_token_id, curr_expert_id;
expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
smem_tokens(curr_token_id, curr_expert_id) = 0;
}
__syncthreads();
// load again
for(int i = tid; i < (sub_tokens * topk); i += block_size)
{
uint32_t curr_token_id_, curr_topk_id_;
topk_mdiv.divmod(i, curr_token_id_, curr_topk_id_);
int curr_token_id = static_cast<int>(curr_token_id_);
int curr_topk_id = static_cast<int>(curr_topk_id_);
int i_t = i_token + curr_token_id;
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1
}
}
__syncthreads();
}
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
int lane_group_os = tid % lane_group_sz;
constexpr int lane_group_nm = block_size / lane_group_sz;
for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm)
{
if constexpr(Problem::LocalExpertMasking)
{
if(local_expert_mask[eid] == 0)
continue;
}
int position = smem_cumsum(eid);
for(int i_sub_token = lane_group_os; i_sub_token < sub_tokens;
i_sub_token += lane_group_sz)
{
auto x = smem_tokens(i_sub_token, eid);
int local_cnt_cache = x != 0 ? 1 : 0;
int local_cnt = local_cnt_cache;
wave_cumsum<int, lane_group_sz>(local_cnt);
if(x != 0)
{
// now x is topk value
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[position + local_cnt - 1] =
MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1);
#else
p_sorted_token_ids[position + local_cnt - 1] = i_token + i_sub_token;
#endif
p_sorted_weights[position + local_cnt - 1] =
weights[(i_token + i_sub_token) * topk + x - 1];
}
int remote_cnt = __builtin_amdgcn_ds_bpermute(
(lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt);
position += remote_cnt;
} }
smem_cumsum(eid) = position;
} }
} }
__syncthreads();
}
// add the skip number
for(int eid = tid; eid < num_experts; eid += block_size)
{
int e_start = smem_cumsum(eid);
int e_end = smem_cumdup(eid + 1);
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
continue;
}
while(e_start < e_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[e_start] = MOE_SORTING_MOCK_ID(tokens, topk);
#else
p_sorted_token_ids[e_start] = tokens;
#endif
p_sorted_weights[e_start] = static_cast<WeightType>(0.0);
e_start++;
}
} }
} }
...@@ -456,6 +1012,24 @@ struct MoeSortingKernel ...@@ -456,6 +1012,24 @@ struct MoeSortingKernel
} }
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
extern __shared__ char smem[]; extern __shared__ char smem[];
#if MOE_SORTING_USE_EX_KERNEL
(void)numel;
return moe_align_block_size_kernel_ex(
static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
static_cast<const IndexType*>(kargs.p_local_expert_mask),
static_cast<IndexType*>(kargs.p_sorted_token_ids),
static_cast<WeightType*>(kargs.p_sorted_weights),
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
kargs.num_experts,
kargs.tokens,
kargs.unit_size_mdiv,
kargs.topk_mdiv,
kargs.expert_mdiv,
kargs.smem_rows,
smem);
#else
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids), return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights), static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType*>(kargs.p_sorted_token_ids), static_cast<IndexType*>(kargs.p_sorted_token_ids),
...@@ -468,6 +1042,7 @@ struct MoeSortingKernel ...@@ -468,6 +1042,7 @@ struct MoeSortingKernel
kargs.unit_size_mdiv, kargs.unit_size_mdiv,
kargs.topk_mdiv, kargs.topk_mdiv,
smem); smem);
#endif
} }
}; };
......
...@@ -25,4 +25,28 @@ struct MoeSortingProblem ...@@ -25,4 +25,28 @@ struct MoeSortingProblem
InternalLoadUnroll_; // TODO: need better design(like tile size) InternalLoadUnroll_; // TODO: need better design(like tile size)
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
}; };
template <typename IndexType_,
typename WeightType_,
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
bool SubTokenOneShot_, // if we only loop over once or not
bool LocalExpertMasking_, // used in EP case
bool SkipExpertsWithZeroTokens_ = true,
index_t ExpertTile_ = 0>
struct MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
};
} // namespace ck_tile } // namespace ck_tile
...@@ -46,3 +46,4 @@ ...@@ -46,3 +46,4 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ?? // TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1; static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread; // should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static constexpr index_t KPack = 8;
static constexpr index_t KPerThread = KIterPerWarp * KPack; static constexpr index_t KPerThread = KIterPerWarp * KPack;
static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KRepeat = KPerThread / KPack;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment