Commit 4dd77195 authored by letaoqin's avatar letaoqin
Browse files

add gelu to kernel

parent 072dfbfe
......@@ -208,19 +208,26 @@ struct FusedMoeGemmPipeline_General
block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block);
}
#if 0
// relu
const auto activation = ck_tile::element_wise::Gelu{};
// constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size();
// static_for<0, thread_buffer_size, 1>{}([&](auto i) {
// activation(s_acc.get_thread_buffer()(i),s_acc.get_thread_buffer()[i]);
// });
tile_elementwise_inout(activation, s_acc, s_acc);
#if 1
{
constexpr auto a_spans = decltype(s_acc)::get_distributed_spans();
int counter = 0;
//a_spans[0] = 1;
// a_spans[0] = 1;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) {
constexpr auto i_j_idx = make_tuple(idxm, idxn);
const auto tile_idx = get_x_indices_from_distributed_indices(
g_dram_block.get_tile_distribution(), i_j_idx);
const auto tile_idx = get_x_indices_from_distributed_indices(
g_dram_block.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
counter = counter + 1;
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
printf("in c row is %d , col is %d, counter is %d, value is: "
......@@ -235,7 +242,6 @@ struct FusedMoeGemmPipeline_General
}
#endif
// move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment