Commit 40df5c8b authored by letaoqin's avatar letaoqin
Browse files

add weight

parent b616b254
......@@ -375,6 +375,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << sorted_token_ids_host << std::endl;
std::cout << num_sorted_tiles_host << std::endl;
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
// output_matrix_3d(d_host, experts, hidden_size, shared_intermediate_size_1);
std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl;
......
......@@ -228,5 +228,5 @@
#endif
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 0
#endif
......@@ -113,6 +113,8 @@ void reference_fused_moe(
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm
// if(i_expert == 0)
// printf("ie:%2d, it:%3d \n", i_expert, i_token);
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
......@@ -122,7 +124,8 @@ void reference_fused_moe(
type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
}
acc_0(0, i_n) = acc;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
// if(i_expert == 0)
// printf("in:%d, %f\t", i_n, acc);
}
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
......@@ -135,6 +138,8 @@ void reference_fused_moe(
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
Activation{}(y(0, i_n), acc_0(0, i_n));
// if(i_expert == 0)
// printf("in:%d, %f\t", i_n, y(0, i_n));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
......@@ -161,6 +166,8 @@ void reference_fused_moe(
{
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
}
// if(i_expert == 0)
// printf("in:%d, %f\t", i_n, acc);
acc_1(0, i_n) = acc * weight; // multiple weight here
}
......
......@@ -247,22 +247,13 @@ struct FusedMoeGemmGlKernel
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t idx_m0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_M0);
index_t idx_n0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_N0);
// const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
// if(threadIdx.x == 200 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
// printf("\n*************a_coord[0]: %d, a_coord[1]: %d size: %d \n",
// a_coord[number<0>{}], a_coord[number<1>{}], a_coord.size());
// }
// const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id *
// BlockShape::Block_M0; //not block pos?
const auto sorted_token_id = sorted_tile_id * BlockShape::Block_M0; // start block_m
// position
index_t idx_n0 =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_N0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m
// position
// index_t token_id =
// reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
......
......@@ -90,7 +90,8 @@ struct FusedMoeGemmPipeline_General
}
template <typename T>
CK_TILE_HOST_DEVICE static void PrintMem(T& tensor)
CK_TILE_HOST_DEVICE static void
PrintMem(T& tensor, const char* pstr, unsigned int threadid = 0, unsigned int blockid = 0)
{
constexpr auto spans = T::get_distributed_spans();
int counter = 0;
......@@ -99,12 +100,14 @@ struct FusedMoeGemmPipeline_General
constexpr auto i_j_idx = make_tuple(idxn, idxk);
const auto tile_idx =
get_x_indices_from_distributed_indices(tensor.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
if(threadIdx.x == threadid && blockIdx.x == 0 && blockIdx.y == blockid &&
blockIdx.z == 0)
{
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
printf("in G row is %d , col is %d, counter is %d, value is: %f"
printf("in %s row is %d , col is %d, counter is %d, value is: %f"
" \n",
pstr,
row,
col,
counter,
......@@ -119,14 +122,11 @@ struct FusedMoeGemmPipeline_General
const GWindow& g_window_,
const DWindow& d_window_,
OWindow& o_window_,
TopkWeightDataType /*topk_weight*/,
TopkWeightDataType topk_weight,
CK_TILE_LDS_ADDR void* smem,
index_t hidden_size,
index_t intermediate_size)
{
ignore = d_window_;
ignore = hidden_size;
ignore = intermediate_size;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>());
......@@ -157,13 +157,13 @@ struct FusedMoeGemmPipeline_General
auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
#if 0
PrintMem(a_dram_block);
PrintMem(a_dram_block,"A", 0, 1);
#endif
auto g_dram_block = load_tile(g_global_to_dram_window);
#if 0
PrintMem(g_dram_block);
#if 1
PrintMem(g_dram_block, "G", 0, 1);
#endif
clear_tile(s_acc); // initialize C
......@@ -191,7 +191,7 @@ struct FusedMoeGemmPipeline_General
block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block);
}
#if 1
#if 0
PrintMem(s_acc);
#endif
// relu
......@@ -233,6 +233,25 @@ struct FusedMoeGemmPipeline_General
d_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_D<Problem>());
auto d = load_tile(d_global_to_dram_window);
#if 0
PrintMem(d,"D",64);
#endif
// add to LDS
auto o_alds_view =
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::atomic_add>(
smem_0,
make_tuple(number<32>{}, number<32>{}),
make_tuple(32, 1),
number<8>{},
number<1>{});
auto o_alds_win =
make_tile_window(o_alds_view, make_tuple(number<32>{}, number<32>{}), {0, 0});
auto o_olds_win =
make_tile_window(o_alds_view,
make_tuple(number<32>{}, number<32>{}),
{0, 0},
Policy::template MakeGlobalTileDistribution_O<Problem>());
ignore = o_alds_win;
constexpr index_t kN1 = BlockShape::Block_N1;
const index_t n1_loops = ck_tile::integer_divide_ceil(hidden_size, kN1);
......@@ -258,12 +277,32 @@ struct FusedMoeGemmPipeline_General
block_sync_lds();
gemm_1(o_acc, y, d);
// block_sync_lds();
tile_elementwise_inout(
[&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc);
auto o = cast_tile<ODataType>(o_acc);
store_tile(o_window_, o);
}
store_tile(o_alds_win, o);
block_sync_lds();
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0)
// {
// for(int i = 0; i < 42; i++)
// {
// printf("\n%d value is %f\t", i, type_convert<float>(smem_0[i]));
// }
// }
if(threadIdx.x < 64)
{
auto o_out = load_tile(o_olds_win);
block_sync_lds();
store_tile(o_window_, o_out);
}
// ignore = o_olds_win;
// store_tile(o_window_, o);
#if 0
PrintMem(o_acc);
PrintMem(o,"O");
#endif
}
// store_tile(o_window_, a_dram_block);
}
};
......
......@@ -189,6 +189,18 @@ struct FusedMoeGemmPipelineGeneralPolicy
return make_static_tile_distribution(g_block_dstr_encode);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<1, 2, 16>, sequence<4, 8>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0()
{
......@@ -276,27 +288,27 @@ struct FusedMoeGemmPipelineGeneralPolicy
return d_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
// {
// using S_ = remove_cvref_t<typename Problem::BlockShape>;
// using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// // using CDataType = typename WarpGemm::CDataType;
// constexpr auto c_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<>,
// tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
// sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
// tuple<sequence<1, 2>>,
// tuple<sequence<1, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
// constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
// return c_block_dstr;
// }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_A()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment