Commit 4dd77195 authored by letaoqin's avatar letaoqin
Browse files

add gelu to kernel

parent 072dfbfe
...@@ -208,19 +208,26 @@ struct FusedMoeGemmPipeline_General ...@@ -208,19 +208,26 @@ struct FusedMoeGemmPipeline_General
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block); gemm_0(s_acc, a_lds_win, g_dram_block);
} }
#if 0 // relu
const auto activation = ck_tile::element_wise::Gelu{};
// constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size();
// static_for<0, thread_buffer_size, 1>{}([&](auto i) {
// activation(s_acc.get_thread_buffer()(i),s_acc.get_thread_buffer()[i]);
// });
tile_elementwise_inout(activation, s_acc, s_acc);
#if 1
{ {
constexpr auto a_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto a_spans = decltype(s_acc)::get_distributed_spans();
int counter = 0; int counter = 0;
//a_spans[0] = 1; // a_spans[0] = 1;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) { sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) { sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) {
constexpr auto i_j_idx = make_tuple(idxm, idxn); constexpr auto i_j_idx = make_tuple(idxm, idxn);
const auto tile_idx = get_x_indices_from_distributed_indices( const auto tile_idx = get_x_indices_from_distributed_indices(
g_dram_block.get_tile_distribution(), i_j_idx); g_dram_block.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{ {
counter = counter + 1; counter = counter + 1;
const auto row = tile_idx.at(number<0>{}); const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{}); const auto col = tile_idx.at(number<1>{});
printf("in c row is %d , col is %d, counter is %d, value is: " printf("in c row is %d , col is %d, counter is %d, value is: "
...@@ -235,7 +242,6 @@ struct FusedMoeGemmPipeline_General ...@@ -235,7 +242,6 @@ struct FusedMoeGemmPipeline_General
} }
#endif #endif
// move sacc to LDS // move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>( auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>()); smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment