"test/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "26941fa3777618dfc659e638e524b65f22dd32a6"
Commit fe797cf2 authored by ltqin's avatar ltqin
Browse files

rename device convolution file and function name

parent ad7bd495
...@@ -14,7 +14,7 @@ template <typename TInWei, ...@@ -14,7 +14,7 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw( void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths, const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths, const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths, const OutLengths& out_n_k_ho_wo_lengths,
......
...@@ -128,9 +128,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -128,9 +128,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
// const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc); // const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const auto a_b_k0_m_k1_grid_desc = GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc); const auto a_b_k0_m_k1_grid_desc = GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc);
const auto b_b_k0_n_k1_grid_desc = GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc); const auto b_b_k0_n_k1_grid_desc = GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc);
{
// std::cout << "k batch number is: " << kbatch << std::endl;
}
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error( throw std::runtime_error(
...@@ -150,7 +148,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -150,7 +148,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
{ {
std::cout << "gridSize : " << grid_size << grid_size << std::endl; std::cout << "gridSize : " << grid_size << std::endl;
} }
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB, FloatAB,
......
...@@ -13,16 +13,16 @@ ...@@ -13,16 +13,16 @@
#include "host_conv_bwd_weight.hpp" #include "host_conv_bwd_weight.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1 #define USE_CONV_WRW_V4R4R2_XDL_NCHW 1
#define USE_CONV_WRW_V4R4R3_XDL_NCHW 1 #define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 1
enum ConvBackwardWeightAlgo enum ConvBackwardWeightAlgo
{ {
V4R4R2XDLNCHW, V4R4R2XDLNCHW,
V4R4R3XDLNCHW, V4R4R2XDLATOMICNCHW,
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -262,8 +262,8 @@ int main(int argc, char* argv[]) ...@@ -262,8 +262,8 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_WRW_V4R4R3_XDL_NCHW #if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW
if(algo == ConvBackwardWeightAlgo::V4R4R3XDLNCHW) if(algo == ConvBackwardWeightAlgo::V4R4R2XDLATOMICNCHW)
{ {
if(layout != ConvTensorLayout::NCHW) if(layout != ConvTensorLayout::NCHW)
{ {
...@@ -272,20 +272,20 @@ int main(int argc, char* argv[]) ...@@ -272,20 +272,20 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw<in_data_t, device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw<
acc_data_t, in_data_t,
out_data_t>( acc_data_t,
tmp[I0], out_data_t>(tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
tmp[I3], tmp[I3],
tmp[I4], tmp[I4],
tmp[I5], tmp[I5],
tmp[I6], tmp[I6],
in, in,
wei_device, wei_device,
out, out,
nrepeat); nrepeat);
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment