Commit 40df5c8b authored by letaoqin's avatar letaoqin
Browse files

add weight

parent b616b254
...@@ -375,6 +375,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -375,6 +375,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << sorted_token_ids_host << std::endl; std::cout << sorted_token_ids_host << std::endl;
std::cout << num_sorted_tiles_host << std::endl; std::cout << num_sorted_tiles_host << std::endl;
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size); // output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
// output_matrix_3d(d_host, experts, hidden_size, shared_intermediate_size_1);
std::cout << sorted_expert_ids_host << std::endl; std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl; // std::cout << topk_weight_host << std::endl;
......
...@@ -228,5 +228,5 @@ ...@@ -228,5 +228,5 @@
#endif #endif
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1 #define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 0
#endif #endif
...@@ -113,6 +113,8 @@ void reference_fused_moe( ...@@ -113,6 +113,8 @@ void reference_fused_moe(
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0}); ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm // first gemm
// if(i_expert == 0)
// printf("ie:%2d, it:%3d \n", i_expert, i_token);
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++) for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
{ {
AccDataType acc = static_cast<AccDataType>(0); AccDataType acc = static_cast<AccDataType>(0);
...@@ -122,7 +124,8 @@ void reference_fused_moe( ...@@ -122,7 +124,8 @@ void reference_fused_moe(
type_convert<AccDataType>(g_host(i_expert, i_n, i_k)); type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
} }
acc_0(0, i_n) = acc; acc_0(0, i_n) = acc;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc); // if(i_expert == 0)
// printf("in:%d, %f\t", i_n, acc);
} }
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1}); ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
...@@ -135,6 +138,8 @@ void reference_fused_moe( ...@@ -135,6 +138,8 @@ void reference_fused_moe(
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++) for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{ {
Activation{}(y(0, i_n), acc_0(0, i_n)); Activation{}(y(0, i_n), acc_0(0, i_n));
// if(i_expert == 0)
// printf("in:%d, %f\t", i_n, y(0, i_n));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n)); // printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
} }
} }
...@@ -161,6 +166,8 @@ void reference_fused_moe( ...@@ -161,6 +166,8 @@ void reference_fused_moe(
{ {
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k)); acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
} }
// if(i_expert == 0)
// printf("in:%d, %f\t", i_n, acc);
acc_1(0, i_n) = acc * weight; // multiple weight here acc_1(0, i_n) = acc * weight; // multiple weight here
} }
......
...@@ -247,22 +247,13 @@ struct FusedMoeGemmGlKernel ...@@ -247,22 +247,13 @@ struct FusedMoeGemmGlKernel
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane( const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t idx_m0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_M0); index_t idx_m0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_M0);
index_t idx_n0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_N0); index_t idx_n0 =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_N0);
// const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
// if(threadIdx.x == 200 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m
// printf("\n*************a_coord[0]: %d, a_coord[1]: %d size: %d \n", // position
// a_coord[number<0>{}], a_coord[number<1>{}], a_coord.size());
// }
// const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id *
// BlockShape::Block_M0; //not block pos?
const auto sorted_token_id = sorted_tile_id * BlockShape::Block_M0; // start block_m
// position
// index_t token_id = // index_t token_id =
// reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id]; // reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
......
...@@ -90,7 +90,8 @@ struct FusedMoeGemmPipeline_General ...@@ -90,7 +90,8 @@ struct FusedMoeGemmPipeline_General
} }
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE static void PrintMem(T& tensor) CK_TILE_HOST_DEVICE static void
PrintMem(T& tensor, const char* pstr, unsigned int threadid = 0, unsigned int blockid = 0)
{ {
constexpr auto spans = T::get_distributed_spans(); constexpr auto spans = T::get_distributed_spans();
int counter = 0; int counter = 0;
...@@ -99,12 +100,14 @@ struct FusedMoeGemmPipeline_General ...@@ -99,12 +100,14 @@ struct FusedMoeGemmPipeline_General
constexpr auto i_j_idx = make_tuple(idxn, idxk); constexpr auto i_j_idx = make_tuple(idxn, idxk);
const auto tile_idx = const auto tile_idx =
get_x_indices_from_distributed_indices(tensor.get_tile_distribution(), i_j_idx); get_x_indices_from_distributed_indices(tensor.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) if(threadIdx.x == threadid && blockIdx.x == 0 && blockIdx.y == blockid &&
blockIdx.z == 0)
{ {
const auto row = tile_idx.at(number<0>{}); const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{}); const auto col = tile_idx.at(number<1>{});
printf("in G row is %d , col is %d, counter is %d, value is: %f" printf("in %s row is %d , col is %d, counter is %d, value is: %f"
" \n", " \n",
pstr,
row, row,
col, col,
counter, counter,
...@@ -119,14 +122,11 @@ struct FusedMoeGemmPipeline_General ...@@ -119,14 +122,11 @@ struct FusedMoeGemmPipeline_General
const GWindow& g_window_, const GWindow& g_window_,
const DWindow& d_window_, const DWindow& d_window_,
OWindow& o_window_, OWindow& o_window_,
TopkWeightDataType /*topk_weight*/, TopkWeightDataType topk_weight,
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
index_t hidden_size, index_t hidden_size,
index_t intermediate_size) index_t intermediate_size)
{ {
ignore = d_window_;
ignore = hidden_size;
ignore = intermediate_size;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem); CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>( auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>()); smem_0, Policy::template MakeLdsBlockDesc_A<Problem>());
...@@ -157,13 +157,13 @@ struct FusedMoeGemmPipeline_General ...@@ -157,13 +157,13 @@ struct FusedMoeGemmPipeline_General
auto a_dram_block = load_tile(a_global_to_dram_window); auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block); store_tile(a_lds_win, a_dram_block);
#if 0 #if 0
PrintMem(a_dram_block); PrintMem(a_dram_block,"A", 0, 1);
#endif #endif
auto g_dram_block = load_tile(g_global_to_dram_window); auto g_dram_block = load_tile(g_global_to_dram_window);
#if 0 #if 1
PrintMem(g_dram_block); PrintMem(g_dram_block, "G", 0, 1);
#endif #endif
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
...@@ -191,7 +191,7 @@ struct FusedMoeGemmPipeline_General ...@@ -191,7 +191,7 @@ struct FusedMoeGemmPipeline_General
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block); gemm_0(s_acc, a_lds_win, g_dram_block);
} }
#if 1 #if 0
PrintMem(s_acc); PrintMem(s_acc);
#endif #endif
// relu // relu
...@@ -233,6 +233,25 @@ struct FusedMoeGemmPipeline_General ...@@ -233,6 +233,25 @@ struct FusedMoeGemmPipeline_General
d_window_.get_window_origin(), d_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_D<Problem>()); Policy::template MakeGlobalTileDistribution_D<Problem>());
auto d = load_tile(d_global_to_dram_window); auto d = load_tile(d_global_to_dram_window);
#if 0
PrintMem(d,"D",64);
#endif
// add to LDS
auto o_alds_view =
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::atomic_add>(
smem_0,
make_tuple(number<32>{}, number<32>{}),
make_tuple(32, 1),
number<8>{},
number<1>{});
auto o_alds_win =
make_tile_window(o_alds_view, make_tuple(number<32>{}, number<32>{}), {0, 0});
auto o_olds_win =
make_tile_window(o_alds_view,
make_tuple(number<32>{}, number<32>{}),
{0, 0},
Policy::template MakeGlobalTileDistribution_O<Problem>());
ignore = o_alds_win;
constexpr index_t kN1 = BlockShape::Block_N1; constexpr index_t kN1 = BlockShape::Block_N1;
const index_t n1_loops = ck_tile::integer_divide_ceil(hidden_size, kN1); const index_t n1_loops = ck_tile::integer_divide_ceil(hidden_size, kN1);
...@@ -258,12 +277,32 @@ struct FusedMoeGemmPipeline_General ...@@ -258,12 +277,32 @@ struct FusedMoeGemmPipeline_General
block_sync_lds(); block_sync_lds();
gemm_1(o_acc, y, d); gemm_1(o_acc, y, d);
// block_sync_lds();
tile_elementwise_inout(
[&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc);
auto o = cast_tile<ODataType>(o_acc); auto o = cast_tile<ODataType>(o_acc);
store_tile(o_window_, o); store_tile(o_alds_win, o);
} block_sync_lds();
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0)
// {
// for(int i = 0; i < 42; i++)
// {
// printf("\n%d value is %f\t", i, type_convert<float>(smem_0[i]));
// }
// }
if(threadIdx.x < 64)
{
auto o_out = load_tile(o_olds_win);
block_sync_lds();
store_tile(o_window_, o_out);
}
// ignore = o_olds_win;
// store_tile(o_window_, o);
#if 0 #if 0
PrintMem(o_acc); PrintMem(o,"O");
#endif #endif
}
// store_tile(o_window_, a_dram_block); // store_tile(o_window_, a_dram_block);
} }
}; };
......
...@@ -189,6 +189,18 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -189,6 +189,18 @@ struct FusedMoeGemmPipelineGeneralPolicy
return make_static_tile_distribution(g_block_dstr_encode); return make_static_tile_distribution(g_block_dstr_encode);
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<1, 2, 16>, sequence<4, 8>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0()
{ {
...@@ -276,27 +288,27 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -276,27 +288,27 @@ struct FusedMoeGemmPipelineGeneralPolicy
return d_block_dstr; return d_block_dstr;
} }
template <typename Problem> // template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O() // CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
{ // {
using S_ = remove_cvref_t<typename Problem::BlockShape>; // using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>; // using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// using CDataType = typename WarpGemm::CDataType; // // using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding = // constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>, // tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>, // tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>, // sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple<sequence<1, 2>>, // tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>, // tuple<sequence<1, 1>>,
sequence<1, 2>, // sequence<1, 2>,
sequence<0, 0>>{}; // sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( // constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); // c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); // constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr; // return c_block_dstr;
} // }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_A() CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_A()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment