"...composable_kernel.git" did not exist on "2090160a778967bbc4ac33ff4dfa318845c061d7"
Commit 55159365 authored by Chao Liu's avatar Chao Liu
Browse files

refactor type_convert

parent d8a632a8
......@@ -95,7 +95,7 @@ struct GridwiseReduction_xy_to_x_blockwise
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
......@@ -178,11 +178,11 @@ struct GridwiseReduction_xy_to_x_blockwise
if(thread_local_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
......@@ -246,7 +246,7 @@ struct GridwiseReduction_xy_to_x_blockwise
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
......@@ -347,11 +347,11 @@ struct GridwiseReduction_xy_to_x_blockwise
if(thread_local_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
......@@ -433,10 +433,8 @@ struct GridwiseReduction_xy_to_x_blockwise
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
src2dDesc.GetElementSpaceSize(),
type_convert<srcDataType>{}(zeroVal));
const auto src_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_indices_global, src2dDesc.GetElementSpaceSize());
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
......@@ -553,11 +551,11 @@ struct GridwiseReduction_xy_to_x_blockwise
if(thread_local_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
......
......@@ -85,7 +85,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
......@@ -145,11 +145,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
......@@ -207,7 +207,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
......@@ -273,11 +273,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
......@@ -350,10 +350,8 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
src2dDesc.GetElementSpaceSize(),
type_convert<srcDataType>{}(zeroVal));
const auto src_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_indices_global, src2dDesc.GetElementSpaceSize());
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
......@@ -436,11 +434,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
......
......@@ -85,7 +85,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
......@@ -154,11 +154,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(thread_inwarp_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
......@@ -218,7 +218,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
......@@ -293,11 +293,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(thread_inwarp_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
......@@ -375,10 +375,8 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
src2dDesc.GetElementSpaceSize(),
type_convert<srcDataType>{}(zeroVal));
const auto src_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_indices_global, src2dDesc.GetElementSpaceSize());
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
......@@ -472,11 +470,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(thread_inwarp_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
......
......@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_multiblock
__shared__ compType p_in_block_buffer[BlockBufferSize];
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize);
......@@ -223,7 +223,7 @@ struct GridwiseReduction_xy_to_x_multiblock
__shared__ int p_in_block_indices_buffer[BlockBufferSize];
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize);
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
......
......@@ -64,7 +64,7 @@ struct BlockwiseReduction_2d_block_buffer
offset = blockIsOneRow
? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id))
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd));
compType opData = type_convert<compType>{}(block_buffer[offset]);
compType opData = type_convert<compType>(block_buffer[offset]);
binop::calculate(lAccuData, opData);
}
......@@ -89,10 +89,10 @@ struct BlockwiseReduction_2d_block_buffer
? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id + indOffset))
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0));
compType opData1 = type_convert<compType>{}(block_buffer[offset1]);
compType opData2 = type_convert<compType>{}(block_buffer[offset2]);
compType opData1 = type_convert<compType>(block_buffer[offset1]);
compType opData2 = type_convert<compType>(block_buffer[offset2]);
binop::calculate(opData1, opData2);
block_buffer(offset1) = type_convert<compType>{}(opData1);
block_buffer(offset1) = type_convert<compType>(opData1);
}
__syncthreads();
......@@ -100,7 +100,7 @@ struct BlockwiseReduction_2d_block_buffer
if(thread_local_id == 0)
{
compType tmpVal = type_convert<compType>{}(block_buffer[0]);
compType tmpVal = type_convert<compType>(block_buffer[0]);
binop::calculate(accuData, tmpVal);
}
......@@ -131,13 +131,13 @@ struct BlockwiseReduction_2d_block_buffer
index_t offset2 = buffer2dDesc.CalculateOffset(
make_tuple(otherDimInd, thread_local_id + indOffset));
compType currVal1 = type_convert<compType>{}(block_buffer[offset1]);
compType currVal2 = type_convert<compType>{}(block_buffer[offset2]);
compType currVal1 = type_convert<compType>(block_buffer[offset1]);
compType currVal2 = type_convert<compType>(block_buffer[offset2]);
int currIndex1 = block_indices_buffer[offset1];
int currIndex2 = block_indices_buffer[offset2];
binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
block_buffer(offset1) = type_convert<compType>{}(currVal1);
block_buffer(offset1) = type_convert<compType>(currVal1);
block_indices_buffer(offset1) = currIndex1;
}
__syncthreads();
......@@ -150,7 +150,7 @@ struct BlockwiseReduction_2d_block_buffer
{
index_t offset = buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, 0));
compType tmpVal = type_convert<compType>{}(block_buffer[offset]);
compType tmpVal = type_convert<compType>(block_buffer[offset]);
int tmpIndex = block_indices_buffer[offset];
binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex);
......@@ -166,7 +166,7 @@ struct BlockwiseReduction_2d_block_buffer
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
{
offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd));
compType currVal = type_convert<compType>{}(block_buffer[offset]);
compType currVal = type_convert<compType>(block_buffer[offset]);
int currIndex = block_indices_buffer[offset];
binop::calculate(lAccuData, currVal, lAccuIndex, currIndex);
......@@ -187,13 +187,13 @@ struct BlockwiseReduction_2d_block_buffer
index_t offset2 =
buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0));
compType currVal1 = type_convert<compType>{}(block_buffer[offset1]);
compType currVal2 = type_convert<compType>{}(block_buffer[offset2]);
compType currVal1 = type_convert<compType>(block_buffer[offset1]);
compType currVal2 = type_convert<compType>(block_buffer[offset2]);
int currIndex1 = block_indices_buffer[offset1];
int currIndex2 = block_indices_buffer[offset2];
binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
block_buffer(offset1) = type_convert<compType>{}(currVal1);
block_buffer(offset1) = type_convert<compType>(currVal1);
block_indices_buffer(offset1) = currIndex1;
}
......@@ -202,7 +202,7 @@ struct BlockwiseReduction_2d_block_buffer
if(thread_local_id == 0)
{
compType tmpVal = type_convert<compType>{}(block_buffer[0]);
compType tmpVal = type_convert<compType>(block_buffer[0]);
int tmpIndex = block_indices_buffer[0];
binop::calculate(accuData, tmpVal, accuIndex, tmpIndex);
......@@ -227,9 +227,9 @@ struct BlockwiseReduction_2d_block_buffer
}
};
// Initialize the block-wise indices buffer, the index for each element in the block-wise data
// buffer
// is calculated according to its position in the buffer and the global starting index
// Initialize the block-wise indices buffer, the index for each element in the block-wise
// data buffer is calculated according to its position in the buffer and the global starting
// index
template <typename IdxBufferType>
__device__ static void init_buffer_indices(IdxBufferType& block_indices_buffer, int indexStart)
{
......
......@@ -196,7 +196,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(src_buf[Number<src_offset>{}]);
type_convert<DstData>(src_buf[Number<src_offset>{}]);
});
const bool is_dst_valid =
......@@ -983,7 +983,7 @@ struct ThreadwiseTensorSliceTransfer_v3
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(buffer_[Number<buffer_offset>{}]);
type_convert<DstData>(buffer_[Number<buffer_offset>{}]);
});
using dst_vector_t = typename decltype(dst_tmp_vector)::type;
......@@ -1403,7 +1403,7 @@ struct ThreadwiseTensorSliceTransfer_v4
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(src_tmp_vector.template AsType<SrcData>()[i]);
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
......
......@@ -351,7 +351,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_vector_desc.CalculateOffset(dst_vector_idx);
dst_vector.template AsType<DstData>()(Number<dst_vector_offset>{}) =
type_convert<DstData>{}(buffer_[Number<buffer_offset>{}]);
type_convert<DstData>(buffer_[Number<buffer_offset>{}]);
});
using dst_vector_t = typename decltype(dst_vector)::type;
......@@ -750,7 +750,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + src_vector_idx);
dst_buf(Number<dst_offset>{}) = type_convert<DstData>{}(
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(
src_vector.template AsType<DstData>()[Number<src_vector_offset>{}]);
});
});
......
......@@ -248,7 +248,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>{}(src_thread_scratch_[idx]);
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
......@@ -322,7 +322,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
{
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>{}(src_thread_scratch_[idx]);
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]);
});
}
#endif
......
......@@ -927,23 +927,36 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
static __host__ __device__ float bf16_to_f32(ushort src_val)
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ Y type_convert(X x)
{
return static_cast<Y>(x);
}
// convert bfp16 to fp32
template <>
inline __host__ __device__ float type_convert(ushort x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(src_val) << 16};
} u = {uint32_t(x) << 16};
return u.fp32;
}
static __host__ __device__ ushort f32_to_bf16(float src_val)
// convert fp32 to bfp16
template <>
inline __host__ __device__ ushort type_convert(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {src_val};
} u = {x};
if(~u.int32 & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
......@@ -976,40 +989,14 @@ static __host__ __device__ ushort f32_to_bf16(float src_val)
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
}
// data type conversion
template <typename T>
struct type_convert
{
template <typename X>
__device__ T operator()(X x) const
{
return static_cast<T>(x);
}
};
template <>
template <>
__device__ float type_convert<float>::operator()<ushort>(ushort x) const
{
return bf16_to_f32(x);
}
template <>
template <>
__device__ ushort type_convert<ushort>::operator()<float>(float x) const
{
return f32_to_bf16(x);
return uint16_t(u.int32 >> 16);
}
// TODO: deprecate this
template <typename T>
struct inner_product_with_conversion
{
static constexpr auto convert = type_convert<T>();
template <typename X, index_t N>
__device__ T operator()(typename vector_type<X, N>::type a,
typename vector_type<X, N>::type b) const
......@@ -1020,13 +1007,16 @@ struct inner_product_with_conversion
T acc = 0;
static_for<0, N, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
acc += type_convert<T>(a_vector.Scalars()[i]) * type_convert<T>(b_vector.Scalars()[i]);
});
return acc;
}
__device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); }
__device__ T operator()(float_t a, float_t b) const
{
return type_convert<T>(a) * type_convert<T>(b);
}
__device__ T operator()(int8x4_t a, int8x4_t b) const
{
......@@ -1036,7 +1026,8 @@ struct inner_product_with_conversion
T acc = 0;
static_for<0, 4, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
acc += type_convert<T>(a_vector.AsType<int8_t>()[i]) *
type_convert<T>(b_vector.AsType<int8_t>()[i]);
});
return acc;
......@@ -1050,7 +1041,8 @@ struct inner_product_with_conversion
T acc = 0;
static_for<0, 8, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
acc += type_convert<T>(a_vector.AsType<int8_t>()[i]) *
type_convert<T>(b_vector.AsType<int8_t>()[i]);
});
return acc;
......@@ -1064,7 +1056,8 @@ struct inner_product_with_conversion
T acc = 0;
static_for<0, 16, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
acc += type_convert<T>(a_vector.AsType<int8_t>()[i]) *
type_convert<T>(b_vector.AsType<int8_t>()[i]);
});
return acc;
......
......@@ -28,12 +28,6 @@ __device__ void inner_product<float, float, float>(const float& a, const float&
#endif
}
template <>
__device__ void inner_product<ushort, ushort, float>(const ushort& a, const ushort& b, float& c)
{
c += bf16_to_f32(a) * bf16_to_f32(b);
}
template <>
__device__ void
inner_product<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
......@@ -90,13 +84,12 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
c = __builtin_amdgcn_sdot2(a, b, c, false);
#endif
#else
const auto convert = type_convert<int32_t>{};
const vector_type<half_t, 2> a_vector{a};
const vector_type<half_t, 2> b_vector{b};
static_for<0, 2, 1>{}([&](auto i) {
c += convert(a_vector.AsType<half_t>()[i]) * convert(b_vector.AsType<half_t>()[i]);
c += type_convert<int32_t>(a_vector.AsType<half_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<half_t>()[i]);
});
#endif
}
......@@ -156,13 +149,12 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif
#else
const auto convert = type_convert<int32_t>{};
const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b};
static_for<0, 4, 1>{}([&](auto i) {
c += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
c += type_convert<int32_t>(a_vector.AsType<int8_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<int8_t>()[i]);
});
#endif
}
......
......@@ -165,7 +165,7 @@ struct unary_identic
scaler = 1.0f / static_cast<float>(divider);
};
__device__ inline constexpr T operator()(T a) const { return a * type_convert<T>{}(scaler); };
__device__ inline constexpr T operator()(T a) const { return a * type_convert<T>(scaler); };
float scaler = 1.0f;
};
......@@ -187,7 +187,7 @@ struct unary_square
{
a = a * a;
return a * type_convert<T>{}(scaler);
return a * type_convert<T>(scaler);
};
float scaler = 1.0f;
......@@ -210,7 +210,7 @@ struct unary_abs
{
a = abs(a);
return a * type_convert<T>{}(scaler);
return a * type_convert<T>(scaler);
};
float scaler = 1.0f;
......@@ -249,7 +249,7 @@ struct unary_abs<half_t, hasDividing>
{
a = static_cast<half_t>(__habs(a));
return a * type_convert<half_t>{}(scaler);
return a * type_convert<half_t>(scaler);
};
float scaler = 1.0f;
......
......@@ -82,8 +82,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{
if constexpr(is_same<TIn, ushort>::value)
{
v += ck::bf16_to_f32(in(n, c, hi, wi)) *
ck::bf16_to_f32(wei(k, c, y, x));
v += ck::type_convert<float>(in(n, c, hi, wi)) *
ck::type_convert<float>(wei(k, c, y, x));
}
else
{
......@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if constexpr(is_same<TOut, ushort>::value)
{
out(n, k, ho, wo) = f32_to_bf16(v);
out(n, k, ho, wo) = type_convert<ushort>(v);
}
else
{
......@@ -120,8 +120,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{
if constexpr(is_same<TIn, ushort>::value)
{
v += ck::bf16_to_f32(in(n, hi, wi, c)) *
ck::bf16_to_f32(wei(k, y, x, c));
v += ck::type_convert<float>(in(n, hi, wi, c)) *
ck::type_convert<float>(wei(k, y, x, c));
}
else
{
......@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
}
if constexpr(is_same<TOut, ushort>::value)
{
out(n, ho, wo, k) = f32_to_bf16(v);
out(n, ho, wo, k) = ck::type_convert<ushort>(v);
}
else
{
......
......@@ -5,15 +5,25 @@
#include "config.hpp"
#include "data_type.hpp"
template <typename T>
struct GeneratorTensor_0
{
template <typename... Is>
T operator()(Is...)
{
return T{0};
}
};
template <typename T>
struct GeneratorTensor_1
{
int value = 1;
template <typename... Is>
float operator()(Is...)
T operator()(Is...)
{
return value;
return ck::type_convert<T>(value);
}
};
......@@ -25,7 +35,7 @@ struct GeneratorTensor_1<ushort>
template <typename... Is>
ushort operator()(Is...)
{
return ck::f32_to_bf16(value);
return ck::type_convert<ushort>(value);
}
};
......@@ -41,17 +51,6 @@ struct GeneratorTensor_1<int8_t>
}
};
struct GeneratorTensor_0
{
int value = 0;
template <typename... Is>
float operator()(Is...)
{
return value;
}
};
template <typename T>
struct GeneratorTensor_2
{
......@@ -59,7 +58,7 @@ struct GeneratorTensor_2
int max_value = 1;
template <typename... Is>
float operator()(Is...)
T operator()(Is...)
{
return (std::rand() % (max_value - min_value)) + min_value;
}
......@@ -75,7 +74,7 @@ struct GeneratorTensor_2<ushort>
ushort operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::f32_to_bf16(tmp);
return ck::type_convert<ushort>(tmp);
}
};
......@@ -99,7 +98,7 @@ struct GeneratorTensor_3
T max_value = 1;
template <typename... Is>
float operator()(Is...)
T operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
......@@ -120,7 +119,7 @@ struct GeneratorTensor_3<ushort>
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::f32_to_bf16(fp32_tmp);
return ck::type_convert<ushort>(fp32_tmp);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment