Commit b5d6100b authored by letaoqin's avatar letaoqin
Browse files

change file name

parent f912ca40
...@@ -38,7 +38,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -38,7 +38,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
f_traits>; f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmGl<f_problem>; using f_pipeline = ck_tile::FusedMoeGemmPipeline_General<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>; using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmGlKernel<f_partitioner, f_pipeline, void>; using f_kernel = ck_tile::FusedMoeGemmGlKernel<f_partitioner, f_pipeline, void>;
......
...@@ -256,7 +256,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -256,7 +256,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// } // }
// std::cout << std::endl; // std::cout << std::endl;
// } // }
std::cout << sorted_token_ids_host << std::endl; // std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl; // std::cout << num_sorted_tiles_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl; // std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl; // std::cout << topk_weight_host << std::endl;
......
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_gl_policy.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -18,8 +18,8 @@ we need to design the pipeline such that all waves along gemm-N dim (gemm-m only ...@@ -18,8 +18,8 @@ we need to design the pipeline such that all waves along gemm-N dim (gemm-m only
| w0 | w1 | w2 | w3 | gemm-m | w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+ +----+----+----+----+
*/ */
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy> template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineGeneralPolicy>
struct FusedMoeGemmPipeline_FlatmmGl struct FusedMoeGemmPipeline_General
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
...@@ -70,14 +70,15 @@ struct FusedMoeGemmPipeline_FlatmmGl ...@@ -70,14 +70,15 @@ struct FusedMoeGemmPipeline_FlatmmGl
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
// matrix a or tokens smem // // matrix a or tokens smem
constexpr index_t smem_mat_a = // constexpr index_t smem_mat_a =
BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType); // BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
// shuffle C matrix // // shuffle C matrix
constexpr index_t smem_bridge = // constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); // BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_mat_a, smem_bridge); // return max(smem_mat_a, smem_bridge);
return Policy::template GetSmemSize<Problem>();
} }
// this is the thread-offset along row/col // this is the thread-offset along row/col
...@@ -104,12 +105,19 @@ struct FusedMoeGemmPipeline_FlatmmGl ...@@ -104,12 +105,19 @@ struct FusedMoeGemmPipeline_FlatmmGl
ignore = hidden_size; ignore = hidden_size;
ignore = intermediate_size; ignore = intermediate_size;
auto a_copy_dram_window = make_tile_window( CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>());
auto a_lds_win = make_tile_window(a_lds_view, make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), {0, 0});
auto a_global_to_dram_window = make_tile_window(
a_window_.get_bottom_tensor_view(), a_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
a_window_.get_window_origin(), a_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_A<Problem>()); Policy::template MakeGlobalTileDistribution_A<Problem>());
auto a_dram = load_tile(a_copy_dram_window); auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
#if 0 #if 0
//check a matrix gather right or not //check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram)::get_distributed_spans(); constexpr auto a_spans = decltype(a_dram)::get_distributed_spans();
...@@ -126,7 +134,6 @@ struct FusedMoeGemmPipeline_FlatmmGl ...@@ -126,7 +134,6 @@ struct FusedMoeGemmPipeline_FlatmmGl
}); });
}); });
#endif #endif
ignore = a_dram;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment