Commit 580d93dc authored by letaoqin's avatar letaoqin
Browse files

rewrite save o

parent d4a0a8ee
...@@ -500,8 +500,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -500,8 +500,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
auto c_dev = c_buf.ToHost<ADataType>(); auto c_dev = c_buf.ToHost<ADataType>();
std::cout << std::endl; std::cout << std::endl;
// std::cout << o_dev << std::endl;
// std::cout << c_dev << std::endl; // std::cout << c_dev << std::endl;
std::cout << o_dev << std::endl;
// int count = 0; // int count = 0;
// std::cout << "["; // std::cout << "[";
// for(int i = 0; i < tokens; i++) // for(int i = 0; i < tokens; i++)
......
...@@ -81,7 +81,7 @@ struct indexing_adaptor ...@@ -81,7 +81,7 @@ struct indexing_adaptor
#if Using_Gather #if Using_Gather
pre_up_index_ = idx_up[number<0>{}]; pre_up_index_ = idx_up[number<0>{}];
pre_low_index_ = idx_low(number<0>{}); pre_low_index_ = idx_low(number<0>{});
#if 0 #if 1
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{ {
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{})); printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
...@@ -93,8 +93,8 @@ struct indexing_adaptor ...@@ -93,8 +93,8 @@ struct indexing_adaptor
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& /*idx_low*/, LowIdx& idx_low,
const UpIdx& /*idx_up*/) const const UpIdx& idx_up) const
{ {
// TODO: nonthing changed here // TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
...@@ -109,14 +109,16 @@ struct indexing_adaptor ...@@ -109,14 +109,16 @@ struct indexing_adaptor
pre_up_index_ = up_index; pre_up_index_ = up_index;
pre_low_index_ = low_index; pre_low_index_ = low_index;
#if 0 #if 1
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{ {
printf("\n index form %d to %d, diff from %d to %d \n", printf("\n index form %d to %d, idx_diff_low %d, idx_diff_up: %d, idx_low: %d, idx_up: %d \n",
up_index, up_index,
low_index, low_index,
idx_diff_low(number<0>{}),
idx_diff_up[number<0>{}], idx_diff_up[number<0>{}],
idx_diff_low(number<0>{})); idx_low(number<0>{}),
idx_up.at(number<0>{}));
} }
#endif #endif
#endif #endif
......
...@@ -252,13 +252,6 @@ struct FusedMoeGemmGlKernel ...@@ -252,13 +252,6 @@ struct FusedMoeGemmGlKernel
index_t idx_n0 = index_t idx_n0 =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_N0); __builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_N0);
// const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
// const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m
// // position
// auto topk_weight =
// reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const index_t* sorted_token_ids_ptr = const index_t* sorted_token_ids_ptr =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr); reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr);
...@@ -375,18 +368,17 @@ struct FusedMoeGemmGlKernel ...@@ -375,18 +368,17 @@ struct FusedMoeGemmGlKernel
}(); }();
const auto w_window = [&]() { const auto w_window = [&]() {
const TopkWeightDataType* w_ptr = reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr); const TopkWeightDataType* w_ptr =
const auto w_view_ = make_naive_tensor_view<address_space_enum::global>( reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr);
const auto w_view_ = make_naive_tensor_view<address_space_enum::global>(
w_ptr, w_ptr,
make_tuple(kargs.max_num_tokens_padded), make_tuple(kargs.max_num_tokens_padded),
make_tuple(1), make_tuple(1),
number<1>{}, number<1>{},
number<1>{}); number<1>{});
const auto w_window_ = make_tile_window( const auto w_window_ =
w_view_, make_tile_window(w_view_, make_tuple(number<BlockShape::Block_M0>{}), {idx_m0});
make_tuple(number<BlockShape::Block_M0>{}),
{idx_m0});
return w_window_; return w_window_;
}(); }();
......
...@@ -348,22 +348,28 @@ struct FusedMoeGemmPipeline_General ...@@ -348,22 +348,28 @@ struct FusedMoeGemmPipeline_General
while(iCounter1 > 0) while(iCounter1 > 0)
{ {
clear_tile(o_acc); clear_tile(o_acc);
block_sync_lds(); block_sync_lds_direct_load();
gemm_1(o_acc, y, d); gemm_1(o_acc, y, d);
block_sync_lds();
move_tile_window(d_global_to_dram_window, {kN1, 0}); move_tile_window(d_global_to_dram_window, {kN1, 0});
d = load_tile(d_global_to_dram_window); d = load_tile(d_global_to_dram_window);
// move out window and save data // move out window and save data
tile_elementwise_inout([&weight](auto& x) { x = x * type_convert<float>(weight); },
o_acc);
auto o = cast_tile<ODataType>(o_acc); auto o = cast_tile<ODataType>(o_acc);
store_tile(o_window_, o); store_tile(o_alds_win, o);
move_tile_window(o_window_, {kN1, 0}); block_sync_lds();
save_o();
move_tile_window(o_window_, {0, kN1});
iCounter1--; iCounter1--;
} }
// tail // tail
{ {
clear_tile(o_acc); clear_tile(o_acc);
block_sync_lds(); block_sync_lds_direct_load();
gemm_1(o_acc, y, d); gemm_1(o_acc, y, d);
// block_sync_lds(); // block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment