"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "959b8d82d89e5570c815d54ed58216d6d57660fc"
Commit 4ddda63b authored by aska-0096's avatar aska-0096
Browse files

sanity check pass

parent 74f0d5de
......@@ -108,26 +108,26 @@ using DeviceGemmInstance =
1, // MRepeat
8, // LRepeat
4, // NRepeat
S<4, 64, 1>, // ABlockTransfer
S<4, 64, 1>, // ABlockTransfer MK -> K0 M K1
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // B0BlockTransfer
S<4, 64, 1>, // B0BlockTransfer LK -> K0 L K1
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // B1BlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
S<4, 8, 8>, // B1BlockTransfer LN -> L0 N L1
S<0, 2, 1>,
S<0, 2, 1>,
1,
8,
1,
false,
1, // CShuffleMWmmaPerWavePerShuffle
2, // CShuffleNWmmaPerWavePerShuffle
......
......@@ -127,7 +127,7 @@ int run(int argc, char* argv[])
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: b0 ; unit: a b1 fail
case 6: // Rand: b0 ; unit: a b1 pass
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
......@@ -240,12 +240,6 @@ int run(int argc, char* argv[])
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
// for(int i =0; i< 128; i++){
// for(int j =0; j< 128; j++){
// printf("%0.2lf ", acc0_g_m_n.mData[i*128 +j]);
// }
// printf("\n");
// }
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
......
......@@ -171,8 +171,6 @@ struct BlockwiseGemmWMMA
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
// printf("tid %03d, Mat-B offset %d\n", get_thread_local_1d_id()%32, CalculateBThreadOriginDataIndex().At(Number<3>{}));
}
// transposed WMMA output C' = B' * A'
......@@ -301,9 +299,6 @@ struct BlockwiseGemmWMMA
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
// static_for<0, a_thread_buf.size(), 1>{}([&](auto i) {
// a_thread_buf(i) = 1;
// });
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
......@@ -323,9 +318,6 @@ struct BlockwiseGemmWMMA
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
// a_thread_vec.template AsType<FloatA>()(i) = 1;
// b_thread_vec.template AsType<FloatB>()(i) = 1;
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
......@@ -333,12 +325,6 @@ struct BlockwiseGemmWMMA
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// printf("GPU Gemm0 input, Tid %03d, A%2d = %04x, B%2d = %0x4\n",
// get_thread_local_1d_id(),
// i.value, *(reinterpret_cast<uint16_t*>(&a_thread_vec.template AsType<FloatA>()(i))),
// i.value, *(reinterpret_cast<uint16_t*>(&b_thread_vec.template AsType<FloatB>()(i))));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
......
......@@ -658,16 +658,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4);
constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5);
constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6);
if(get_thread_local_1d_id()==0){
printf("t_mrepeat %d, t_mwave %d, t_mthreadpersubgroup %d, t_lrepeat %d, t_lwave %d, t_lsubgroup %d, t_laccvgprs %d \n",
t_mrepeat.value,
t_mwave.value,
t_mthreadpersubgroup.value,
t_lrepeat.value,
t_lwave.value,
t_lsubgroup.value,
t_laccvgprs.value);
}
// get acc0 thread map
constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)),
......@@ -744,28 +734,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// B1 matrix blockwise copy
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B1ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BL0, NPerBlock, BL1>,
B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder,
FloatB1,
FloatB1,
decltype(b1_grid_desc_l0_n_l1),
decltype(b1_block_desc_l0perblock_nperblock_l1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_L1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>(
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
/* typename SrcElementwiseOperation, */ B1ElementwiseOperation,
/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough,
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
/* typename BlockSliceLengths, */ Sequence<BL0, NPerBlock, BL1>,
/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1,
/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatB1,
/* typename DstData, */ FloatB1,
/* typename SrcDesc, */ decltype(b1_grid_desc_l0_n_l1),
/* typename DstDesc, */ decltype(b1_block_desc_l0perblock_nperblock_l1),
/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>,
/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim,
/* index_t DstVectorDim, */ 2,
/* index_t SrcScalarPerVector, */ B1BlockTransferSrcScalarPerVector,
/* index_t DstScalarPerVector, */ B1BlockTransferDstScalarPerVector_L1,
/* index_t SrcScalarStrideInVector, */ 1,
/* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord
NumGemmKPrefetchStage>(
b1_grid_desc_l0_n_l1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b1_element_op,
......@@ -793,7 +783,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
NPerWmma,
MRepeat,
NRepeat,
KPack>{make_tuple(0, 0, 0, 0, 0)};
KPack,
true>{make_tuple(0, 0, 0, 0, 0)};
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
......@@ -899,18 +890,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
// printf("GPU Gemm 0, Tid %03d, GPU acc0 = %lf\n", get_thread_local_1d_id(), acc0_thread_buf[I0]);
// static_for<0, acc0_thread_buf.Size(), 1>{}([&](auto i) {
// printf("GPU Gemm0, Tid %03d, GPU acc%d = %lf\n", get_thread_local_1d_id(), i.value, acc0_thread_buf[i]);
// });
blockwise_softmax.Run(acc0_thread_buf, workspace_buf);
// printf("GPU SoftMax, Tid %03d, GPU acc0 = %lf\n", get_thread_local_1d_id(), acc0_thread_buf[I0]);
blockwise_softmax.Run(acc0_thread_buf, workspace_buf);
// TODO: may convert to log domain
running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
......@@ -949,8 +936,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds();
// printf("GPU permute lanex, Tid %03d, GPU 0 = %04x\n", get_thread_local_1d_id(), *(reinterpret_cast<const uint16_t*>(&a1_thread_buf[I0])));
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds();
......@@ -973,7 +958,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a1_thread_buf);
block_sync_lds();
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
}
} // end gemm1
......@@ -997,8 +982,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatAcc1 acc1 = acc1_thread_buf[I]; // P*V
FloatAcc1 c = c_thread_buf[I]; // O
FloatAcc1 acc1 = acc1_thread_buf[I]; // P*V
FloatAcc1 c = c_thread_buf[I]; // O
FloatAcc1 c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
......@@ -1022,18 +1007,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
// write out to C, implement shuffle
{
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
// This API Provide All dimension (size) you need
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4);
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
......@@ -1043,22 +1028,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
MWave, // MWave
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
MAccVgprs)),
MThreadPerSubGroup // MThreadPerSubGroup = MPerWmma
)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
NWave, // NWave
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
NSubGroup,
NAccVgprs))), // NSubGroup * NAccVgprs = NPerWmma
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{}));
make_tuple(Sequence<>{}, Sequence<0, 1, 2>{}, Sequence<>{}, Sequence<3, 4, 5, 6>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
......@@ -1067,30 +1053,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
const auto m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
const auto n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NSubGroup, NAccVgprs))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc1,
FloatCShuffle,
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
I1,
......@@ -1098,21 +1084,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
CShuffleNRepeatPerShuffle,
I1,
I1,
MAccVgprs>,
NAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1, // vector write pixel
8, // vector write pixel
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
make_multi_index(0,
m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
0,
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3]),
n_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
......@@ -1144,7 +1130,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, MAccVgprs>,
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, NAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
Sequence<CShuffleMRepeatPerShuffle,
1,
......@@ -1152,7 +1138,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
CShuffleNRepeatPerShuffle,
1,
1,
MAccVgprs>>{};
NAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
......@@ -1172,10 +1158,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
......
......@@ -1394,19 +1394,24 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
if(get_thread_local_1d_id() % 32 > 16){
if(get_thread_local_1d_id() % 32 < 16){
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset + DstScalarPerVector>{})),
type_convert<DstData>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
}
else{
// apply type convert
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset>{})),
type_convert<DstData>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
}
SrcData d = 0;
int temp = 0;
temp = __builtin_amdgcn_permlanex16(temp, type_convert<int>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
d = type_convert<float>(temp);
if(get_thread_local_1d_id() % 32 < 16){
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(d);
}
else{
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(d);
}
});
});
......
......@@ -964,6 +964,32 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return u.fp32;
}
template <>
inline __host__ __device__ constexpr int type_convert<int, float>(float x)
{
union
{
float fp32;
int int32;
} u = {x};
// u.fp32 = x;
return u.int32;
}
template <>
inline __host__ __device__ constexpr float type_convert<float, int>(int x)
{
union
{
int int32;
float fp32;
} u = {x};
// u.fp32 = x;
return u.fp32;
}
// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment