"include/vscode:/vscode.git/clone" did not exist on "29053edd9ef87bbef6d9ecdb57fd93aa095b6274"
Commit 3ba485b6 authored by Jing Zhang's avatar Jing Zhang
Browse files

resolve merge conflicts

parents 04c1aa31 a3c80265
...@@ -198,7 +198,9 @@ template <index_t NDimSpatial, ...@@ -198,7 +198,9 @@ template <index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler(),
typename AComputeType = ADataType,
typename BComputeType = AComputeType>
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial, : public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
ALayout, // output image ALayout, // output image
...@@ -211,7 +213,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -211,7 +213,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
EDataType, // input image EDataType, // input image
AElementwiseOp, AElementwiseOp,
BElementwiseOp, BElementwiseOp,
CDEElementwiseOp> CDEElementwiseOp,
AComputeType,
BComputeType>
{ {
// TODO: Extend support for more spatial dimensions. // TODO: Extend support for more spatial dimensions.
static_assert(NDimSpatial == 2 || NDimSpatial == 3, static_assert(NDimSpatial == 2 || NDimSpatial == 3,
...@@ -312,9 +316,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -312,9 +316,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ABDataType, // TODO: distinguish A/B datatype ABDataType,
ABDataType, // TODO: distinguish A/B datatype ABDataType,
ABDataType, // TODO: distinguish A/B datatype AComputeType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -354,7 +358,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -354,7 +358,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched,
PipelineVersion::v1,
BComputeType>;
template <typename Desc_K0_M_K1> template <typename Desc_K0_M_K1>
static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1)
......
...@@ -48,7 +48,8 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -48,7 +48,8 @@ struct ComputePtrOffsetOfStridedBatch
} // namespace } // namespace
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatC, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -64,8 +65,8 @@ __global__ void ...@@ -64,8 +65,8 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_xdlops_bwd_weight( kernel_batched_gemm_xdlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid, const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -91,7 +92,7 @@ __global__ void ...@@ -91,7 +92,7 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
...@@ -163,7 +164,9 @@ template <ck::index_t NDimSpatial, ...@@ -163,7 +164,9 @@ template <ck::index_t NDimSpatial,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial, : public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout, InLayout,
...@@ -174,7 +177,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -174,7 +177,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
OutDataType, OutDataType,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation,
ComputeTypeA,
ComputeTypeB>
{ {
using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle;
...@@ -1045,7 +1050,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1045,7 +1050,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
...@@ -1090,7 +1096,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1090,7 +1096,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true, true,
true>; true,
1,
PipelineVersion::v1,
ComputeTypeA,
ComputeTypeB>;
// Argument // Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
...@@ -1217,8 +1227,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1217,8 +1227,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
InElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
OutElementwiseOperation b_element_op_; InElementwiseOperation b_element_op_;
WeiElementwiseOperation c_element_op_; WeiElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
...@@ -1281,7 +1291,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1281,7 +1291,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType,
BDataType,
CDataType, CDataType,
OutElementwiseOperation, OutElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
......
...@@ -211,7 +211,8 @@ template <index_t NDimSpatial, ...@@ -211,7 +211,8 @@ template <index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> typename ComputeDataType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleD<NDimSpatial, : public DeviceGroupedConvFwdMultipleD<NDimSpatial,
ALayout, ALayout,
...@@ -224,7 +225,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -224,7 +225,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation,
ComputeDataType>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle; using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle;
...@@ -323,8 +325,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -323,8 +325,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using ComputeDataType = ADataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
......
...@@ -186,25 +186,6 @@ struct Bilinear ...@@ -186,25 +186,6 @@ struct Bilinear
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1)); y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
}; };
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x0_tmp = type_convert<float>(x0);
const float x1_tmp = type_convert<float>(x1);
const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
y = type_convert<bhalf_t>(y_tmp);
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
y = y_tmp;
};
template <> template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>( __host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
......
...@@ -33,12 +33,6 @@ struct PassThrough ...@@ -33,12 +33,6 @@ struct PassThrough
y = type_convert<float>(x); y = type_convert<float>(x);
} }
template <>
__host__ __device__ void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
}
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -75,12 +69,6 @@ struct PassThrough ...@@ -75,12 +69,6 @@ struct PassThrough
y = type_convert<bhalf_t>(x); y = type_convert<bhalf_t>(x);
} }
template <>
__host__ __device__ void operator()<float, bhalf_t>(float& y, const bhalf_t& x) const
{
y = type_convert<float>(x);
}
template <> template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const __host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{ {
...@@ -156,6 +144,38 @@ struct PassThrough ...@@ -156,6 +144,38 @@ struct PassThrough
y = type_convert<f8_t>(x); y = type_convert<f8_t>(x);
} }
#endif #endif
#if defined CK_ENABLE_BF8
template <>
__host__ __device__ void operator()<bf8_t, bf8_t>(bf8_t& y, const bf8_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, bf8_t>(float& y, const bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<bf8_t, float>(bf8_t& y, const float& x) const
{
y = type_convert<bf8_t>(x);
}
template <>
__host__ __device__ void operator()<half_t, bf8_t>(half_t& y, const bf8_t& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
{
y = ck::type_convert<bf8_t>(x);
}
#endif
}; };
struct UnaryConvert struct UnaryConvert
...@@ -210,20 +230,6 @@ struct Scale ...@@ -210,20 +230,6 @@ struct Scale
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = ck::type_convert<half_t>(scale_) * x;
};
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
const float x_tmp = ck::type_convert<float>(x);
const float y_tmp = scale_ * x_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
};
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
......
...@@ -522,6 +522,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -522,6 +522,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ABDataType, ABDataType,
ABDataType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -628,7 +628,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -628,7 +628,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
Gemm1KPack, Gemm1KPack,
false, // TransposeC false, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
......
...@@ -880,7 +880,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -880,7 +880,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
Gemm1KPack, Gemm1KPack,
false, // TransposeC false, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl, Gemm1KPack, false>{} Gemm1KPack * XdlopsGemm<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm1KPack,
A0B0B1DataType,
false>{}
.K0PerXdlops>{ // BMmaKStride .K0PerXdlops>{ // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
......
...@@ -794,7 +794,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -794,7 +794,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
Gemm1KPack, Gemm1KPack,
true, // TransposeC true, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
......
...@@ -649,7 +649,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -649,7 +649,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Gemm1KPack, Gemm1KPack,
true, // TransposeC true, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
......
...@@ -504,6 +504,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -504,6 +504,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
[&](auto i) { [&](auto i) {
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>; using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
return MakeAGridDescriptor_M_K<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]); return MakeAGridDescriptor_M_N<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]);
}, },
Number<NumATensor>{}); Number<NumATensor>{});
} }
......
...@@ -470,6 +470,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -470,6 +470,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment