Commit 27dc055b authored by aska-0096's avatar aska-0096
Browse files

fix a host tensor bug and clean up flash-attn code

parent 4ddda63b
...@@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add; using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceOpInstanceKKNN = using DeviceOpInstanceKKNN =
...@@ -64,18 +65,18 @@ using DeviceOpInstanceKKNN = ...@@ -64,18 +65,18 @@ using DeviceOpInstanceKKNN =
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
GemmSpec, GemmSpec,
ABSpec, ASpec,
ABSpec, BSpec,
DESpec, DESpec,
256, 256,
128, 128,
256, 128,
8, 4,
8, 8,
16, 16,
16, 16,
4, 4,
4, 2,
S<4, 64, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -252,21 +253,6 @@ int main(int argc, char* argv[]) ...@@ -252,21 +253,6 @@ int main(int argc, char* argv[])
ck::index_t K0 = 2048; ck::index_t K0 = 2048;
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -277,13 +263,43 @@ int main(int argc, char* argv[]) ...@@ -277,13 +263,43 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
G0 = std::stoi(argv[4]);
G1 = std::stoi(argv[5]);
M0 = std::stoi(argv[6]);
M1 = std::stoi(argv[7]);
N0 = std::stoi(argv[8]);
N1 = std::stoi(argv[9]);
K0 = std::stoi(argv[10]);
}
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n");
exit(0); exit(0);
} }
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
......
...@@ -29,12 +29,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -29,12 +29,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim.y, block_dim.y,
block_dim.z); block_dim.z);
const int nrepeat = 1; const int nrepeat = 100;
// printf("Warm up 1 time\n"); printf("Warm up 1 time\n");
// warm up // warm up
// kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
printf("Start running %d times...\n", nrepeat); printf("Start running %d times...\n", nrepeat);
......
...@@ -771,6 +771,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -771,6 +771,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
printf("DeviceOp: Arch check failure\n");
return false; return false;
} }
} }
...@@ -785,6 +786,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -785,6 +786,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_ctile_map_))
{ {
printf("GridwiseOp: Validity check failure\n");
return false; return false;
} }
...@@ -799,6 +801,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -799,6 +801,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if(!(arg.a_mz_stride_ == 1 && if(!(arg.a_mz_stride_ == 1 &&
arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{ {
printf("DeviceOp: Vector Access A-m check failure\n");
return false; return false;
} }
} }
...@@ -807,6 +810,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -807,6 +810,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if(!(arg.a_kz_stride_ == 1 && if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{ {
printf("DeviceOp: Vector Access A-k check failure\n");
return false; return false;
} }
} }
...@@ -817,6 +821,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -817,6 +821,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if(!(arg.b_nz_stride_ == 1 && if(!(arg.b_nz_stride_ == 1 &&
arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
{ {
printf("DeviceOp: Vector Access B-n check failure\n");
return false; return false;
} }
} }
...@@ -825,6 +830,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -825,6 +830,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if(!(arg.b_kz_stride_ == 1 && if(!(arg.b_kz_stride_ == 1 &&
arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{ {
printf("DeviceOp: Vector Access B-k check failure\n");
return false; return false;
} }
} }
...@@ -838,6 +844,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -838,6 +844,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
CDEShuffleBlockTransferScalarPerVector_NPerBlock == CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
0)) 0))
{ {
printf("DeviceOp: Vector Access D-n check failure\n");
valid_d_access = false; valid_d_access = false;
} }
}); });
...@@ -854,6 +861,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -854,6 +861,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
0) || 0) ||
CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1)) CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
{ {
printf("DeviceOp: Vector Access E-n check failure\n");
return false; return false;
} }
......
...@@ -352,6 +352,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -352,6 +352,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto N = b1_grid_desc_l0_n_l1.GetLength(I1); const auto N = b1_grid_desc_l0_n_l1.GetLength(I1);
printf("M = %d, L = %d, K = %d, N = %d\n", M, L, K, N);
const auto KPerBlock = K0PerBlock * K1Value; const auto KPerBlock = K0PerBlock * K1Value;
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
{ {
...@@ -730,7 +732,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -730,7 +732,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// dst Rowlane // dst Rowlane
// 0x76543210 0xfedcba98 // 0x76543210 0xfedcba98
// src Rowlane // src Rowlane
0x76543210, 0xfedcba98>{tensor_operation::element_wise::PassThrough{}}; 0x76543210, 0xfedcba98,
false>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy // B1 matrix blockwise copy
auto b1_blockwise_copy = auto b1_blockwise_copy =
......
...@@ -148,14 +148,12 @@ __global__ void ...@@ -148,14 +148,12 @@ __global__ void
const Block2CTileMap block_2_etile_map) const Block2CTileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
//printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
//printf("before compute_ptr_offset call");
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
...@@ -170,13 +168,9 @@ __global__ void ...@@ -170,13 +168,9 @@ __global__ void
DsPointer p_ds_grid_grp; DsPointer p_ds_grid_grp;
//printf("before allocate pointer d");
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
//printf("before entry");
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_ds_grid_grp, p_ds_grid_grp,
...@@ -469,16 +463,23 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -469,16 +463,23 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
if(!valid) if(!valid)
{ {
printf("GridwiseOp: D descriptor dimension check failure\n");
return false; return false;
} }
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2))) K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
{
printf("GridwiseOp: ABE descriptor dimension cross check failure\n");
return false; return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
{
printf("GridwiseOp: Problemsize descriptor dimension check failure\n");
return false; return false;
}
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = K0 / K0PerBlock; const auto num_k_loop = K0 / K0PerBlock;
...@@ -570,7 +571,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -570,7 +571,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
//printf("safe entry");
// clang-format off // clang-format off
/*******************************************************************************/ /*******************************************************************************/
// Memory buffer zone. // Memory buffer zone.
...@@ -716,7 +716,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -716,7 +716,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
c_thread_buf, c_thread_buf,
K0BlockMainLoop); K0BlockMainLoop);
/*******************************************************************************/ /*******************************************************************************/
//printf("safe 1");
// write out to C, implement shuffle // write out to C, implement shuffle
{ {
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
......
...@@ -1311,10 +1311,11 @@ template <typename SrcData, ...@@ -1311,10 +1311,11 @@ template <typename SrcData,
typename ElementwiseOperation, typename ElementwiseOperation,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
index_t DstScalarPerVector, index_t DstScalarPerVector,
uint32_t LowEightRowlaneIdx, uint32_t LowEightRowlaneIdx,
uint32_t HighEightRowLaneIdx, uint32_t HighEightRowLaneIdx,
bool IntraRowSwizzlePerm,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...@@ -1389,29 +1390,33 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1389,29 +1390,33 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v; SrcData v_this_row, v_theother_row;
// int type temp value due to intrinsic requirement
int temp = 0;
// apply element-wise operation // apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]); element_op_(v_this_row, src_buf[Number<src_offset>{}]);
// apply intra-row swizzle permute
if constexpr(IntraRowSwizzlePerm){
// origin: 0xfedcba98, 0x76543210
temp = __builtin_amdgcn_permlane16(temp, type_convert<int>(v_this_row), 0xeca86420, 0xfdb97531, 1, 0);
v_this_row = type_convert<float>(temp);
}
// apply inter-row permute.
temp = __builtin_amdgcn_permlanex16(temp, type_convert<int>(v_this_row), LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
v_theother_row = type_convert<float>(temp);
if(get_thread_local_1d_id() % 32 < 16){ if(get_thread_local_1d_id() % 32 < 16){
// apply type convert // apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v); dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v_this_row);
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v_theother_row);
} }
else{ else{
// apply type convert // apply type convert
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v); dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v_this_row);
} dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v_theother_row);
SrcData d = 0;
int temp = 0;
temp = __builtin_amdgcn_permlanex16(temp, type_convert<int>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
d = type_convert<float>(temp);
if(get_thread_local_1d_id() % 32 < 16){
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(d);
}
else{
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(d);
} }
}); });
}); });
......
...@@ -972,7 +972,6 @@ inline __host__ __device__ constexpr int type_convert<int, float>(float x) ...@@ -972,7 +972,6 @@ inline __host__ __device__ constexpr int type_convert<int, float>(float x)
float fp32; float fp32;
int int32; int int32;
} u = {x}; } u = {x};
// u.fp32 = x;
return u.int32; return u.int32;
} }
...@@ -985,7 +984,6 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x) ...@@ -985,7 +984,6 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x)
int int32; int int32;
float fp32; float fp32;
} u = {x}; } u = {x};
// u.fp32 = x;
return u.fp32; return u.fp32;
} }
......
...@@ -396,7 +396,7 @@ struct Tensor ...@@ -396,7 +396,7 @@ struct Tensor
} }
case 6: { case 6: {
auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) { auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) {
(*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4, i5); (*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5);
}; };
make_ParallelTensorFunctor(f, make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0], mDesc.GetLengths()[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment