Commit 9a06cc7f authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parent be103129
...@@ -332,9 +332,9 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k, ...@@ -332,9 +332,9 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t CThreadTransferDstScalarPerVector = 1; constexpr index_t CThreadTransferDstScalarPerVector = 1;
#endif #endif
const index_t K = a_m_k.mDesc.GetLengths()[1]; const auto K = a_m_k.mDesc.GetLengths()[1];
const index_t M = a_m_k.mDesc.GetLengths()[0]; const auto M = a_m_k.mDesc.GetLengths()[0];
const index_t N = b_k_n.mDesc.GetLengths()[1]; const auto N = b_k_n.mDesc.GetLengths()[1];
constexpr auto K1Number = Number<K1>{}; constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number; const auto K0 = K / K1Number;
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "conv_common.hpp" #include "conv_common.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
//#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
...@@ -65,7 +65,7 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -65,7 +65,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v = 0; double v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
...@@ -77,18 +77,33 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -77,18 +77,33 @@ void host_convolution_forward(const Tensor<TIn>& in,
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3]) wi < in.mDesc.GetLengths()[3])
{ {
v += ck::type_convert<float>(in(n, c, hi, wi)) * if constexpr(is_same<TIn, ushort>::value)
ck::type_convert<float>(wei(k, c, y, x)); {
v += ck::type_convert<float>(in(n, c, hi, wi)) *
ck::type_convert<float>(wei(k, c, y, x));
}
else
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(wei(k, c, y, x));
}
} }
} }
} }
} }
out(n, k, ho, wo) = ck::type_convert<TOut>(v); if constexpr(is_same<TOut, ushort>::value)
{
out(n, k, ho, wo) = ck::type_convert<ushort>(static_cast<float>(v));
}
else
{
out(n, k, ho, wo) = v;
}
}; };
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
float v = 0; double v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
{ {
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
...@@ -100,14 +115,28 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -100,14 +115,28 @@ void host_convolution_forward(const Tensor<TIn>& in,
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2]) wi < in.mDesc.GetLengths()[2])
{ {
v += ck::type_convert<float>(in(n, hi, wi, c)) * if constexpr(is_same<TIn, ushort>::value)
ck::type_convert<float>(wei(k, y, x, c)); {
v += ck::type_convert<float>(in(n, hi, wi, c)) *
ck::type_convert<float>(wei(k, y, x, c));
}
else
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(wei(k, y, x, c));
}
} }
} }
} }
} }
if constexpr(is_same<TOut, ushort>::value)
out(n, ho, wo, k) = ck::type_convert<TOut>(v); {
out(n, ho, wo, k) = ck::type_convert<ushort>(static_cast<float>(v));
}
else
{
out(n, ho, wo, k) = v;
}
}; };
if(layout == ConvTensorLayout::NCHW) if(layout == ConvTensorLayout::NCHW)
...@@ -225,11 +254,11 @@ int main(int argc, char* argv[]) ...@@ -225,11 +254,11 @@ int main(int argc, char* argv[])
using in_data_t = float; using in_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 0 #elif 1
using in_data_t = half_t; using in_data_t = half_t;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = half_t; using out_data_t = half_t;
#elif 1 #elif 0
using in_data_t = ushort; using in_data_t = ushort;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = ushort; using out_data_t = ushort;
......
...@@ -3,7 +3,7 @@ rm -f CMakeCache.txt ...@@ -3,7 +3,7 @@ rm -f CMakeCache.txt
rm -f *.cmake rm -f *.cmake
rm -rf CMakeFiles rm -rf CMakeFiles
MY_PROJECT_SOURCE=../ MY_PROJECT_SOURCE=../../..
MY_PROJECT_INSTALL=../install.dir MY_PROJECT_INSTALL=../install.dir
cmake \ cmake \
......
...@@ -24,22 +24,22 @@ REPEAT=$7 ...@@ -24,22 +24,22 @@ REPEAT=$7
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment