Unverified Commit b8d11559 authored by amd-khushbu's avatar amd-khushbu Committed by GitHub
Browse files

Merge branch 'develop' into ck_profiler_m_instances

parents 7f3fe4e7 3b230208
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
template <typename T>
struct IsCharArray : std::false_type
{
};
template <std::size_t N>
struct IsCharArray<char[N]> : std::true_type
{
};
template <std::size_t N>
struct IsCharArray<const char[N]> : std::true_type
{
};
template <std::size_t N>
struct IsCharArray<char (&)[N]> : std::true_type
{
};
template <std::size_t N>
struct IsCharArray<const char (&)[N]> : std::true_type
{
};
template <typename... Ts>
inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v<Ts, std::string_view> ||
IsCharArray<Ts>::value ||
std::is_same_v<Ts, char>)&&...);
template <typename... Ts>
[[nodiscard]] auto concat(const Ts&... xs)
-> std::enable_if_t<!AllConvertibleToStringView<Ts...>, std::string>
{
using ::operator<<;
thread_local std::ostringstream oss;
oss.str("");
(oss << ... << xs);
return oss.str();
}
template <std::size_t N>
[[nodiscard]] constexpr inline std::size_t getSize(char (&)[N]) noexcept
{
return N;
}
template <std::size_t N>
[[nodiscard]] constexpr inline std::size_t getSize(const char (&)[N]) noexcept
{
return N;
}
[[nodiscard]] constexpr inline std::size_t getSize(const char* s) noexcept
{
const char* end = s;
while(*end++ != 0) {}
return end - s - 1;
}
[[nodiscard]] constexpr inline std::size_t getSize(const char&) noexcept { return 1; }
[[nodiscard]] inline std::size_t getSize(const std::string& s) noexcept { return s.size(); }
[[nodiscard]] constexpr inline std::size_t getSize(const std::string_view& s) noexcept
{
return s.size();
}
template <typename... Ts>
auto concatInto(std::string& result, const Ts&... xs)
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, void>
{
const std::size_t space = (1 + ... + getSize(xs));
result.reserve(result.size() + space);
((result += xs), ...);
}
template <typename... Ts>
[[nodiscard]] auto concat(const Ts&... xs)
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, std::string>
{
std::string result;
concatInto(result, xs...);
return result;
}
// Function for types convertible to std::string_view
template <typename Sep, typename First, typename... Rest>
[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest)
-> std::enable_if_t<AllConvertibleToStringView<First, Rest...>, std::string>
{
std::string result;
result += first;
((result += sep, result += rest), ...);
return result;
}
// Function for other types
template <typename Sep, typename First, typename... Rest>
[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest)
-> std::enable_if_t<!AllConvertibleToStringView<First, Rest...>, std::string>
{
using ::operator<<;
thread_local std::ostringstream oss;
oss.str("");
oss << first;
((oss << sep << rest), ...);
return oss.str();
}
} // namespace ck_tile
......@@ -14,12 +14,15 @@ namespace ck_tile {
template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const HostTensor<WeightType>& weights,
const HostTensor<IndexType>& local_expert_mask,
HostTensor<IndexType>& p_sorted_token_ids,
HostTensor<WeightType>& sorted_weight,
HostTensor<IndexType>& sorted_expert_ids,
index_t& unit_cnt,
const index_t experts,
const index_t unit_size)
const index_t unit_size,
bool local_expert_masking,
bool skip_experts_with_zero_token = true)
{
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
......@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
#endif
std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0));
// count number of unit-size slices in this expert
std::vector<IndexType> expert_slices(experts, 1);
// count the tokens used in this expert
std::vector<IndexType> expert_slice_idxs(experts, 0);
// TODO: above 2 buffer seems duplicated
for(index_t t = 0; t < num_token; t++)
{
......@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
IndexType* out_tokens = p_sorted_token_ids.data();
WeightType* out_weights = sorted_weight.data();
IndexType* out_expert_id = sorted_expert_ids.data();
int curr_expert_id = 0;
for(index_t e = 0; e < experts; e++)
{
if(local_expert_masking)
{
if(local_expert_mask(e) == 0)
continue;
}
if(skip_experts_with_zero_token)
{
if(expert_slice_idxs[e] == 0)
{
curr_expert_id++;
continue;
}
}
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size;
memcpy(out_weights,
......@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
for(index_t s = 0; s < expert_slices[e]; s++)
{
out_expert_id[s] = e;
out_expert_id[s] = curr_expert_id;
unit_cnt++;
}
out_expert_id += expert_slices[e];
curr_expert_id++;
}
unit_cnt *= unit_size;
return;
......
......@@ -10,3 +10,4 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -9,3 +9,4 @@
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -5,3 +5,4 @@
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
namespace ck_tile {
// clang-format off
template <typename T> struct typeToStr;
template <> struct typeToStr<float> { static constexpr const char * name = "fp32"; };
template <> struct typeToStr<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct typeToStr<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct typeToStr<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct typeToStr<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct typeToStr<int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
template <typename ADataType_, typename BDataType_>
std::string gemm_prec_str()
{
std::string base_str = std::string(typeToStr<ADataType_>::name);
if(!std::is_same_v<ADataType_, BDataType_>)
{
base_str += "_" + std::string(typeToStr<BDataType_>::name);
}
return base_str;
}
} // namespace ck_tile
......@@ -6,3 +6,4 @@
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -8,3 +8,4 @@
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -9,3 +9,4 @@
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -44,3 +44,4 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -310,7 +310,7 @@ struct SimplifiedGenericAttentionMask
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split;
const index_t split_end = split_start + x_per_split;
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split);
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end));
......
......@@ -742,7 +742,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{});
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
......
......@@ -343,6 +343,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
// ensure LDS access by Q is done before the over-writting by K
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
do
......
......@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
......@@ -14,6 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......
......@@ -25,4 +25,28 @@ struct MoeSortingProblem
InternalLoadUnroll_; // TODO: need better design(like tile size)
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
};
template <typename IndexType_,
typename WeightType_,
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
bool SubTokenOneShot_, // if we only loop over once or not
bool LocalExpertMasking_, // used in EP case
bool SkipExpertsWithZeroTokens_ = true,
index_t ExpertTile_ = 0>
struct MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
};
} // namespace ck_tile
......@@ -29,6 +29,8 @@
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
......@@ -46,3 +48,4 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -14,24 +14,54 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
private:
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t KPack = WarpGemm::kKPerThread;
};
public:
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using Traits = GemmTraits_<Problem, Policy>;
using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
......@@ -43,7 +73,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
......@@ -58,7 +88,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
......@@ -73,7 +103,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encode;
}
......@@ -112,13 +142,13 @@ struct BlockGemmARegBRegCRegV1
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
......@@ -157,7 +187,7 @@ struct BlockGemmARegBRegCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
......@@ -180,7 +210,7 @@ struct BlockGemmARegBRegCRegV1
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
......@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
struct BatchedGemmKernelArgs : GemmKernelArgs
{
index_t batch_stride_A;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment