Unverified Commit b3e8d57d authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Tweak GEMM kernel (#38)

* add parameters

* tweak gemm

* tweak

* update conv

* update script

* adding bwd 1x1

* update script

* adding 1x1 bwd

* debugging bwd 1x1 failure

* update script

* update script

* test

* test v100

* clean up
parent 846f462b
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "debug.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -24,7 +25,7 @@ ...@@ -24,7 +25,7 @@
#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
enum ConvForwardAlgo enum ConvForwardAlgo
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "debug.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -111,7 +112,7 @@ int main(int argc, char* argv[]) ...@@ -111,7 +112,7 @@ int main(int argc, char* argv[])
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
#endif #endif
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "debug.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -16,11 +17,19 @@ ...@@ -16,11 +17,19 @@
#include "device_gemm_xdlops_mk_nk_mn.hpp" #include "device_gemm_xdlops_mk_nk_mn.hpp"
#include "device_gemm_xdlops_km_kn_mn.hpp" #include "device_gemm_xdlops_km_kn_mn.hpp"
#include "device_gemm_xdlops_km_nk_mn.hpp" #include "device_gemm_xdlops_km_nk_mn.hpp"
#include "device_gemm_xdlops_mk_kn_nm.hpp"
#include "device_gemm_xdlops_mk_nk_nm.hpp"
#include "device_gemm_xdlops_km_kn_nm.hpp"
#include "device_gemm_xdlops_km_nk_nm.hpp"
#define USE_GEMM_XDL_MK_KN_MN 1 #define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1 #define USE_GEMM_XDL_MK_NK_MN 1
#define USE_GEMM_XDL_KM_KN_MN 1 #define USE_GEMM_XDL_KM_KN_MN 1
#define USE_GEMM_XDL_KM_NK_MN 1 #define USE_GEMM_XDL_KM_NK_MN 1
#define USE_GEMM_XDL_MK_KN_NM 0
#define USE_GEMM_XDL_MK_NK_NM 0
#define USE_GEMM_XDL_KM_KN_NM 0
#define USE_GEMM_XDL_KM_NK_NM 0
enum GemmAlgo enum GemmAlgo
{ {
...@@ -28,21 +37,21 @@ enum GemmAlgo ...@@ -28,21 +37,21 @@ enum GemmAlgo
Xdl_MK_NK_MN, // 1 Xdl_MK_NK_MN, // 1
Xdl_KM_KN_MN, // 2 Xdl_KM_KN_MN, // 2
Xdl_KM_NK_MN, // 3 Xdl_KM_NK_MN, // 3
Xdl_MK_KN_NM, // 4
Xdl_MK_NK_NM, // 5
Xdl_KM_KN_NM, // 6
Xdl_KM_NK_NM, // 7
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
constexpr auto I0 = Number<0>{}; if(argc != 12)
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
// dynamic mode
if(argc != 10)
{ {
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
printf("rest: M, N, K\n"); printf("rest: M, N, K\n");
printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n");
exit(1); exit(1);
} }
...@@ -57,6 +66,9 @@ int main(int argc, char* argv[]) ...@@ -57,6 +66,9 @@ int main(int argc, char* argv[])
const index_t N = std::stoi(argv[8]); const index_t N = std::stoi(argv[8]);
const index_t K = std::stoi(argv[9]); const index_t K = std::stoi(argv[9]);
debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]);
debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]);
#if 0 #if 0
using ab_data_t = float; using ab_data_t = float;
using acc_data_t = float; using acc_data_t = float;
...@@ -74,69 +86,44 @@ int main(int argc, char* argv[]) ...@@ -74,69 +86,44 @@ int main(int argc, char* argv[])
std::vector<std::size_t> a_lengths_host(2), b_lengths_host(2), c_lengths_host(2); std::vector<std::size_t> a_lengths_host(2), b_lengths_host(2), c_lengths_host(2);
std::vector<std::size_t> a_strides_host(2), b_strides_host(2), c_strides_host(2); std::vector<std::size_t> a_strides_host(2), b_strides_host(2), c_strides_host(2);
if(layout == GemmMatrixLayout::MK_KN_MN) // A
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN ||
layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM)
{ {
a_lengths_host[0] = static_cast<std::size_t>(M); a_lengths_host[0] = static_cast<std::size_t>(M);
a_lengths_host[1] = static_cast<std::size_t>(K); a_lengths_host[1] = static_cast<std::size_t>(K);
a_strides_host[0] = static_cast<std::size_t>(K); a_strides_host[0] = static_cast<std::size_t>(K);
a_strides_host[1] = static_cast<std::size_t>(1); a_strides_host[1] = static_cast<std::size_t>(1);
b_lengths_host[0] = static_cast<std::size_t>(K);
b_lengths_host[1] = static_cast<std::size_t>(N);
b_strides_host[0] = static_cast<std::size_t>(N);
b_strides_host[1] = static_cast<std::size_t>(1);
c_lengths_host[0] = static_cast<std::size_t>(M);
c_lengths_host[1] = static_cast<std::size_t>(N);
c_strides_host[0] = static_cast<std::size_t>(N);
c_strides_host[1] = static_cast<std::size_t>(1);
} }
else if(layout == GemmMatrixLayout::MK_NK_MN) else
{ {
a_lengths_host[0] = static_cast<std::size_t>(M); a_lengths_host[0] = static_cast<std::size_t>(K);
a_lengths_host[1] = static_cast<std::size_t>(K); a_lengths_host[1] = static_cast<std::size_t>(M);
a_strides_host[0] = static_cast<std::size_t>(K); a_strides_host[0] = static_cast<std::size_t>(M);
a_strides_host[1] = static_cast<std::size_t>(1); a_strides_host[1] = static_cast<std::size_t>(1);
}
// B
if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN ||
layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM)
{
b_lengths_host[0] = static_cast<std::size_t>(N); b_lengths_host[0] = static_cast<std::size_t>(N);
b_lengths_host[1] = static_cast<std::size_t>(K); b_lengths_host[1] = static_cast<std::size_t>(K);
b_strides_host[0] = static_cast<std::size_t>(K); b_strides_host[0] = static_cast<std::size_t>(K);
b_strides_host[1] = static_cast<std::size_t>(1); b_strides_host[1] = static_cast<std::size_t>(1);
c_lengths_host[0] = static_cast<std::size_t>(M);
c_lengths_host[1] = static_cast<std::size_t>(N);
c_strides_host[0] = static_cast<std::size_t>(N);
c_strides_host[1] = static_cast<std::size_t>(1);
} }
else if(layout == GemmMatrixLayout::KM_KN_MN) else
{ {
a_lengths_host[0] = static_cast<std::size_t>(K);
a_lengths_host[1] = static_cast<std::size_t>(M);
a_strides_host[0] = static_cast<std::size_t>(M);
a_strides_host[1] = static_cast<std::size_t>(1);
b_lengths_host[0] = static_cast<std::size_t>(K); b_lengths_host[0] = static_cast<std::size_t>(K);
b_lengths_host[1] = static_cast<std::size_t>(N); b_lengths_host[1] = static_cast<std::size_t>(N);
b_strides_host[0] = static_cast<std::size_t>(N); b_strides_host[0] = static_cast<std::size_t>(N);
b_strides_host[1] = static_cast<std::size_t>(1); b_strides_host[1] = static_cast<std::size_t>(1);
c_lengths_host[0] = static_cast<std::size_t>(M);
c_lengths_host[1] = static_cast<std::size_t>(N);
c_strides_host[0] = static_cast<std::size_t>(N);
c_strides_host[1] = static_cast<std::size_t>(1);
} }
else if(layout == GemmMatrixLayout::KM_NK_MN)
{
a_lengths_host[0] = static_cast<std::size_t>(K);
a_lengths_host[1] = static_cast<std::size_t>(M);
a_strides_host[0] = static_cast<std::size_t>(M);
a_strides_host[1] = static_cast<std::size_t>(1);
b_lengths_host[0] = static_cast<std::size_t>(N);
b_lengths_host[1] = static_cast<std::size_t>(K);
b_strides_host[0] = static_cast<std::size_t>(K);
b_strides_host[1] = static_cast<std::size_t>(1);
// C
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN ||
layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN)
{
c_lengths_host[0] = static_cast<std::size_t>(M); c_lengths_host[0] = static_cast<std::size_t>(M);
c_lengths_host[1] = static_cast<std::size_t>(N); c_lengths_host[1] = static_cast<std::size_t>(N);
c_strides_host[0] = static_cast<std::size_t>(N); c_strides_host[0] = static_cast<std::size_t>(N);
...@@ -144,7 +131,10 @@ int main(int argc, char* argv[]) ...@@ -144,7 +131,10 @@ int main(int argc, char* argv[])
} }
else else
{ {
std::runtime_error("wrong! not implemented"); c_lengths_host[0] = static_cast<std::size_t>(N);
c_lengths_host[1] = static_cast<std::size_t>(M);
c_strides_host[0] = static_cast<std::size_t>(M);
c_strides_host[1] = static_cast<std::size_t>(1);
} }
Tensor<ab_data_t> a(a_lengths_host, a_strides_host); Tensor<ab_data_t> a(a_lengths_host, a_strides_host);
...@@ -185,38 +175,6 @@ int main(int argc, char* argv[]) ...@@ -185,38 +175,6 @@ int main(int argc, char* argv[])
b.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread); b.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
} }
auto f_make_for_device_mk_kn_mn = [&]() {
const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1));
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return make_tuple(a_desc, b_desc, c_desc);
};
auto f_make_for_device_mk_nk_mn = [&]() {
const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1));
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return make_tuple(a_desc, b_desc, c_desc);
};
auto f_make_for_device_km_kn_mn = [&]() {
const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1));
const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1));
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return make_tuple(a_desc, b_desc, c_desc);
};
auto f_make_for_device_km_nk_mn = [&]() {
const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1));
const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1));
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return make_tuple(a_desc, b_desc, c_desc);
};
#if USE_GEMM_XDL_MK_KN_MN #if USE_GEMM_XDL_MK_KN_MN
if(algo == GemmAlgo::Xdl_MK_KN_MN) if(algo == GemmAlgo::Xdl_MK_KN_MN)
{ {
...@@ -225,10 +183,7 @@ int main(int argc, char* argv[]) ...@@ -225,10 +183,7 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! layout"); throw std::runtime_error("wrong! layout");
} }
const auto descs = f_make_for_device_mk_kn_mn(); device_gemm_xdlops_mk_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
device_gemm_xdlops_mk_kn_mn<ab_data_t, acc_data_t, c_data_t>(
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
} }
#endif #endif
...@@ -240,10 +195,7 @@ int main(int argc, char* argv[]) ...@@ -240,10 +195,7 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! layout"); throw std::runtime_error("wrong! layout");
} }
const auto descs = f_make_for_device_mk_nk_mn(); device_gemm_xdlops_mk_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
device_gemm_xdlops_mk_nk_mn<ab_data_t, acc_data_t, c_data_t>(
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
} }
#endif #endif
...@@ -255,10 +207,7 @@ int main(int argc, char* argv[]) ...@@ -255,10 +207,7 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! layout"); throw std::runtime_error("wrong! layout");
} }
const auto descs = f_make_for_device_km_kn_mn(); device_gemm_xdlops_km_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
device_gemm_xdlops_km_kn_mn<ab_data_t, acc_data_t, c_data_t>(
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
} }
#endif #endif
...@@ -270,10 +219,55 @@ int main(int argc, char* argv[]) ...@@ -270,10 +219,55 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! layout"); throw std::runtime_error("wrong! layout");
} }
const auto descs = f_make_for_device_km_nk_mn(); device_gemm_xdlops_km_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_MK_KN_NM
if(algo == GemmAlgo::Xdl_MK_KN_NM)
{
if(layout != GemmMatrixLayout::MK_KN_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_mk_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_MK_NK_NM
if(algo == GemmAlgo::Xdl_MK_NK_NM)
{
if(layout != GemmMatrixLayout::MK_NK_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_mk_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_KM_KN_NM
if(algo == GemmAlgo::Xdl_KM_KN_NM)
{
if(layout != GemmMatrixLayout::KM_KN_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_km_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_KM_NK_NM
if(algo == GemmAlgo::Xdl_KM_NK_NM)
{
if(layout != GemmMatrixLayout::KM_NK_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_km_nk_mn<ab_data_t, acc_data_t, c_data_t>( device_gemm_xdlops_km_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
} }
#endif #endif
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#define DEVICE_HPP #define DEVICE_HPP
#include <memory> #include <memory>
#include <thread>
#include <chrono>
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
...@@ -74,6 +76,8 @@ float launch_and_time_kernel( ...@@ -74,6 +76,8 @@ float launch_and_time_kernel(
timer.End(); timer.End();
// std::this_thread::sleep_for (std::chrono::microseconds(10));
return timer.GetElapsedTime() / nrepeat; return timer.GetElapsedTime() / nrepeat;
} }
......
...@@ -7,6 +7,10 @@ enum GemmMatrixLayout ...@@ -7,6 +7,10 @@ enum GemmMatrixLayout
MK_NK_MN, // 1 MK_NK_MN, // 1
KM_KN_MN, // 2 KM_KN_MN, // 2
KM_NK_MN, // 3 KM_NK_MN, // 3
MK_KN_NM, // 4
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
}; };
#endif #endif
...@@ -80,6 +80,78 @@ void host_gemm(const Tensor<AType>& a, ...@@ -80,6 +80,78 @@ void host_gemm(const Tensor<AType>& a,
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
} }
else if(layout == GemmMatrixLayout::MK_KN_NM)
{
auto f_mk_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_NM)
{
auto f_mk_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_NM)
{
auto f_km_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_NM)
{
auto f_km_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else else
{ {
throw std::runtime_error("wrong! not supported layout"); throw std::runtime_error("wrong! not supported layout");
......
WORKSPACE=$1
echo "workspace: " $WORKSPACE
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v $WORKSPACE:/root/workspace \
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
#--network host \
...@@ -4,24 +4,12 @@ ...@@ -4,24 +4,12 @@
export ROCR_VISIBLE_DEVICE=0 export ROCR_VISIBLE_DEVICE=0
export GPU_DEVICE_ORDINAL=0 export GPU_DEVICE_ORDINAL=0
## Boost make -j conv_fwd_driver_offline
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
## Compiling
#export OLC_DEBUG_HIP_VERBOSE=1
#export OLC_DEBUG_HIP_DUMP=1
#export OLC_DEBUG_SAVE_TEMP_DIR=1
#rm -rf /root/_hip_binary_kernels_/
#rm -rf /tmp/olCompile*
#make -j conv_fwd_driver_offline
#make -j conv_bwd_driver_offline #make -j conv_bwd_driver_offline
#make -j conv_wrw_driver_offline #make -j conv_wrw_driver_offline
#make -j conv_fwd_driver_online #make -j gemm_driver_offline
make -j gemm_driver_offline
DRIVER="./host/driver_offline/conv_fwd_driver_offline"
LAYOUT=$1 LAYOUT=$1
ALGO=$2 ALGO=$2
VERIFY=$3 VERIFY=$3
...@@ -29,30 +17,121 @@ INIT=$4 ...@@ -29,30 +17,121 @@ INIT=$4
LOG=$5 LOG=$5
REPEAT=$6 REPEAT=$6
################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads #M01=$7
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 #N01=$8
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 KBATCH=$7
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 ######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1
######### layout algo verify init log repeat M___ N___ K___
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01
# Resnet50
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1
# 256x128x32 c64
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 # 128x128x32 c64
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH
#./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#./host/driver_offline/conv_wrw_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 # 128x64x32 c64
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 # 64x128x32 c64
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
################################################ layout algo verify init log repeat M___ N___ K___ # 64x64x32 c32
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112
./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment