Commit 9a06cc7f authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parent be103129
......@@ -332,9 +332,9 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#endif
const index_t K = a_m_k.mDesc.GetLengths()[1];
const index_t M = a_m_k.mDesc.GetLengths()[0];
const index_t N = b_k_n.mDesc.GetLengths()[1];
const auto K = a_m_k.mDesc.GetLengths()[1];
const auto M = a_m_k.mDesc.GetLengths()[0];
const auto N = b_k_n.mDesc.GetLengths()[1];
constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
......
......@@ -12,10 +12,10 @@
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "device_tensor.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
//#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
......@@ -65,7 +65,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
constexpr auto I1 = Number<1>{};
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v = 0;
double v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
......@@ -76,19 +76,34 @@ void host_convolution_forward(const Tensor<TIn>& in,
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
if constexpr(is_same<TIn, ushort>::value)
{
v += ck::type_convert<float>(in(n, c, hi, wi)) *
ck::type_convert<float>(wei(k, c, y, x));
}
else
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(wei(k, c, y, x));
}
}
}
}
}
out(n, k, ho, wo) = ck::type_convert<TOut>(v);
if constexpr(is_same<TOut, ushort>::value)
{
out(n, k, ho, wo) = ck::type_convert<ushort>(static_cast<float>(v));
}
else
{
out(n, k, ho, wo) = v;
}
};
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
float v = 0;
double v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
{
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
......@@ -99,15 +114,29 @@ void host_convolution_forward(const Tensor<TIn>& in,
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2])
{
if constexpr(is_same<TIn, ushort>::value)
{
v += ck::type_convert<float>(in(n, hi, wi, c)) *
ck::type_convert<float>(wei(k, y, x, c));
}
else
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(wei(k, y, x, c));
}
}
}
out(n, ho, wo, k) = ck::type_convert<TOut>(v);
}
}
if constexpr(is_same<TOut, ushort>::value)
{
out(n, ho, wo, k) = ck::type_convert<ushort>(static_cast<float>(v));
}
else
{
out(n, ho, wo, k) = v;
}
};
if(layout == ConvTensorLayout::NCHW)
......@@ -225,11 +254,11 @@ int main(int argc, char* argv[])
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif 0
#elif 1
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;
#elif 1
#elif 0
using in_data_t = ushort;
using acc_data_t = float;
using out_data_t = ushort;
......
......@@ -3,7 +3,7 @@ rm -f CMakeCache.txt
rm -f *.cmake
rm -rf CMakeFiles
MY_PROJECT_SOURCE=../
MY_PROJECT_SOURCE=../../..
MY_PROJECT_INSTALL=../install.dir
cmake \
......
......@@ -24,22 +24,22 @@ REPEAT=$7
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment