Commit eac1753d authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Avoid convert to compType from dstDataType before writting the output value

parent 5a9f6308
...@@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise
threadwise_dst_load.Run( threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_store = auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -218,7 +222,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -218,7 +222,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id)); make_multi_index(block_global_1d_id));
threadwise_dst_store.Run( threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
} }
}; };
...@@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise
make_tuple(I0), make_tuple(I0),
priorDstValue_buf); priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -400,7 +408,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -400,7 +408,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id)); make_multi_index(block_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
} }
...@@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise
make_tuple(I0), make_tuple(I0),
priorDstValue_buf); priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -602,7 +614,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -602,7 +614,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id)); make_multi_index(block_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
} }
......
...@@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType, auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
...@@ -166,11 +170,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -166,11 +170,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run( threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_store = auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -184,7 +188,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -184,7 +188,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id)); make_multi_index(thread_global_1d_id));
threadwise_dst_store.Run( threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
}; };
template <> template <>
...@@ -271,6 +275,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -271,6 +275,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType, auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
...@@ -290,11 +298,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -290,11 +298,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run( threadwise_dst_load.Run(
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -322,7 +330,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -322,7 +330,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id)); make_multi_index(thread_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}; };
...@@ -430,6 +438,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -430,6 +438,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType, auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
...@@ -449,11 +461,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -449,11 +461,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run( threadwise_dst_load.Run(
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -481,7 +493,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -481,7 +493,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id)); make_multi_index(thread_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}; };
......
...@@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
threadwise_dst_load.Run( threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf(I0) * beta); dstValue_buf(I0) += priorDstValue_buf(I0) * beta;
} }
auto threadwise_dst_store = auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -194,7 +198,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -194,7 +198,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id)); make_multi_index(warp_global_1d_id));
threadwise_dst_store.Run( threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
} }
}; };
...@@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_tuple(I0), make_tuple(I0),
priorDstValue_buf); priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -346,7 +354,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -346,7 +354,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id)); make_multi_index(warp_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
} }
...@@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_tuple(I0), make_tuple(I0),
priorDstValue_buf); priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -521,7 +533,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -521,7 +533,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id)); make_multi_index(warp_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment