Commit 4dd77195 authored by letaoqin's avatar letaoqin
Browse files

add gelu to kernel

parent 072dfbfe
...@@ -208,11 +208,18 @@ struct FusedMoeGemmPipeline_General ...@@ -208,11 +208,18 @@ struct FusedMoeGemmPipeline_General
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block); gemm_0(s_acc, a_lds_win, g_dram_block);
} }
#if 0 // relu
const auto activation = ck_tile::element_wise::Gelu{};
// constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size();
// static_for<0, thread_buffer_size, 1>{}([&](auto i) {
// activation(s_acc.get_thread_buffer()(i),s_acc.get_thread_buffer()[i]);
// });
tile_elementwise_inout(activation, s_acc, s_acc);
#if 1
{ {
constexpr auto a_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto a_spans = decltype(s_acc)::get_distributed_spans();
int counter = 0; int counter = 0;
//a_spans[0] = 1; // a_spans[0] = 1;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) { sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) { sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) {
constexpr auto i_j_idx = make_tuple(idxm, idxn); constexpr auto i_j_idx = make_tuple(idxm, idxn);
...@@ -235,7 +242,6 @@ struct FusedMoeGemmPipeline_General ...@@ -235,7 +242,6 @@ struct FusedMoeGemmPipeline_General
} }
#endif #endif
// move sacc to LDS // move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>( auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>()); smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment