".github/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "9fcf73069f30bbc75cd52b7a36ec961129f239cb"
Commit eac1753d authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Avoid convert to compType from dstDataType before writting the output value

parent 5a9f6308
......@@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
......@@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise
threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
......@@ -218,7 +222,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id));
threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
}
};
......@@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
......@@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise
make_tuple(I0),
priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
......@@ -400,7 +408,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
......@@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
......@@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise
make_tuple(I0),
priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
......@@ -602,7 +614,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
......
......@@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
......@@ -166,11 +170,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
......@@ -184,7 +188,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id));
threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
};
template <>
......@@ -271,6 +275,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
......@@ -290,11 +298,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run(
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
......@@ -322,7 +330,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
};
......@@ -430,6 +438,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
......@@ -449,11 +461,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run(
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
......@@ -481,7 +493,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
};
......
......@@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
......@@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf(I0) * beta);
dstValue_buf(I0) += priorDstValue_buf(I0) * beta;
}
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
......@@ -194,7 +198,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id));
threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
}
};
......@@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
......@@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_tuple(I0),
priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
......@@ -346,7 +354,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
......@@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
......@@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_tuple(I0),
priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
......@@ -521,7 +533,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment