Unverified Commit 16d7c4d2 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add grouped conv bwd weight wmma (#985)

* Add grouped conv bwd weight wmma

* Update README, changelog, profiler

* Minor fixes

* Fix grouped conv bwd wei dl kernel

* Minor fixes

* Minor stylistic fixes
parent 39430bfd
...@@ -14,7 +14,7 @@ None ...@@ -14,7 +14,7 @@ None
### Additions ### Additions
- Added an image to a column kernel (#867) - Added an image to a column kernel (#867)
- Added a column to an image kernel (#930) - Added a column to an image kernel (#930)
- Support for 3D grouped convolution forward on RDNA 3 GPUs (#935) - Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985)
- Grouped convolution support for small K and C (#822 #879 #897) - Grouped convolution support for small K and C (#822 #879 #897)
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) - Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
- Support for bf16/f32/f16 and NHWGC (2D and 3d) grouped convolution backward data (#757 #799) - Support for bf16/f32/f16 and NHWGC (2D and 3d) grouped convolution backward data (#757 #799)
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight) add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
if(result EQUAL 0) if(result EQUAL 0)
...@@ -18,6 +19,14 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -18,6 +19,14 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif() endif()
endif() endif()
set(target 1) set(target 1)
endif()
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
endif()
set(target 1)
endif() endif()
endforeach() endforeach()
......
...@@ -46,25 +46,21 @@ struct CommonLayoutSetting ...@@ -46,25 +46,21 @@ struct CommonLayoutSetting
using OutputLayout = OutputLay; using OutputLayout = OutputLay;
}; };
template <ck::index_t NDimSpatial>
struct CommonLayoutSettingSelector;
namespace ctl = ck::tensor_layout::convolution; namespace ctl = ck::tensor_layout::convolution;
template <ck::index_t NDimSpatial>
template <> struct CommonLayoutSettingSelector
struct CommonLayoutSettingSelector<1> final : CommonLayoutSetting<ctl::GNWC, ctl::GKXC, ctl::GNWK> : CommonLayoutSetting<ck::tuple_element_t<NDimSpatial - 1,
{ ck::Tuple<ck::tensor_layout::convolution::GNWC,
}; ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
template <> ck::tuple_element_t<NDimSpatial - 1,
struct CommonLayoutSettingSelector<2> final ck::Tuple<ck::tensor_layout::convolution::GKXC,
: CommonLayoutSetting<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK> ck::tensor_layout::convolution::GKYXC,
{ ck::tensor_layout::convolution::GKZYXC>>,
}; ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
template <> ck::tensor_layout::convolution::GNHWK,
struct CommonLayoutSettingSelector<3> final ck::tensor_layout::convolution::GNDHWK>>>
: CommonLayoutSetting<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>
{ {
}; };
...@@ -84,10 +80,10 @@ struct ExecutionConfig final ...@@ -84,10 +80,10 @@ struct ExecutionConfig final
bool time_kernel = false; bool time_kernel = false;
}; };
#define DefaultConvParam \ #define DefaultConvParam \
ck::utils::conv::ConvParam \ ck::utils::conv::ConvParam \
{ \ { \
2, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \ 3, 4, 1, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, { 1, 1, 1 } \
} }
inline void print_help_msg() inline void print_help_msg()
......
...@@ -76,4 +76,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe ...@@ -76,4 +76,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
using InDataType = F16;
using WeiDataType = F16;
using OutDataType = F16;
using AccDataType = F32;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
NDimSpatial,
ck::tensor_layout::convolution::GNDHWC,
ck::tensor_layout::convolution::GKZYXC,
ck::tensor_layout::convolution::GNDHWK,
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
16, // MPerWMMA
16, // NPerWMMA
4, // MRepeat
2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
4,
2,
S<1, 32, 1, 8>,
1>;
template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
#include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
...@@ -78,4 +78,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe ...@@ -78,4 +78,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
...@@ -77,4 +77,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe ...@@ -77,4 +77,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
...@@ -83,4 +83,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe ...@@ -83,4 +83,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } int main(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return 1;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
default: break;
}
return 1;
}
...@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial> ...@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param) const ck::utils::conv::ConvParam& conv_param)
{ {
// Dl op doesn't support split_k > 1 // Dl and WMMA ops don't support split_k > 1
constexpr ck::index_t split_k = 1; constexpr ck::index_t split_k = 1;
const auto in_g_n_c_wis_desc = const auto in_g_n_c_wis_desc =
...@@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
return true; return true;
} }
bool run_grouped_conv_bwd_weight_example(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
if(!parse_cmd_args(argc, argv, config, conv_param))
{
return false;
}
switch(conv_param.num_dim_spatial_)
{
case 1: return run_grouped_conv_bwd_weight<1>(config, conv_param);
case 2: return run_grouped_conv_bwd_weight<2>(config, conv_param);
case 3: return run_grouped_conv_bwd_weight<3>(config, conv_param);
}
return false;
}
...@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, ADataType,
BDataType, BDataType,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
...@@ -22,32 +23,6 @@ namespace ck { ...@@ -22,32 +23,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
OutElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
...@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>, remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch<I0>,
has_main_loop, has_main_loop,
has_double_loop>; has_double_loop>;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -21,32 +22,6 @@ namespace ck { ...@@ -21,32 +22,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
...@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch<I0>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
if constexpr(!is_GNWK_GKXC_GNWC) if constexpr(!is_GNWK_GKXC_GNWC)
......
...@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
GridwiseOp, GridwiseOp,
ADataType, ADataType,
BDataType, BDataType,
......
...@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
return ds_offset; return ds_offset;
} }
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
// alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return g_idx * static_cast<long_index_t>(BatchStrideE_);
} }
...@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
index_t BatchStrideB_; index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
} // namespace device } // namespace device
......
...@@ -36,7 +36,7 @@ __global__ void ...@@ -36,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle( kernel_grouped_conv_multiple_d_wmma_cshuffle(
const ADataType* __restrict__ p_a_grid, const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
...@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// CheckValidity for kernels without multi D
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
...@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2))) K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
...@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return true; return true;
} }
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
return CheckValidity(
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n, block_2_ctile_map);
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t num_loop = K / (K0PerBlock * K1);
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
...@@ -163,6 +163,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances ...@@ -163,6 +163,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
...@@ -177,6 +201,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances ...@@ -177,6 +201,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3, std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
...@@ -202,6 +251,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances ...@@ -202,6 +251,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
...@@ -231,6 +304,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_ ...@@ -231,6 +304,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
BF8, BF8,
F8>>>& instances); F8>>>& instances);
#endif #endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef DL_KERNELS #ifdef DL_KERNELS
// dl // dl
...@@ -694,6 +792,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -694,6 +792,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif #endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
...@@ -708,6 +810,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -708,6 +810,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs); op_ptrs);
} }
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
op_ptrs);
}
#endif #endif
} }
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> && else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
...@@ -737,6 +849,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -737,6 +849,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif #endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
...@@ -752,6 +868,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -752,6 +868,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs); op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<OutDataType, half_t> &&
......
set(GROUPED_CONV1D_BWD_WEIGHT set(GROUPED_CONV1D_BWD_WEIGHT
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp) xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp)
if(DL_KERNELS) if(DL_KERNELS)
list(APPEND GROUPED_CONV1D_BWD_WEIGHT list(APPEND GROUPED_CONV1D_BWD_WEIGHT
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp) dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp)
endif() endif()
add_instance_library(device_grouped_conv1d_bwd_weight_instance ${GROUPED_CONV1D_BWD_WEIGHT}) add_instance_library(device_grouped_conv1d_bwd_weight_instance ${GROUPED_CONV1D_BWD_WEIGHT})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment