Commit 1561fc22 authored by “letaoqin”'s avatar “letaoqin”
Browse files

change indexing adapter to gather matrix

parent 1caa8198
......@@ -246,8 +246,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0],
experts,
block_m);
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << std::endl;
// for(int i = 0; i < tokens; i++)
// {
// std::cout << "Line " << i << "\t";
// for(int j = 0; j < hidden_size; j++)
// {
// std::cout << ck_tile::type_convert<float>(a_host(i,j)) << "\t";
// }
// std::cout << std::endl;
// }
std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl;
......
......@@ -65,6 +65,8 @@ struct indexing_adaptor
CK_TILE_HOST_DEVICE constexpr indexing_adaptor() = default;
CK_TILE_HOST_DEVICE constexpr indexing_adaptor(const IndexingType* idx) : cached_idx_(idx) {}
const IndexingType* cached_idx_;
mutable index_t preUpIndex = 0;
mutable index_t preLowIndex = 0;
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
......@@ -74,6 +76,13 @@ struct indexing_adaptor
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = *(cached_idx_ + idx_up[number<0>{}]);
preUpIndex = idx_up[number<0>{}];
preLowIndex = idx_low(number<0>{});
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
}
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
......@@ -86,8 +95,22 @@ struct indexing_adaptor
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
int up_index = idx_diff_up[number<0>{}] + preUpIndex;
int low_index = *(cached_idx_ + up_index);
idx_diff_low(number<0>{}) = low_index - preLowIndex;
preUpIndex = up_index;
preLowIndex = low_index;
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
printf("\n index form %d to %d, diff from %d to %d \n",
up_index,
low_index,
idx_diff_up[number<0>{}],
idx_diff_low(number<0>{}));
}
// pass the diff to lower, but not changing the actually index
}
......
......@@ -97,13 +97,35 @@ struct FusedMoeGemmPipeline_FlatmmGl
index_t hidden_size,
index_t intermediate_size)
{
ignore = a_window_;
ignore = g_window_;
ignore = d_window_;
ignore = o_window_;
ignore = smem;
ignore = hidden_size;
ignore = intermediate_size;
auto a_copy_dram_window =
make_tile_window(a_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
a_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_A<Problem>());
auto a_dram = load_tile(a_copy_dram_window);
//check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk){
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idn_0 = idxk.impl_.at(0);
printf("in A idm is %d , idn_ is %d , counter is %d, value is: %f \n", idm_0, idn_0, counter, ck_tile::type_convert<float>(a_dram(i_j_idx)));
}
});
});
ignore = a_spans;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment