"vscode:/vscode.git/clone" did not exist on "ccbd8d907be06fa585e5298824760a959829936c"
Unverified Commit 0a66c54e authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

fixed multiple definition issue of bfp16/fp32 conversion function when building ckProfiler (#51)



* fixed bfloat16 issues

* refactor type_convert
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 89e1ebd4
...@@ -95,7 +95,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -95,7 +95,7 @@ struct GridwiseReduction_xy_to_x_blockwise
const auto zeroVal = opReduce::GetReductionZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize()); p_dst_global, dst1dDesc.GetElementSpaceSize());
...@@ -178,11 +178,11 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -178,11 +178,11 @@ struct GridwiseReduction_xy_to_x_blockwise
if(thread_local_id == 0) if(thread_local_id == 0)
{ {
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf; StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]); dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
...@@ -246,7 +246,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -246,7 +246,7 @@ struct GridwiseReduction_xy_to_x_blockwise
const auto zeroVal = opReduce::GetReductionZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize()); p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -347,11 +347,11 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -347,11 +347,11 @@ struct GridwiseReduction_xy_to_x_blockwise
if(thread_local_id == 0) if(thread_local_id == 0)
{ {
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf; StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]); dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
...@@ -433,10 +433,8 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -433,10 +433,8 @@ struct GridwiseReduction_xy_to_x_blockwise
const auto zeroVal = opReduce::GetReductionZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
src2dDesc.GetElementSpaceSize(),
type_convert<srcDataType>{}(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_indices_global, src2dDesc.GetElementSpaceSize()); ws_indices_global, src2dDesc.GetElementSpaceSize());
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -553,11 +551,11 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -553,11 +551,11 @@ struct GridwiseReduction_xy_to_x_blockwise
if(thread_local_id == 0) if(thread_local_id == 0)
{ {
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf; StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]); dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
......
...@@ -85,7 +85,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -85,7 +85,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
const auto zeroVal = opReduce::GetReductionZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize()); p_dst_global, dst1dDesc.GetElementSpaceSize());
...@@ -145,11 +145,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -145,11 +145,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf; StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]); dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
...@@ -207,7 +207,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -207,7 +207,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
const auto zeroVal = opReduce::GetReductionZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize()); p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -273,11 +273,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -273,11 +273,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf; StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]); dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
...@@ -350,10 +350,8 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -350,10 +350,8 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
const auto zeroVal = opReduce::GetReductionZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
src2dDesc.GetElementSpaceSize(),
type_convert<srcDataType>{}(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_indices_global, src2dDesc.GetElementSpaceSize()); ws_indices_global, src2dDesc.GetElementSpaceSize());
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -436,11 +434,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -436,11 +434,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf; StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]); dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
......
...@@ -85,7 +85,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -85,7 +85,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
const auto zeroVal = opReduce::GetReductionZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize()); p_dst_global, dst1dDesc.GetElementSpaceSize());
...@@ -154,11 +154,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -154,11 +154,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(thread_inwarp_id == 0) if(thread_inwarp_id == 0)
{ {
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf; StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]); dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
...@@ -218,7 +218,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -218,7 +218,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
const auto zeroVal = opReduce::GetReductionZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize()); p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -293,11 +293,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -293,11 +293,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(thread_inwarp_id == 0) if(thread_inwarp_id == 0)
{ {
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf; StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]); dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
...@@ -375,10 +375,8 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -375,10 +375,8 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
const auto zeroVal = opReduce::GetReductionZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
src2dDesc.GetElementSpaceSize(),
type_convert<srcDataType>{}(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_indices_global, src2dDesc.GetElementSpaceSize()); ws_indices_global, src2dDesc.GetElementSpaceSize());
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -472,11 +470,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -472,11 +470,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(thread_inwarp_id == 0) if(thread_inwarp_id == 0)
{ {
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf; StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]); dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
......
...@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_multiblock ...@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_multiblock
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize); ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize);
...@@ -223,7 +223,7 @@ struct GridwiseReduction_xy_to_x_multiblock ...@@ -223,7 +223,7 @@ struct GridwiseReduction_xy_to_x_multiblock
__shared__ int p_in_block_indices_buffer[BlockBufferSize]; __shared__ int p_in_block_indices_buffer[BlockBufferSize];
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize); ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize);
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
......
...@@ -64,7 +64,7 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -64,7 +64,7 @@ struct BlockwiseReduction_2d_block_buffer
offset = blockIsOneRow offset = blockIsOneRow
? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id)) ? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id))
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd)); : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd));
compType opData = type_convert<compType>{}(block_buffer[offset]); compType opData = type_convert<compType>(block_buffer[offset]);
binop::calculate(lAccuData, opData); binop::calculate(lAccuData, opData);
} }
...@@ -89,10 +89,10 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -89,10 +89,10 @@ struct BlockwiseReduction_2d_block_buffer
? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id + indOffset)) ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id + indOffset))
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0)); : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0));
compType opData1 = type_convert<compType>{}(block_buffer[offset1]); compType opData1 = type_convert<compType>(block_buffer[offset1]);
compType opData2 = type_convert<compType>{}(block_buffer[offset2]); compType opData2 = type_convert<compType>(block_buffer[offset2]);
binop::calculate(opData1, opData2); binop::calculate(opData1, opData2);
block_buffer(offset1) = type_convert<compType>{}(opData1); block_buffer(offset1) = type_convert<compType>(opData1);
} }
__syncthreads(); __syncthreads();
...@@ -100,7 +100,7 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -100,7 +100,7 @@ struct BlockwiseReduction_2d_block_buffer
if(thread_local_id == 0) if(thread_local_id == 0)
{ {
compType tmpVal = type_convert<compType>{}(block_buffer[0]); compType tmpVal = type_convert<compType>(block_buffer[0]);
binop::calculate(accuData, tmpVal); binop::calculate(accuData, tmpVal);
} }
...@@ -131,13 +131,13 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -131,13 +131,13 @@ struct BlockwiseReduction_2d_block_buffer
index_t offset2 = buffer2dDesc.CalculateOffset( index_t offset2 = buffer2dDesc.CalculateOffset(
make_tuple(otherDimInd, thread_local_id + indOffset)); make_tuple(otherDimInd, thread_local_id + indOffset));
compType currVal1 = type_convert<compType>{}(block_buffer[offset1]); compType currVal1 = type_convert<compType>(block_buffer[offset1]);
compType currVal2 = type_convert<compType>{}(block_buffer[offset2]); compType currVal2 = type_convert<compType>(block_buffer[offset2]);
int currIndex1 = block_indices_buffer[offset1]; int currIndex1 = block_indices_buffer[offset1];
int currIndex2 = block_indices_buffer[offset2]; int currIndex2 = block_indices_buffer[offset2];
binop::calculate(currVal1, currVal2, currIndex1, currIndex2); binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
block_buffer(offset1) = type_convert<compType>{}(currVal1); block_buffer(offset1) = type_convert<compType>(currVal1);
block_indices_buffer(offset1) = currIndex1; block_indices_buffer(offset1) = currIndex1;
} }
__syncthreads(); __syncthreads();
...@@ -150,7 +150,7 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -150,7 +150,7 @@ struct BlockwiseReduction_2d_block_buffer
{ {
index_t offset = buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, 0)); index_t offset = buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, 0));
compType tmpVal = type_convert<compType>{}(block_buffer[offset]); compType tmpVal = type_convert<compType>(block_buffer[offset]);
int tmpIndex = block_indices_buffer[offset]; int tmpIndex = block_indices_buffer[offset];
binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex); binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex);
...@@ -166,7 +166,7 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -166,7 +166,7 @@ struct BlockwiseReduction_2d_block_buffer
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
{ {
offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd)); offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd));
compType currVal = type_convert<compType>{}(block_buffer[offset]); compType currVal = type_convert<compType>(block_buffer[offset]);
int currIndex = block_indices_buffer[offset]; int currIndex = block_indices_buffer[offset];
binop::calculate(lAccuData, currVal, lAccuIndex, currIndex); binop::calculate(lAccuData, currVal, lAccuIndex, currIndex);
...@@ -187,13 +187,13 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -187,13 +187,13 @@ struct BlockwiseReduction_2d_block_buffer
index_t offset2 = index_t offset2 =
buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0)); buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0));
compType currVal1 = type_convert<compType>{}(block_buffer[offset1]); compType currVal1 = type_convert<compType>(block_buffer[offset1]);
compType currVal2 = type_convert<compType>{}(block_buffer[offset2]); compType currVal2 = type_convert<compType>(block_buffer[offset2]);
int currIndex1 = block_indices_buffer[offset1]; int currIndex1 = block_indices_buffer[offset1];
int currIndex2 = block_indices_buffer[offset2]; int currIndex2 = block_indices_buffer[offset2];
binop::calculate(currVal1, currVal2, currIndex1, currIndex2); binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
block_buffer(offset1) = type_convert<compType>{}(currVal1); block_buffer(offset1) = type_convert<compType>(currVal1);
block_indices_buffer(offset1) = currIndex1; block_indices_buffer(offset1) = currIndex1;
} }
...@@ -202,7 +202,7 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -202,7 +202,7 @@ struct BlockwiseReduction_2d_block_buffer
if(thread_local_id == 0) if(thread_local_id == 0)
{ {
compType tmpVal = type_convert<compType>{}(block_buffer[0]); compType tmpVal = type_convert<compType>(block_buffer[0]);
int tmpIndex = block_indices_buffer[0]; int tmpIndex = block_indices_buffer[0];
binop::calculate(accuData, tmpVal, accuIndex, tmpIndex); binop::calculate(accuData, tmpVal, accuIndex, tmpIndex);
...@@ -227,9 +227,9 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -227,9 +227,9 @@ struct BlockwiseReduction_2d_block_buffer
} }
}; };
// Initialize the block-wise indices buffer, the index for each element in the block-wise data // Initialize the block-wise indices buffer, the index for each element in the block-wise
// buffer // data buffer is calculated according to its position in the buffer and the global starting
// is calculated according to its position in the buffer and the global starting index // index
template <typename IdxBufferType> template <typename IdxBufferType>
__device__ static void init_buffer_indices(IdxBufferType& block_indices_buffer, int indexStart) __device__ static void init_buffer_indices(IdxBufferType& block_indices_buffer, int indexStart)
{ {
......
...@@ -196,7 +196,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -196,7 +196,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector.template AsType<DstData>()(i) = dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(src_buf[Number<src_offset>{}]); type_convert<DstData>(src_buf[Number<src_offset>{}]);
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -983,7 +983,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -983,7 +983,7 @@ struct ThreadwiseTensorSliceTransfer_v3
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_tmp_vector.template AsType<DstData>()(i) = dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(buffer_[Number<buffer_offset>{}]); type_convert<DstData>(buffer_[Number<buffer_offset>{}]);
}); });
using dst_vector_t = typename decltype(dst_tmp_vector)::type; using dst_vector_t = typename decltype(dst_tmp_vector)::type;
...@@ -1403,7 +1403,7 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1403,7 +1403,7 @@ struct ThreadwiseTensorSliceTransfer_v4
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile // TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) = dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(src_tmp_vector.template AsType<SrcData>()[i]); type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
}); });
// copy data from dst_tmp_vector into dst_buf // copy data from dst_tmp_vector into dst_buf
......
...@@ -351,7 +351,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -351,7 +351,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_vector_desc.CalculateOffset(dst_vector_idx); dst_vector_desc.CalculateOffset(dst_vector_idx);
dst_vector.template AsType<DstData>()(Number<dst_vector_offset>{}) = dst_vector.template AsType<DstData>()(Number<dst_vector_offset>{}) =
type_convert<DstData>{}(buffer_[Number<buffer_offset>{}]); type_convert<DstData>(buffer_[Number<buffer_offset>{}]);
}); });
using dst_vector_t = typename decltype(dst_vector)::type; using dst_vector_t = typename decltype(dst_vector)::type;
...@@ -750,7 +750,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -750,7 +750,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); dst_origin_idx + data_to_origin_disp_idx + src_vector_idx);
dst_buf(Number<dst_offset>{}) = type_convert<DstData>{}( dst_buf(Number<dst_offset>{}) = type_convert<DstData>(
src_vector.template AsType<DstData>()[Number<src_vector_offset>{}]); src_vector.template AsType<DstData>()[Number<src_vector_offset>{}]);
}); });
}); });
......
...@@ -248,7 +248,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -248,7 +248,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here // convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>{}(src_thread_scratch_[idx]); dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]);
}); });
#else #else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
...@@ -322,7 +322,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -322,7 +322,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
{ {
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here // convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>{}(src_thread_scratch_[idx]); dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]);
}); });
} }
#endif #endif
......
...@@ -927,23 +927,36 @@ using int8x16_t = typename vector_type<int8_t, 16>::type; ...@@ -927,23 +927,36 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type; using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
__host__ __device__ float bf16_to_f32(ushort src_val) // Convert X to Y
template <typename Y, typename X>
__host__ __device__ Y type_convert(X x)
{
return static_cast<Y>(x);
}
// convert bfp16 to fp32
template <>
inline __host__ __device__ float type_convert(ushort x)
{ {
union union
{ {
uint32_t int32; uint32_t int32;
float fp32; float fp32;
} u = {uint32_t(src_val) << 16}; } u = {uint32_t(x) << 16};
return u.fp32; return u.fp32;
} }
__host__ __device__ ushort f32_to_bf16(float src_val) // convert fp32 to bfp16
template <>
inline __host__ __device__ ushort type_convert(float x)
{ {
union union
{ {
float fp32; float fp32;
uint32_t int32; uint32_t int32;
} u = {src_val}; } u = {x};
if(~u.int32 & 0x7f800000) if(~u.int32 & 0x7f800000)
{ {
// When the exponent bits are not all 1s, then the value is zero, normal, // When the exponent bits are not all 1s, then the value is zero, normal,
...@@ -976,40 +989,14 @@ __host__ __device__ ushort f32_to_bf16(float src_val) ...@@ -976,40 +989,14 @@ __host__ __device__ ushort f32_to_bf16(float src_val)
// the bloat16's mantissa bits are all 0. // the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN u.int32 |= 0x10000; // Preserve signaling NaN
} }
return uint16_t(u.int32 >> 16);
}
// data type conversion
template <typename T>
struct type_convert
{
template <typename X>
__device__ T operator()(X x) const
{
return static_cast<T>(x);
}
};
template <>
template <>
__device__ float type_convert<float>::operator()<ushort>(ushort x) const
{
return bf16_to_f32(x);
}
template <> return uint16_t(u.int32 >> 16);
template <>
__device__ ushort type_convert<ushort>::operator()<float>(float x) const
{
return f32_to_bf16(x);
} }
// TODO: deprecate this // TODO: deprecate this
template <typename T> template <typename T>
struct inner_product_with_conversion struct inner_product_with_conversion
{ {
static constexpr auto convert = type_convert<T>();
template <typename X, index_t N> template <typename X, index_t N>
__device__ T operator()(typename vector_type<X, N>::type a, __device__ T operator()(typename vector_type<X, N>::type a,
typename vector_type<X, N>::type b) const typename vector_type<X, N>::type b) const
...@@ -1020,13 +1007,16 @@ struct inner_product_with_conversion ...@@ -1020,13 +1007,16 @@ struct inner_product_with_conversion
T acc = 0; T acc = 0;
static_for<0, N, 1>{}([&](auto i) { static_for<0, N, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); acc += type_convert<T>(a_vector.Scalars()[i]) * type_convert<T>(b_vector.Scalars()[i]);
}); });
return acc; return acc;
} }
__device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); } __device__ T operator()(float_t a, float_t b) const
{
return type_convert<T>(a) * type_convert<T>(b);
}
__device__ T operator()(int8x4_t a, int8x4_t b) const __device__ T operator()(int8x4_t a, int8x4_t b) const
{ {
...@@ -1036,7 +1026,8 @@ struct inner_product_with_conversion ...@@ -1036,7 +1026,8 @@ struct inner_product_with_conversion
T acc = 0; T acc = 0;
static_for<0, 4, 1>{}([&](auto i) { static_for<0, 4, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]); acc += type_convert<T>(a_vector.AsType<int8_t>()[i]) *
type_convert<T>(b_vector.AsType<int8_t>()[i]);
}); });
return acc; return acc;
...@@ -1050,7 +1041,8 @@ struct inner_product_with_conversion ...@@ -1050,7 +1041,8 @@ struct inner_product_with_conversion
T acc = 0; T acc = 0;
static_for<0, 8, 1>{}([&](auto i) { static_for<0, 8, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]); acc += type_convert<T>(a_vector.AsType<int8_t>()[i]) *
type_convert<T>(b_vector.AsType<int8_t>()[i]);
}); });
return acc; return acc;
...@@ -1064,7 +1056,8 @@ struct inner_product_with_conversion ...@@ -1064,7 +1056,8 @@ struct inner_product_with_conversion
T acc = 0; T acc = 0;
static_for<0, 16, 1>{}([&](auto i) { static_for<0, 16, 1>{}([&](auto i) {
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]); acc += type_convert<T>(a_vector.AsType<int8_t>()[i]) *
type_convert<T>(b_vector.AsType<int8_t>()[i]);
}); });
return acc; return acc;
......
...@@ -28,12 +28,6 @@ __device__ void inner_product<float, float, float>(const float& a, const float& ...@@ -28,12 +28,6 @@ __device__ void inner_product<float, float, float>(const float& a, const float&
#endif #endif
} }
template <>
__device__ void inner_product<ushort, ushort, float>(const ushort& a, const ushort& b, float& c)
{
c += bf16_to_f32(a) * bf16_to_f32(b);
}
template <> template <>
__device__ void __device__ void
inner_product<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c) inner_product<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
...@@ -90,13 +84,12 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h ...@@ -90,13 +84,12 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
c = __builtin_amdgcn_sdot2(a, b, c, false); c = __builtin_amdgcn_sdot2(a, b, c, false);
#endif #endif
#else #else
const auto convert = type_convert<int32_t>{};
const vector_type<half_t, 2> a_vector{a}; const vector_type<half_t, 2> a_vector{a};
const vector_type<half_t, 2> b_vector{b}; const vector_type<half_t, 2> b_vector{b};
static_for<0, 2, 1>{}([&](auto i) { static_for<0, 2, 1>{}([&](auto i) {
c += convert(a_vector.AsType<half_t>()[i]) * convert(b_vector.AsType<half_t>()[i]); c += type_convert<int32_t>(a_vector.AsType<half_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<half_t>()[i]);
}); });
#endif #endif
} }
...@@ -156,13 +149,12 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, ...@@ -156,13 +149,12 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false); c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif #endif
#else #else
const auto convert = type_convert<int32_t>{};
const vector_type<int8_t, 4> a_vector{a}; const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b}; const vector_type<int8_t, 4> b_vector{b};
static_for<0, 4, 1>{}([&](auto i) { static_for<0, 4, 1>{}([&](auto i) {
c += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]); c += type_convert<int32_t>(a_vector.AsType<int8_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<int8_t>()[i]);
}); });
#endif #endif
} }
......
...@@ -165,7 +165,7 @@ struct unary_identic ...@@ -165,7 +165,7 @@ struct unary_identic
scaler = 1.0f / static_cast<float>(divider); scaler = 1.0f / static_cast<float>(divider);
}; };
__device__ inline constexpr T operator()(T a) const { return a * type_convert<T>{}(scaler); }; __device__ inline constexpr T operator()(T a) const { return a * type_convert<T>(scaler); };
float scaler = 1.0f; float scaler = 1.0f;
}; };
...@@ -187,7 +187,7 @@ struct unary_square ...@@ -187,7 +187,7 @@ struct unary_square
{ {
a = a * a; a = a * a;
return a * type_convert<T>{}(scaler); return a * type_convert<T>(scaler);
}; };
float scaler = 1.0f; float scaler = 1.0f;
...@@ -210,7 +210,7 @@ struct unary_abs ...@@ -210,7 +210,7 @@ struct unary_abs
{ {
a = abs(a); a = abs(a);
return a * type_convert<T>{}(scaler); return a * type_convert<T>(scaler);
}; };
float scaler = 1.0f; float scaler = 1.0f;
...@@ -249,7 +249,7 @@ struct unary_abs<half_t, hasDividing> ...@@ -249,7 +249,7 @@ struct unary_abs<half_t, hasDividing>
{ {
a = static_cast<half_t>(__habs(a)); a = static_cast<half_t>(__habs(a));
return a * type_convert<half_t>{}(scaler); return a * type_convert<half_t>(scaler);
}; };
float scaler = 1.0f; float scaler = 1.0f;
......
...@@ -82,8 +82,8 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -82,8 +82,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{ {
if constexpr(is_same<TIn, ushort>::value) if constexpr(is_same<TIn, ushort>::value)
{ {
v += ck::bf16_to_f32(in(n, c, hi, wi)) * v += ck::type_convert<float>(in(n, c, hi, wi)) *
ck::bf16_to_f32(wei(k, c, y, x)); ck::type_convert<float>(wei(k, c, y, x));
} }
else else
{ {
...@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if constexpr(is_same<TOut, ushort>::value) if constexpr(is_same<TOut, ushort>::value)
{ {
out(n, k, ho, wo) = f32_to_bf16(v); out(n, k, ho, wo) = type_convert<ushort>(v);
} }
else else
{ {
...@@ -120,8 +120,8 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -120,8 +120,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{ {
if constexpr(is_same<TIn, ushort>::value) if constexpr(is_same<TIn, ushort>::value)
{ {
v += ck::bf16_to_f32(in(n, hi, wi, c)) * v += ck::type_convert<float>(in(n, hi, wi, c)) *
ck::bf16_to_f32(wei(k, y, x, c)); ck::type_convert<float>(wei(k, y, x, c));
} }
else else
{ {
...@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
} }
if constexpr(is_same<TOut, ushort>::value) if constexpr(is_same<TOut, ushort>::value)
{ {
out(n, ho, wo, k) = f32_to_bf16(v); out(n, ho, wo, k) = ck::type_convert<ushort>(v);
} }
else else
{ {
......
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
template <>
void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
const Tensor<ushort>& b,
Tensor<ushort>& c,
const GemmMatrixLayout layout)
{
if(layout == GemmMatrixLayout::MK_KN_MN)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_MN)
{
auto f_mk_nk_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_MN)
{
auto f_km_kn_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_MN)
{
auto f_km_nk_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_KN_NM)
{
auto f_mk_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_NM)
{
auto f_mk_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_NM)
{
auto f_km_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_NM)
{
auto f_km_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
}
template <typename AType, typename BType, typename CType> template <typename AType, typename BType, typename CType>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n, const Tensor<BType>& b_k_n,
......
...@@ -299,53 +299,41 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> s ...@@ -299,53 +299,41 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> s
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
float bf16_to_f32_(ushort src_val);
template <typename T> template <typename T>
void check_error(const Tensor<T>& ref, const Tensor<T>& result) void check_error(const Tensor<T>& ref, const Tensor<T>& result)
{ {
float error = 0; float error = 0;
float max_diff = -1; float max_diff = -1;
float ref_value = 0, result_value = 0; float ref_value = 0, result_value = 0;
for(int i = 0; i < ref.mData.size(); ++i)
if constexpr(std::is_same<ushort, T>::value)
{ {
error += std::abs(double(ref.mData[i]) - double(result.mData[i])); for(int i = 0; i < ref.mData.size(); ++i)
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff)
{ {
max_diff = diff; error += std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i]));
ref_value = ref.mData[i]; float diff = std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i]));
result_value = result.mData[i]; if(max_diff < diff)
{
max_diff = diff;
ref_value = bf16_to_f32_(ref.mData[i]);
result_value = bf16_to_f32_(result.mData[i]);
}
} }
} }
else
std::cout << "error: " << error << std::endl;
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
}
__host__ __device__ float bf16_to_f32(ushort src_val)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(src_val) << 16};
return u.fp32;
}
template <>
void check_error<ushort>(const Tensor<ushort>& ref, const Tensor<ushort>& result)
{
float error = 0;
float max_diff = -1;
float ref_value = 0, result_value = 0;
for(int i = 0; i < ref.mData.size(); ++i)
{ {
error += std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i])); for(int i = 0; i < ref.mData.size(); ++i)
float diff = std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i]));
if(max_diff < diff)
{ {
max_diff = diff; error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
ref_value = bf16_to_f32(ref.mData[i]); float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
result_value = bf16_to_f32(result.mData[i]); if(max_diff < diff)
{
max_diff = diff;
ref_value = ref.mData[i];
result_value = result.mData[i];
}
} }
} }
......
...@@ -5,15 +5,25 @@ ...@@ -5,15 +5,25 @@
#include "config.hpp" #include "config.hpp"
#include "data_type.hpp" #include "data_type.hpp"
template <typename T>
struct GeneratorTensor_0
{
template <typename... Is>
T operator()(Is...)
{
return T{0};
}
};
template <typename T> template <typename T>
struct GeneratorTensor_1 struct GeneratorTensor_1
{ {
int value = 1; int value = 1;
template <typename... Is> template <typename... Is>
float operator()(Is...) T operator()(Is...)
{ {
return value; return ck::type_convert<T>(value);
} }
}; };
...@@ -25,7 +35,7 @@ struct GeneratorTensor_1<ushort> ...@@ -25,7 +35,7 @@ struct GeneratorTensor_1<ushort>
template <typename... Is> template <typename... Is>
ushort operator()(Is...) ushort operator()(Is...)
{ {
return ck::f32_to_bf16(value); return ck::type_convert<ushort>(value);
} }
}; };
...@@ -41,17 +51,6 @@ struct GeneratorTensor_1<int8_t> ...@@ -41,17 +51,6 @@ struct GeneratorTensor_1<int8_t>
} }
}; };
struct GeneratorTensor_0
{
int value = 0;
template <typename... Is>
float operator()(Is...)
{
return value;
}
};
template <typename T> template <typename T>
struct GeneratorTensor_2 struct GeneratorTensor_2
{ {
...@@ -59,7 +58,7 @@ struct GeneratorTensor_2 ...@@ -59,7 +58,7 @@ struct GeneratorTensor_2
int max_value = 1; int max_value = 1;
template <typename... Is> template <typename... Is>
float operator()(Is...) T operator()(Is...)
{ {
return (std::rand() % (max_value - min_value)) + min_value; return (std::rand() % (max_value - min_value)) + min_value;
} }
...@@ -75,7 +74,7 @@ struct GeneratorTensor_2<ushort> ...@@ -75,7 +74,7 @@ struct GeneratorTensor_2<ushort>
ushort operator()(Is...) ushort operator()(Is...)
{ {
float tmp = (std::rand() % (max_value - min_value)) + min_value; float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::f32_to_bf16(tmp); return ck::type_convert<ushort>(tmp);
} }
}; };
...@@ -99,7 +98,7 @@ struct GeneratorTensor_3 ...@@ -99,7 +98,7 @@ struct GeneratorTensor_3
T max_value = 1; T max_value = 1;
template <typename... Is> template <typename... Is>
float operator()(Is...) T operator()(Is...)
{ {
float tmp = float(std::rand()) / float(RAND_MAX); float tmp = float(std::rand()) / float(RAND_MAX);
...@@ -120,7 +119,7 @@ struct GeneratorTensor_3<ushort> ...@@ -120,7 +119,7 @@ struct GeneratorTensor_3<ushort>
float fp32_tmp = min_value + tmp * (max_value - min_value); float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::f32_to_bf16(fp32_tmp); return ck::type_convert<ushort>(fp32_tmp);
} }
}; };
......
...@@ -61,3 +61,13 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream ...@@ -61,3 +61,13 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream
LogRange(os, desc.GetStrides(), ", "); LogRange(os, desc.GetStrides(), ", ");
os << "}" << std::endl; os << "}" << std::endl;
} }
float bf16_to_f32_(ushort src_val)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(src_val) << 16};
return u.fp32;
}
...@@ -106,12 +106,12 @@ void profile_conv(int do_verification, ...@@ -106,12 +106,12 @@ void profile_conv(int do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
default: default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
} }
if(do_verification) if(do_verification)
......
...@@ -122,12 +122,12 @@ void profile_gemm(int do_verification, ...@@ -122,12 +122,12 @@ void profile_gemm(int do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
} }
if(do_verification) if(do_verification)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment