"composable_kernel/include/utility/array.hpp" did not exist on "8c385cf5cf25219d235e52be50d1d3f4a0a21f87"
Commit 5c27dcd5 authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp32 mfma instructions

parent 21755b5d
......@@ -11,7 +11,8 @@ namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GemmMPerBlock,
template <typename FloatAB,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
......@@ -109,7 +110,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<float, GemmMPerWave, GemmNPerWave, GemmKPerWave>{};
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPerWave>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
......
......@@ -34,7 +34,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPack>{};
static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerWave, NPerWave, KPack>{};
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
......
......@@ -306,14 +306,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(Sequence<0, 3>{}, Sequence<1, 2>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline<BlockSize,
FloatAB,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
KPack>{};
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
KPack>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
......@@ -327,7 +327,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
StaticBuffer<AddressSpace::Vgpr,
vector_type<float, c_blk_nb_bs_desc.GetElementSpaceSize()>,
vector_type<FloatAB, c_blk_nb_bs_desc.GetElementSpaceSize()>,
c_mr_nr_nx_desc.GetElementSpaceSize()>
c_thread_buf;
......@@ -488,7 +488,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_blk_buf_;
StaticBuffer<AddressSpace::Vgpr, FloatAB, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
......@@ -498,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(mr_i, nr_i, xdlops_i))>{}];
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) = c_blk.template AsType<float>()[Number<
c_blk_buf_(j) = c_blk.template AsType<FloatAB>()[Number<
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
});
......
......@@ -10,19 +10,19 @@ namespace ck {
enum struct mfma_instr
{
// fp32
/// fp32
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction
// fp16
/// fp16
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction
// bfp16
/// bfp16
mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16,
......@@ -58,7 +58,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -87,7 +87,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
return intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -110,17 +110,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_16x16x4f32(p_a, p_b, reg_c);
intrin_mfma_f32_16x16x4f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -143,17 +139,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -177,17 +169,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -523,20 +511,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
}
};
template <mfma_instr instr,
index_t MPerXdlops_,
index_t NPerXdlops_,
index_t MRepeats_,
index_t NRepeats_,
class CType_>
template <mfma_instr instr, index_t MPerXdlops_, index_t NPerXdlops_>
struct xdlops_info
{
static constexpr auto mfma_type = mfma_info<instr>{};
static constexpr index_t MPerXdlops = MPerXdlops_;
static constexpr index_t NPerXdlops = NPerXdlops_;
static constexpr index_t MRepeats = MRepeats_;
static constexpr index_t NRepeats = NRepeats_;
static constexpr bool IsABroadcast()
{
......@@ -555,8 +536,6 @@ struct xdlops_info
}
static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
static constexpr auto GetCType() { return CType_{}; }
};
template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack>
......@@ -570,55 +549,43 @@ struct XdlopsGemm
template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 1, float64_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 32, 1, 1, float32_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64, 1, 1, float32_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 64, 16, 1, 1, float16_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64, 1, 1, float16_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64, 1, 1, float8_t>{};
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64, 1, 1, float4_t>{};
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32, 1, 1, float16_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16, 1, 1, float4_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{};
}
#if 0
......
......@@ -201,42 +201,6 @@ extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x1f32;
// template <index_t AStride, index_t BStride>
// struct intrin_mfma_f32_32x32x1f32<128, 64, AStride, BStride>
//{
//__device__ static c_vec32_4_t::VecType
// run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
// reg_c.s.z =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
// reg_c.s.w =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
// return reg_c;
//}
//};
// template <index_t AStride, index_t BStride>
// struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride>
//{
//__device__ static c_vec32_4_t::VecType
// run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
// reg_c.s.z =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
// reg_c.s.w =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
// return reg_c;
//}
//};
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
{
......@@ -262,27 +226,22 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
}
};
// template <index_t AStride, index_t BStride>
// struct intrin_mfma_f32_32x32x1f32<64, 32, AStride, BStride>
//{
//__device__ static c_vec32_1_t::VecType
// run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
// return reg_c;
//}
//};
// template <index_t AStride, index_t BStride>
// struct intrin_mfma_f32_32x32x1f32<32, 64, AStride, BStride>
//{
//__device__ static c_vec32_1_t::VecType
// run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// return reg_c;
//}
//};
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
{
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x2f32;
......@@ -304,58 +263,89 @@ struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
}
};
__device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x4f32(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c)
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x4f32;
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_16x16x1f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c);
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x1f32;
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x1f32<16, 64>(const float* reg_a,
const float* reg_b,
c_vec16_1_t::VecType reg_c)
template <index_t COffset>
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x1f32<64, 16>(const float* reg_a,
const float* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave>
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_4x4x1f32;
template <>
struct intrin_mfma_f32_4x4x1f32<4, 64>
template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
{
__device__ static c_vec4_1_t::VecType
run(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c)
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
}
};
template <>
struct intrin_mfma_f32_4x4x1f32<8, 64>
template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
{
__device__ static c_vec4_2_t::VecType
run(const float* reg_a, const float* reg_b, c_vec4_2_t::VecType reg_c)
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
}
};
......
......@@ -49,6 +49,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
static_assert(1 == InWeiVectorSize, "support InWeiVectorSize == 1 only!");
#if 1
// run-time variables
const auto in_n_c_hi_wi_desc =
......@@ -108,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmMPerWave = 8;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerWave = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
constexpr index_t MRepeat = 8;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
......@@ -131,7 +133,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<GemmMPerBlock,
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<TInWei,
GemmMPerBlock,
GemmNPerBlock,
GemmMPerWave,
GemmNPerWave,
......@@ -148,7 +151,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
{
float ave_time = launch_kernel_dynamic_gemm_xdlops_v1<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TInWei,
TAcc,
TOut,
InMemoryDataOperation::Set,
......@@ -188,10 +191,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment