Commit b7c1259f authored by ltqin's avatar ltqin
Browse files

init ok

parent 0acd3ebe
...@@ -366,18 +366,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -366,18 +366,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto b_grid_size = CalculateGridSize(M, N); const auto b_grid_size = CalculateGridSize(M, N);
const auto nBatch = get_block_1d_id() / b_grid_size; const auto k_batch_id = get_block_1d_id() / b_grid_size;
const auto blockid_in_batch = get_block_1d_id() % b_grid_size; const auto block_id_in_batch = get_block_1d_id() % b_grid_size;
if(get_block_1d_id() == 2000) if(get_block_1d_id() == 2000)
printf("grid size: %d, Batch: %d block_id: %d k0: %d\n", printf("grid size: %d, k0: %d, blockid: %d, threadid %d, Batch: %d block_id: %d \n",
b_grid_size, b_grid_size,
nBatch, K0,
blockid_in_batch, get_block_1d_id(),
K0); get_thread_local_1d_id(),
k_batch_id,
block_id_in_batch);
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(blockid_in_batch)); c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(block_id_in_batch));
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
...@@ -391,65 +393,69 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -391,65 +393,69 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_b_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(1, Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(1, Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, MPerBlock, K1>, Sequence<1, KPerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_grid_desc),
decltype(a_k0_m_k1_block_desc), decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 3,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_k0_m_k1_grid_desc, true>(a_b_k0_m_k1_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_k0_m_k1_block_desc, a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, NPerBlock, K1>, Sequence<1, KPerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_grid_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 3,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_k0_n_k1_grid_desc, true>(b_b_k0_n_k1_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_k0_n_k1_block_desc, b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -490,8 +496,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -490,8 +496,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
FloatAB* p_a_block = p_shared_block; FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size; FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
...@@ -509,11 +515,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -509,11 +515,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
} }
// main body // main body
...@@ -521,25 +527,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -521,25 +527,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack); a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack); b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds(); block_sync_lds();
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock; k_block_data_begin += KPerBlock;
} while(k_block_data_begin < (K0 - KPerBlock)); } while(k_block_data_begin < (K0 - KPerBlock));
......
...@@ -62,21 +62,21 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk ...@@ -62,21 +62,21 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; using GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4 // using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; using GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t KBatch = 96; constexpr index_t KBatch = 2;
#elif 1 #elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -123,20 +123,24 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk ...@@ -123,20 +123,24 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmB
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM
make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemB
Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmB
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmB
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
...@@ -157,10 +161,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk ...@@ -157,10 +161,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 1, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
...@@ -181,19 +185,19 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk ...@@ -181,19 +185,19 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
GemmK1, GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
2, 3,
GemmABlockTransferSrcScalarPerVector_GemmK1, GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1, GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
2, 3,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1, GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
......
...@@ -233,6 +233,8 @@ int main(int argc, char* argv[]) ...@@ -233,6 +233,8 @@ int main(int argc, char* argv[])
in_right_pads_dev); in_right_pads_dev);
}; };
// set zero to wei_device
wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
#if USE_CONV_WRW_V4R4R2_XDL_NCHW #if USE_CONV_WRW_V4R4R2_XDL_NCHW
if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW) if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW)
{ {
...@@ -267,8 +269,6 @@ int main(int argc, char* argv[]) ...@@ -267,8 +269,6 @@ int main(int argc, char* argv[])
{ {
throw std::runtime_error("wrong! layout"); throw std::runtime_error("wrong! layout");
} }
// set zero to wei_device
wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment