Commit e97fdbc3 authored by letaoqin's avatar letaoqin
Browse files

change gather index adaptor

parent 727f201d
......@@ -501,7 +501,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto c_dev = c_buf.ToHost<ADataType>();
std::cout << std::endl;
// std::cout << c_dev << std::endl;
std::cout << o_dev << std::endl;
// std::cout << o_dev << std::endl;
// int count = 0;
// std::cout << "[";
// for(int i = 0; i < tokens; i++)
......
......@@ -81,7 +81,7 @@ struct indexing_adaptor
#if Using_Gather
pre_up_index_ = idx_up[number<0>{}];
pre_low_index_ = idx_low(number<0>{});
#if 1
#if 0
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
......@@ -100,30 +100,30 @@ struct indexing_adaptor
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
(void)idx_up;
#if !Using_Gather
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
idx_low += idx_diff_low;
#else
int up_index = idx_diff_up[number<0>{}] + pre_up_index_;
int low_index = *(cached_idx_ + up_index);
idx_low(number<0>{}) = low_index;
idx_diff_low(number<0>{}) = low_index - pre_low_index_;
pre_up_index_ = up_index;
pre_low_index_ = low_index;
#if 1
#if 0
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
printf("\n index form %d to %d, idx_diff_low %d, idx_diff_up: %d, idx_low: %d, idx_up: %d \n",
printf("\n end index form %d to %d, idx_diff_low %d, idx_diff_up: %d, idx_low: %d, idx_up: %d, pre_low_index_: %d pre_up_index_: %d\n",
up_index,
low_index,
idx_diff_low(number<0>{}),
idx_diff_up[number<0>{}],
idx_low(number<0>{}),
idx_up.at(number<0>{}));
idx_up.at(number<0>{}),
pre_low_index_,
pre_up_index_);
}
#endif
#endif
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
......
......@@ -274,7 +274,6 @@ struct FusedMoeGemmPipeline_General
}
}
#endif
ignore = w;
// y data
auto bridge_llds_win =
make_tile_window(bridge_lds_view,
......@@ -339,9 +338,6 @@ struct FusedMoeGemmPipeline_General
type_convert<float>(o1.get_thread_buffer()[i]));
});
});
// tile_elementwise_inout([&weight](auto& x) { x = x *
// type_convert<float>(weight); },
// o0);
auto o = cast_tile<ODataType>(o0);
update_tile(o_window_, o);
// restore pos
......
......@@ -21,17 +21,11 @@ namespace ck_tile {
struct FusedMoeGemmPipelineGeneralPolicy
{
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{
// TODO: always 1 dword
return 2;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
{
// using async
constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
constexpr index_t copy_bytes = 8;
constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
......@@ -196,7 +190,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<1, 1, 32>, sequence<2, 16>>,
tuple<sequence<1, 2, 16>, sequence<4, 8>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<2, 0>>,
sequence<1, 2>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment