Commit 33aff2ef authored by carlushuang's avatar carlushuang
Browse files

modify format

parent 483ad69a
......@@ -224,7 +224,8 @@ bool test_topk_softmax(ck_tile::ArgParser args)
// constexpr auto uf = ck_tile::static_uford<sss, pks, ord>{};
// ck_tile::static_for<0, uf.get_num_of_access(), 1>{}([&](auto i_access){
// uf([&](auto i_0, auto i_1, auto i_2, auto i_3, auto i_4, auto i_5, auto i_6, auto i_7) {
// uf([&](auto i_0, auto i_1, auto i_2, auto i_3, auto i_4, auto i_5, auto i_6, auto
// i_7) {
// decltype(i_0)::push_front(i_access).fo_0();
// decltype(i_1)::push_front(i_access).fo_1();
// decltype(i_2)::push_front(i_access).fo_2();
......
......@@ -28,22 +28,23 @@ struct BlockSoftmax2D
CK_TILE_DEVICE void
operator()(const DistributedTensor& x, DistributedTensor& y, number<dim> = {})
{
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
#if _BLOCK_SOFTMAX_USE_UNPACK2
const auto f_max3 = [](auto e0, auto e1, auto e2) {
float rtn;
asm volatile("v_max3_f32 %0, %1, %2, %3" : "=v"(rtn) : "v"(e0), "v"(e1), "v"(e2));
return rtn;};
return rtn;
};
const auto f_sum3 = [](auto e0, auto e1, auto e2) { return e0 + e1 + e2; };
#endif
// compute row max
auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
#else
auto row_max = reduce_row_max(f_max);
auto row_max = reduce_row_max(f_max);
#endif
// compute elementwise softmax
constexpr auto span_2d = DistributedTensor::get_distributed_spans();
......@@ -59,9 +60,9 @@ struct BlockSoftmax2D
// compute row sum
auto reduce_row_sum = BlockReduce2D<decltype(y)>{y, DataType{0}};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{});
auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{});
#else
auto row_sum = reduce_row_sum(f_sum);
auto row_sum = reduce_row_sum(f_sum);
#endif
// reciprocal
auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment