"src/vscode:/vscode.git/clone" did not exist on "d720b2132e74a16cd44f98947e667e4a4442adc5"
Commit fe797cf2 authored by ltqin's avatar ltqin
Browse files

rename device convolution file and function name

parent ad7bd495
......@@ -14,7 +14,7 @@ template <typename TInWei,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw(
void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
......
......@@ -128,9 +128,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
// const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const auto a_b_k0_m_k1_grid_desc = GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc);
const auto b_b_k0_n_k1_grid_desc = GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc);
{
// std::cout << "k batch number is: " << kbatch << std::endl;
}
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{
throw std::runtime_error(
......@@ -150,7 +148,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
{
std::cout << "gridSize : " << grid_size << grid_size << std::endl;
std::cout << "gridSize : " << grid_size << std::endl;
}
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB,
......
......@@ -13,16 +13,16 @@
#include "host_conv_bwd_weight.hpp"
#include "device_tensor.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1
#define USE_CONV_WRW_V4R4R3_XDL_NCHW 1
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 1
enum ConvBackwardWeightAlgo
{
V4R4R2XDLNCHW,
V4R4R3XDLNCHW,
V4R4R2XDLATOMICNCHW,
};
int main(int argc, char* argv[])
......@@ -262,8 +262,8 @@ int main(int argc, char* argv[])
}
#endif
#if USE_CONV_WRW_V4R4R3_XDL_NCHW
if(algo == ConvBackwardWeightAlgo::V4R4R3XDLNCHW)
#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW
if(algo == ConvBackwardWeightAlgo::V4R4R2XDLATOMICNCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
......@@ -272,10 +272,10 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw();
device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw<in_data_t,
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw<
in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
out_data_t>(tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment