Commit eaa68635 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix RunWrite.

parent 125a39d1
......@@ -101,7 +101,7 @@ __global__ void
index_t gemm_tile_id_end = grid_size_grp;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
typename GridwiseGemm::AccType,
GridwiseGemm::GetMPerXdl() * GridwiseGemm::GetNPerXdl(),
GridwiseGemm::GetMXdlPerWave() * GridwiseGemm::GetNXdlPerWave(),
GridwiseGemm::GetCThreadBufferVectorSize(),
true>
results_buffer;
......
......@@ -308,8 +308,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// M0 - MBlock
// M1 - MPerBlock
// N0 - NBlock
// N1 - NVecPerThread
// N2 - NVecSize
// N1 - N repeats
// N2 - NVecSize * cluster length
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeEGridDescriptor_M0M1_N0N1N2(const EGridDesc_M_N& e_grid_desc_m_n)
......@@ -330,33 +330,18 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1();
constexpr auto workspace_thread_desc_m0m1_n0n1n2 = MakeReductionThreadDesc_M0M1_N0N1N2();
// # of threads in NDim * vector load size * # repeats per thread
constexpr auto NPerBlockPadded = cluster_length_reduce.At(I2) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I3) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4);
constexpr auto NPerBlockPad = NPerBlockPadded - Number<NPerBlock>{};
const auto e_grid_desc_m0m1_n0n1pad = transform_tensor_descriptor(
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0)),
make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)),
make_right_pad_transform(Number<NPerBlock>{}, NPerBlockPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto e_grid_desc_m0m1_n0n1n2 = transform_tensor_descriptor(
e_grid_desc_m0m1_n0n1pad,
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_pass_through_transform(e_grid_desc_m0m1_n0n1pad.GetLength(I0)),
make_pass_through_transform(e_grid_desc_m0m1_n0n1pad.GetLength(I1)),
make_pass_through_transform(e_grid_desc_m0m1_n0n1pad.GetLength(I2)),
make_unmerge_transform(make_tuple(
workspace_thread_desc_m0m1_n0n1n2.GetLength(I3) * cluster_length_reduce.At(I2),
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4)))),
make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0)),
make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)),
make_unmerge_transform(make_tuple(workspace_thread_desc_m0m1_n0n1n2.GetLength(I3),
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4) *
cluster_length_reduce.At(I2)))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
......@@ -436,8 +421,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
auto K_t = KBatch * KPerBlock;
if(!(K % K_t == 0))
if(!(K % KPerBlock == 0))
{
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
......@@ -540,6 +524,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
}
const auto k_batch_size =
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KBatch;
if(k_batch_size < KPerBlock)
{
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "The k-batch size (" << k_batch_size
<< ") value is less than KPerBlock!\n"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
if(k_batch_size % KPerBlock != 0)
{
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "The k-batch size (" << k_batch_size
<< ") value is not a multiple of KPerBlock!\n"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop =
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......@@ -624,8 +635,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
__device__ __host__ static constexpr auto GetNPerBlock() { return NPerBlock; }
__device__ __host__ static constexpr auto GetMPerXdl() { return MPerXdl; }
__device__ __host__ static constexpr auto GetNPerXdl() { return NPerXdl; }
__device__ __host__ static constexpr auto GetMXdlPerWave() { return MXdlPerWave; }
__device__ __host__ static constexpr auto GetNXdlPerWave() { return NXdlPerWave; }
__device__ static constexpr auto GetCThreadBufferVectorSize()
{
......@@ -646,7 +657,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
return BlockwiseGemmT::xdlops_gemm.GetRegSizePerXdlops();
}
template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf>
template <typename Block2ETileMap, typename CThreadBuf>
__device__ static void RunGEMM(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
void* __restrict__ p_shared,
......@@ -656,7 +667,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf,
const index_t k_tiles)
const index_t k_batch,
const index_t next_k_tiles)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -726,16 +738,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
true,
NumGemmKPrefetchStage>;
const index_t ak0_start_idx = kbatch_id * AK0PerBlock;
const index_t bk0_start_idx = kbatch_id * BK0PerBlock;
if(blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
{
printf("[RunGEMM] bid: %d, ak0_start_idx: %d, bk0_start_idx: %d\n",
static_cast<index_t>(blockIdx.x),
ak0_start_idx,
bk0_start_idx);
}
const index_t num_k_tiles_per_batch =
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
(KPerBlock * k_batch);
const index_t ak0_start_idx = kbatch_id * num_k_tiles_per_batch * AK0PerBlock.value;
const index_t bk0_start_idx = kbatch_id * num_k_tiles_per_batch * BK0PerBlock.value;
// A matrix blockwise copy
auto a_blockwise_copy =
......@@ -777,25 +784,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock.value, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock.value, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
// TODO: what if AK1 != BK1 ???
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(k_tiles);
// __builtin_amdgcn_readfirstlane((a_grid_desc_ak0_m_ak1.GetLength(I1) *
// a_grid_desc_ak0_m_ak1.GetLength(I3)) /
// KPerBlock);
if(blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
{
printf("[RunGEMM] bid: %d, num_k_block_main_loop %d\n",
static_cast<index_t>(blockIdx.x),
num_k_block_main_loop);
}
const auto gridwise_gemm_pipeline = GridwiseGemmPipe();
const index_t num_k_block_main_loop =
__builtin_amdgcn_readfirstlane(next_k_tiles * num_k_tiles_per_batch);
const bool has_k_block_main_loop =
gridwise_gemm_pipeline.CalculateHasMainLoop(num_k_block_main_loop);
bool clear_c_thread_buf = true;
......@@ -813,25 +810,47 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
KPack,
LoopSched>();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop,
clear_c_thread_buf);
if(has_k_block_main_loop)
{
gridwise_gemm_pipeline.template Run<true>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop,
clear_c_thread_buf);
}
else
{
gridwise_gemm_pipeline.template Run<false>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop,
clear_c_thread_buf);
}
}
template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf>
template <typename Block2ETileMap, typename CThreadBuf>
__device__ static void RunGEMM(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
void* __restrict__ p_shared,
......@@ -840,30 +859,31 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const index_t KBatch,
const index_t stride_a,
const index_t stride_b,
const index_t k_batch,
const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf,
const index_t k_tiles)
const index_t next_k_tiles)
{
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(M, K, StrideA, KBatch);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(K, N, StrideB, KBatch);
RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_shared,
a_element_op,
b_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
block_2_etile_map,
c_thread_buf,
k_tiles);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(M, K, stride_a, k_batch);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(K, N, stride_b, k_batch);
RunGEMM(p_a_grid,
p_b_grid,
p_shared,
a_element_op,
b_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
block_2_etile_map,
c_thread_buf,
k_batch,
next_k_tiles);
}
template <typename CThreadBuf>
......@@ -1098,24 +1118,24 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3)>{};
}
// M0 - 1
// M1 - M elements per thread
// N0 - 1
// N1 - N repeats per thread
// N2 - Vector load/store size
__device__ static constexpr auto MakeReductionThreadDesc_M0M1_N0N1N2()
{
constexpr auto cluster_lengths = GetClusterLengthReduction_M0_N0N1();
constexpr auto N1_elems =
math::integer_divide_ceil(Number<NPerBlock>{}, cluster_lengths.At(I2));
static_assert(N1_elems % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0,
"Invalid ReductionThreadDesc M0M1_N0N1N2! N1_elems have to be a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock!");
constexpr auto N2 = Number<CDEShuffleBlockTransferScalarPerVector_NPerBlock>{};
constexpr auto N1 = math::integer_divide_ceil(N1_elems, N2);
constexpr auto N1 = Number<NPerBlock>{} / (Number<cluster_lengths.At(I2)>{} * N2);
constexpr auto M1 = math::integer_divide_ceil(Number<MPerBlock>{}, cluster_lengths.At(I0));
static_assert(
Number<M1>{} * cluster_lengths.At(I0) >= Number<MPerBlock>{},
Number<M1>{} * cluster_lengths.At(I0) == Number<MPerBlock>{},
"Invalid ReductionThreadDesc M0M1_N0N1N2! M1 * cluster_length[0] have to be grater "
"or equal to MPerBlock.");
static_assert(Number<N1>{} * Number<N2>{} * cluster_lengths.At(I2) >= Number<NPerBlock>{},
static_assert(Number<N1>{} * Number<N2>{} * cluster_lengths.At(I2) == Number<NPerBlock>{},
"Invalid ReductionThreadDesc M0M1_N0N1N2! N1 * N2 * cluster_length[2] have "
"to be grater or equal to NPerBlock.");
......@@ -1129,6 +1149,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1();
constexpr auto reduce_cluster_desc = make_cluster_descriptor(cluster_length_reduce);
static_assert(ThisThreadBlock::GetNumOfThread() >= reduce_cluster_desc.GetElementSize(),
"Error! ThisThreadBlock::GetNumOfThread() too small");
if(ThisThreadBlock::GetThreadId() >= reduce_cluster_desc.GetElementSize())
{
return;
}
const auto reduce_thread_cluster_idx =
reduce_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
......@@ -1139,27 +1168,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const auto workspace_grid_desc_m0m1_n0n1 =
MakeWorkspaceGridDesc_GridSize_MPerBlock_I1_NPerBlock(get_grid_size());
// # of threads in NDim * vector load size * # repeats per thread
constexpr auto NPerBlockPadded = cluster_length_reduce.At(I2) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I3) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4);
constexpr auto NPerBlockPad = NPerBlockPadded - Number<NPerBlock>{};
const auto workspace_grid_desc_m0m1_n0n1pad = transform_tensor_descriptor(
workspace_grid_desc_m0m1_n0n1,
make_tuple(make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I0)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I1)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I2)),
make_right_pad_transform(Number<NPerBlock>{}, NPerBlockPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto workspace_grid_desc_m0m1_n0n1n2 = transform_tensor_descriptor(
workspace_grid_desc_m0m1_n0n1pad,
workspace_grid_desc_m0m1_n0n1,
make_tuple(
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1pad.GetLength(I0)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1pad.GetLength(I1)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1pad.GetLength(I2)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I0)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I1)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I2)),
make_unmerge_transform(make_tuple(workspace_thread_desc_m0m1_n0n1n2.GetLength(I3),
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4) *
cluster_length_reduce.At(I2)))),
......@@ -1255,7 +1269,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
__device__ static void RunWrite(DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
/* void* __restrict__ p_shared, */
const AccumulationBuffer& acc_buff,
AccumulationBuffer& acc_buff,
const index_t M,
const index_t N,
const std::array<index_t, NumDTensor> StrideDs,
......@@ -1301,9 +1315,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
},
Number<NumDTensor>{});
auto aux_vgpr_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, EDataType, ScalarPerVector, true>{};
constexpr auto d_vgpr_buf_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1, Number<ScalarPerVector>{}));
......@@ -1312,6 +1323,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1();
constexpr auto reduce_cluster_desc = make_cluster_descriptor(cluster_length_reduce);
static_assert(ThisThreadBlock::GetNumOfThread() >= reduce_cluster_desc.GetElementSize(),
"Error! ThisThreadBlock::GetNumOfThread() too small");
if(ThisThreadBlock::GetThreadId() >= reduce_cluster_desc.GetElementSize())
{
return;
}
const auto reduce_thread_cluster_idx =
reduce_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
......@@ -1344,6 +1363,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
},
Number<NumDTensor>{});
// Each thread writes consecutive M rows and strided N columns
auto e_grid_store =
ThreadwiseTensorSliceTransfer_v1r3<EDataType,
EDataType,
......@@ -1368,7 +1388,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto MIter = workspace_thread_desc_m0m1_n0n1n2.GetLength(I1);
constexpr auto NIter = workspace_thread_desc_m0m1_n0n1n2.GetLength(I3);
constexpr auto n1_step = cluster_length_reduce.At(I2);
constexpr auto n1_step = I1;
constexpr auto d_grid_M1_fwd_step = make_multi_index(I0, I1, I0, I0, I0);
constexpr auto d_grid_N1_fwd_step = make_multi_index(I0, I0, I0, n1_step, I0);
......@@ -1410,7 +1430,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(cde_element_op, tie(aux_vgpr_buf(I)), src_data_refs);
unpack2(cde_element_op, tie(acc_buff(acc_buf_offset + I)), src_data_refs);
});
// if(is_thread_local_1d_id_idx<0, 1, 8, 39>())
......@@ -1429,35 +1449,42 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// }
e_grid_store.Run(workspace_thread_desc_m0m1_n0n1n2,
make_tuple(I0, I0, I0, I0, I0),
aux_vgpr_buf,
make_tuple(I0, m_idx, I0, n_idx, I0),
// aux_vgpr_buf,
acc_buff,
e_grid_desc_m0m1_n0n1n2,
e_grid_buf);
if constexpr(n_idx != (NIter - 1))
if constexpr(NIter != 1)
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_N1_fwd_step);
});
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_N1_fwd_step);
}
else
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_N1_bwd_step);
});
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_N1_bwd_step);
if constexpr(n_idx != (NIter - 1))
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_N1_fwd_step);
});
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2,
d_grid_N1_fwd_step);
}
else
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_N1_bwd_step);
});
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2,
d_grid_N1_bwd_step);
}
}
}); // NIter
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_M1_fwd_step);
});
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_M1_fwd_step);
});
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_M1_fwd_step);
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_M1_fwd_step);
}
}); // MIter
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment