Commit b5d6100b authored by letaoqin's avatar letaoqin
Browse files

change file name

parent f912ca40
......@@ -38,7 +38,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmGl<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_General<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmGlKernel<f_partitioner, f_pipeline, void>;
......
......@@ -256,7 +256,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// }
// std::cout << std::endl;
// }
std::cout << sorted_token_ids_host << std::endl;
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl;
......
......@@ -12,7 +12,8 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
......
......@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl_policy.hpp"
namespace ck_tile {
......@@ -18,8 +18,8 @@ we need to design the pipeline such that all waves along gemm-N dim (gemm-m only
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
struct FusedMoeGemmPipeline_FlatmmGl
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineGeneralPolicy>
struct FusedMoeGemmPipeline_General
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
......@@ -70,14 +70,15 @@ struct FusedMoeGemmPipeline_FlatmmGl
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// matrix a or tokens smem
constexpr index_t smem_mat_a =
BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
// shuffle C matrix
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_mat_a, smem_bridge);
// // matrix a or tokens smem
// constexpr index_t smem_mat_a =
// BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
// // shuffle C matrix
// constexpr index_t smem_bridge =
// BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
// return max(smem_mat_a, smem_bridge);
return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
......@@ -104,12 +105,19 @@ struct FusedMoeGemmPipeline_FlatmmGl
ignore = hidden_size;
ignore = intermediate_size;
auto a_copy_dram_window = make_tile_window(
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>());
auto a_lds_win = make_tile_window(a_lds_view, make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), {0, 0});
auto a_global_to_dram_window = make_tile_window(
a_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
a_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_A<Problem>());
auto a_dram = load_tile(a_copy_dram_window);
auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
#if 0
//check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram)::get_distributed_spans();
......@@ -126,7 +134,6 @@ struct FusedMoeGemmPipeline_FlatmmGl
});
});
#endif
ignore = a_dram;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment