Unverified Commit d520d0cf authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Add int4 reduction examples (#372)

* Add int4 reduction examples

* Contain all using of int4_t inside the pre-compiling condition checking
parent f246fd2c
......@@ -225,6 +225,28 @@ int main(int argc, char* argv[])
arg.scales[0],
arg.scales[1]);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
else if(arg.data_type == 7)
{
pass = reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inLengths,
arg.reduceDims,
arg.scales[0],
arg.scales[1]);
pass = pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inLengths,
arg.reduceDims,
arg.scales[0],
arg.scales[1]);
}
#endif
}
else
{
......@@ -251,6 +273,15 @@ int main(int argc, char* argv[])
pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// for testing int4_t using AVG operation
pass = pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing int4_t using MAX operation
pass = pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
#endif
// for testing 3D input
pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 960}, {0, 1}, 1.0f, 0.0f);
......
......@@ -58,28 +58,47 @@ int reduce_blockwise_impl(bool do_verification,
std::is_same<InOutDataType, float>::value &&
(op_support_indices && !std::is_same<AccDataType, float>::value);
// 1) If InOutDataType is int8_t, must use int8_t as AccDataType for indexable reduction
// operations 2) If InOutDataType is int8_t, must use int32_t as AccDataType for non-indexable
// reduction operations
// 1) If InOutDataType is int8_t or int4_t, must use int8_t as AccDataType for indexable
// reduction operations 2) If InOutDataType is int8_t or int4_t, must use int32_t as AccDataType
// for non-indexable reduction operations
constexpr bool invalid_reduce_4 =
std::is_same<InOutDataType, int8_t>::value &&
((!op_support_indices && !std::is_same<AccDataType, int32_t>::value) ||
(op_support_indices && !std::is_same<AccDataType, int8_t>::value));
// 1) If InOutDataType is int8_t, the supported operation must be either indexable operations or
// ADD/AVG
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
constexpr bool invalid_reduce_4_2 =
std::is_same<InOutDataType, int4_t>::value &&
((!op_support_indices && !std::is_same<AccDataType, int32_t>::value) ||
(op_support_indices && !std::is_same<AccDataType, int8_t>::value));
#endif
// 1) If InOutDataType is int8_t or int4_t, the supported operation must be either indexable
// operations or ADD/AVG
constexpr bool invalid_reduce_5 = std::is_same<InOutDataType, int8_t>::value &&
(!op_support_indices && ReduceOpId != ReduceTensorOp::ADD &&
ReduceOpId != ReduceTensorOp::AVG);
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
constexpr bool invalid_reduce_5_2 = std::is_same<InOutDataType, int4_t>::value &&
(!op_support_indices && ReduceOpId != ReduceTensorOp::ADD &&
ReduceOpId != ReduceTensorOp::AVG);
#endif
// 1) If InOutDataType is bhalf_t, must use float as AccDataType for all reduction operations
constexpr bool invalid_reduce_6 =
std::is_same<InOutDataType, bhalf_t>::value && !std::is_same<AccDataType, float>::value;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
constexpr bool invalid_reduce =
(invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || invalid_reduce_4 ||
invalid_reduce_5 || invalid_reduce_6 || invalid_reduce_4_2 || invalid_reduce_5_2);
#else
constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 ||
invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6);
#endif
if(invalid_reduce)
if constexpr(invalid_reduce)
{
std::cerr << "The reduction setting is invalid, exiting!" << std::endl;
return (-1);
......@@ -91,10 +110,17 @@ int reduce_blockwise_impl(bool do_verification,
using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using InOutDataTypeInDevice = typename std::
conditional<std::is_same<InOutDataType, int4_t>::value, int8_t, InOutDataType>::type;
#else
using InOutDataTypeInDevice = InOutDataType;
#endif
using DeviceReduceInstance =
ck::tensor_operation::device::DeviceReduceMultiBlock<InOutDataType,
ck::tensor_operation::device::DeviceReduceMultiBlock<InOutDataTypeInDevice,
AccDataType,
InOutDataType,
InOutDataTypeInDevice,
Rank,
NumReduceDim,
ReduceOperation,
......@@ -166,13 +192,35 @@ int reduce_blockwise_impl(bool do_verification,
};
// these buffers are usually provided by the user application
DeviceMem in_dev(sizeof(InOutDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpaceSize());
DeviceMem in_dev(sizeof(InOutDataTypeInDevice) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(InOutDataTypeInDevice) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if(std::is_same<InOutDataType, int4_t>::value)
{
std::vector<InOutDataTypeInDevice> tmp_buf(in.mData.size());
std::copy_n(in.mData.data(), in.mData.size(), tmp_buf.data());
in_dev.ToDevice(tmp_buf.data());
}
else
#endif
in_dev.ToDevice(in.mData.data());
if(beta != 0.0f)
out_dev.ToDevice(out.mData.data());
{
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if(std::is_same<InOutDataType, int4_t>::value)
{
std::vector<InOutDataTypeInDevice> tmp_buf(in.mData.size());
std::copy_n(out.mData.data(), out.mData.size(), tmp_buf.data());
out_dev.ToDevice(tmp_buf.data());
}
else
#endif
out_dev.ToDevice(out.mData.data());
};
size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0;
......@@ -261,7 +309,19 @@ int reduce_blockwise_impl(bool do_verification,
if(do_verification)
{
out_dev.FromDevice(out.mData.data());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if(std::is_same<InOutDataType, int4_t>::value)
{
std::vector<InOutDataTypeInDevice> tmp_buf(out.mData.size());
out_dev.FromDevice(tmp_buf.data());
std::copy_n(tmp_buf.data(), out.mData.size(), out.mData.data());
}
else
#endif
out_dev.FromDevice(out.mData.data());
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
if(OutputIndex)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment