Commit f912ca40 authored by “letaoqin”'s avatar “letaoqin”
Browse files

fix call indexing adaptor issue

parent 1561fc22
...@@ -65,8 +65,8 @@ struct indexing_adaptor ...@@ -65,8 +65,8 @@ struct indexing_adaptor
CK_TILE_HOST_DEVICE constexpr indexing_adaptor() = default; CK_TILE_HOST_DEVICE constexpr indexing_adaptor() = default;
CK_TILE_HOST_DEVICE constexpr indexing_adaptor(const IndexingType* idx) : cached_idx_(idx) {} CK_TILE_HOST_DEVICE constexpr indexing_adaptor(const IndexingType* idx) : cached_idx_(idx) {}
const IndexingType* cached_idx_; const IndexingType* cached_idx_;
mutable index_t preUpIndex = 0; mutable index_t pre_up_index_ = 0;
mutable index_t preLowIndex = 0; mutable index_t pre_low_index_ = 0;
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
...@@ -77,12 +77,14 @@ struct indexing_adaptor ...@@ -77,12 +77,14 @@ struct indexing_adaptor
idx_low(number<0>{}) = *(cached_idx_ + idx_up[number<0>{}]); idx_low(number<0>{}) = *(cached_idx_ + idx_up[number<0>{}]);
preUpIndex = idx_up[number<0>{}]; pre_up_index_ = idx_up[number<0>{}];
preLowIndex = idx_low(number<0>{}); pre_low_index_ = idx_low(number<0>{});
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) #if 0
if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0)
{ {
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{})); printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
} }
#endif
} }
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
...@@ -95,15 +97,15 @@ struct indexing_adaptor ...@@ -95,15 +97,15 @@ struct indexing_adaptor
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1, UpIdx::size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
int up_index = idx_diff_up[number<0>{}] + preUpIndex;
int low_index = *(cached_idx_ + up_index);
idx_diff_low(number<0>{}) = low_index - preLowIndex;
preUpIndex = up_index; int up_index = idx_diff_up[number<0>{}] + pre_up_index_;
preLowIndex = low_index; int low_index = *(cached_idx_ + up_index);
idx_diff_low(number<0>{}) = low_index - pre_low_index_;
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) pre_up_index_ = up_index;
pre_low_index_ = low_index;
#if 0
if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0)
{ {
printf("\n index form %d to %d, diff from %d to %d \n", printf("\n index form %d to %d, diff from %d to %d \n",
up_index, up_index,
...@@ -111,6 +113,7 @@ struct indexing_adaptor ...@@ -111,6 +113,7 @@ struct indexing_adaptor
idx_diff_up[number<0>{}], idx_diff_up[number<0>{}],
idx_diff_low(number<0>{})); idx_diff_low(number<0>{}));
} }
#endif
// pass the diff to lower, but not changing the actually index // pass the diff to lower, but not changing the actually index
} }
......
...@@ -268,8 +268,8 @@ struct FusedMoeGemmGlKernel ...@@ -268,8 +268,8 @@ struct FusedMoeGemmGlKernel
auto topk_weight = auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id]; reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const index_t* sorted_token_ids_ptr = reinterpret_cast<const index_t*>( const index_t* sorted_token_ids_ptr =
&(reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id])); reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr);
const auto a_window = [&]() { const auto a_window = [&]() {
// A is already pre-padded in previous kernel // A is already pre-padded in previous kernel
......
...@@ -104,19 +104,20 @@ struct FusedMoeGemmPipeline_FlatmmGl ...@@ -104,19 +104,20 @@ struct FusedMoeGemmPipeline_FlatmmGl
ignore = hidden_size; ignore = hidden_size;
ignore = intermediate_size; ignore = intermediate_size;
auto a_copy_dram_window = auto a_copy_dram_window = make_tile_window(
make_tile_window(a_window_.get_bottom_tensor_view(), a_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
a_window_.get_window_origin(), a_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_A<Problem>()); Policy::template MakeGlobalTileDistribution_A<Problem>());
auto a_dram = load_tile(a_copy_dram_window); auto a_dram = load_tile(a_copy_dram_window);
#if 0
//check a matrix gather right or not //check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram)::get_distributed_spans(); constexpr auto a_spans = decltype(a_dram)::get_distributed_spans();
int counter = 0; int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) { sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk){ sweep_tile_span(a_spans[number<1>{}], [&](auto idxk){
constexpr auto i_j_idx = make_tuple(idxm, idxk); constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0){
counter = counter + 1; counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0); index_t idm_0 = idxm.impl_.at(0);
index_t idn_0 = idxk.impl_.at(0); index_t idn_0 = idxk.impl_.at(0);
...@@ -124,8 +125,8 @@ struct FusedMoeGemmPipeline_FlatmmGl ...@@ -124,8 +125,8 @@ struct FusedMoeGemmPipeline_FlatmmGl
} }
}); });
}); });
#endif
ignore = a_spans; ignore = a_dram;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment