Commit 0bb08f4b authored by aska-0096's avatar aska-0096
Browse files

1. Enable 2-stage global Prefetch ( May cause VGPR spilling)

2. Enable FP16 accumulator blockwise_gemm
parent 716860e3
......@@ -35,24 +35,24 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
2, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 64, 1>,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
......@@ -61,7 +61,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 8>,
S<1, 32, 1, 4>,
8>;
// clang-format on
......
......@@ -47,6 +47,10 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break;
case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
......
......@@ -80,34 +80,34 @@ using DeviceOpInstance =
BElementOp,
CDEElementOp,
GemmSpec,
1,
64,
32,
2,
128,
64,
128,
64,
8,
16,
16,
2,
2,
S<4, 16, 1>,
4,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
8,
8,
true,
S<4, 16, 1>,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
8,
8,
true,
1,
1,
S<1, 2, 1, 32>,
1>;
S<1, 32, 1, 4>,
8>;
int main(int argc, char* argv[])
{
......
......@@ -67,24 +67,24 @@ using DeviceOpInstanceKKNN =
ASpec,
BSpec,
DESpec,
1,
256,
2,
128,
64,
128,
32,
8,
16,
16,
8,
1,
S<4, 64, 1>,
2,
4,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
......@@ -93,7 +93,7 @@ using DeviceOpInstanceKKNN =
true,
1,
1,
S<1, 16, 1, 16>,
S<1, 32, 1, 4>,
8>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
......
......@@ -211,10 +211,27 @@ struct BlockwiseGemmWMMA
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
I1,
Number<NRepeat>{},
I1,
I1,
MAccVgprs),
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
AccStride)
);
#if 0
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
......@@ -225,6 +242,7 @@ struct BlockwiseGemmWMMA
I1,
NThreadPerSubGroup,
MAccVgprs));
#endif
}
template <typename CGridDesc_M_N>
......
......@@ -140,8 +140,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......
......@@ -151,9 +151,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto B0EnableLds_manu = true;
static constexpr auto B1EnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu;
static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch >1);
static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch >1);
static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch >1);
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
......
......@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......
......@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch>1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch>1);
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......@@ -467,7 +467,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> || is_same_v<AccDataType, int32_t>))
{
printf("DeviceOp err: AccDataType");
return false;
......
......@@ -177,8 +177,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
......
......@@ -868,7 +868,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
/* index_t SrcScalarStrideInVector, */ 1,
/* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
NumGemmKPrefetchStage>(
a_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
......@@ -943,7 +944,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
1,
1,
B0ThreadTransferSrcResetCoordinateAfterRun,
true>(
true,
NumGemmKPrefetchStage>(
b0_grid_desc,
make_multi_index(0, 0, 0),
b0_element_op,
......
......@@ -874,7 +874,8 @@ struct GridwiseGemmMultipleD_Wmma
/* index_t SrcScalarStrideInVector, */ 1,
/* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
NumGemmKPrefetchStage>(
a_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
......@@ -950,7 +951,8 @@ struct GridwiseGemmMultipleD_Wmma
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
true,
NumGemmKPrefetchStage>(
b_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
......
......@@ -636,7 +636,8 @@ struct GridwiseGemm_Wmma
/* index_t SrcScalarStrideInVector, */ 1,
/* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
NumGemmKPrefetchStage>(
a_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
......@@ -712,7 +713,8 @@ struct GridwiseGemm_Wmma
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
true,
NumGemmKPrefetchStage>(
b_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
......@@ -814,10 +816,11 @@ struct GridwiseGemm_Wmma
/*******************************************************************************/
// write out to C, implement shuffle
{
// C mapping in single thread.
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// This API Provide All dimension (size) you need
// C mapping in single block
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
......
......@@ -89,6 +89,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
......@@ -100,12 +101,11 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
bool AssemblyBackend,
class FloatA,
class FloatB,
class FloatC>
......@@ -113,7 +113,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
{
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma, AssemblyBackend>::Run(a, b, reg_c);
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
......@@ -134,6 +134,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
......@@ -141,7 +142,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size *acc_pack_number/ wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
......@@ -158,7 +159,6 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
}
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
WaveSize,
......@@ -171,6 +171,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t acc_pack_number = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
......@@ -178,12 +179,11 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
......@@ -191,15 +191,14 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
{
if constexpr(wave_size == 32)
{
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
WaveSize,
......@@ -212,6 +211,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t acc_pack_number = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
......@@ -219,7 +219,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
......@@ -232,17 +232,15 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
{
if constexpr(wave_size == 32)
{
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
}
};
#endif
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
WaveSize,
......@@ -255,6 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
......@@ -262,7 +261,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size *acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
......@@ -351,7 +350,7 @@ struct WmmaSelector
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size ==
selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register");
}
......@@ -469,7 +468,7 @@ struct WmmaGemm
__device__ static constexpr index_t GetRegSizePerWmma()
{
return wmma_instr.num_acc_vgprs_per_wave;
return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
}
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
......@@ -497,12 +496,12 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(
wmma_instr.template run<MPerWmma, NPerWmma>(
p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(
wmma_instr.template run<MPerWmma, NPerWmma>(
p_b_wave, p_a_wave, p_c_thread);
}
}
......@@ -556,7 +555,7 @@ struct WmmaGemm
__host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{}, Number<wmma_instr.acc_pack_number>{});
}
};
......
......@@ -12,11 +12,11 @@ namespace ck {
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave, bool AssemblyBackend>
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32;
template <bool AssemblyBackend>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
template <>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment