Commit c32f0c3b authored by carlushuang's avatar carlushuang
Browse files

update xdnn desc as cmd line

parent 8d0e00a0
......@@ -12,6 +12,7 @@
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "envvar.hpp"
#include "xdnn_desc.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
......@@ -237,6 +238,39 @@ int main(int argc, char* argv[])
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(ck::getenv_int("CK_USE_XDNN_DESC", 0) == 1)
{
assert(argc == 4);
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
ck::desc_t xdnn_desc;
if(str2desc(&xdnn_desc, argv[3]) == XDNN_OK)
{
N = xdnn_desc.mb;
K = xdnn_desc.oc;
C = xdnn_desc.ic;
Y = xdnn_desc.kh;
X = xdnn_desc.kw;
Hi = xdnn_desc.ih;
Wi = xdnn_desc.iw;
conv_stride_h = xdnn_desc.sh;
conv_stride_w = xdnn_desc.sw;
conv_dilation_h = xdnn_desc.dh;
conv_dilation_w = xdnn_desc.dw;
in_left_pad_h = xdnn_desc.ph;
in_left_pad_w = xdnn_desc.pw;
in_right_pad_h = xdnn_desc.ph;
in_right_pad_w = xdnn_desc.pw;
}
else
{
printf("fail to parse xdnn arg:%s\n", argv[3]);
}
}
else
{
if(argc == 1)
{
data_type = 0;
......@@ -276,6 +310,7 @@ int main(int argc, char* argv[])
"RightPx\n");
exit(1);
}
}
auto Run = [&](auto input_type, auto wei_type, auto out_type) {
using InDataType = decltype(input_type);
......@@ -333,6 +368,8 @@ int main(int argc, char* argv[])
<< ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
<< ", Threads:" << omp_get_max_threads() << std::endl;
fflush(stdout);
int per_pixel_check = 0;
switch(init_method)
{
......
......@@ -12,6 +12,7 @@
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "envvar.hpp"
#include "xdnn_desc.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
......@@ -273,6 +274,39 @@ int main(int argc, char* argv[])
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(ck::getenv_int("CK_USE_XDNN_DESC", 0) == 1)
{
assert(argc == 4);
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
ck::desc_t xdnn_desc;
if(str2desc(&xdnn_desc, argv[3]) == XDNN_OK)
{
N = xdnn_desc.mb;
K = xdnn_desc.oc;
C = xdnn_desc.ic;
Y = xdnn_desc.kh;
X = xdnn_desc.kw;
Hi = xdnn_desc.ih;
Wi = xdnn_desc.iw;
conv_stride_h = xdnn_desc.sh;
conv_stride_w = xdnn_desc.sw;
conv_dilation_h = xdnn_desc.dh;
conv_dilation_w = xdnn_desc.dw;
in_left_pad_h = xdnn_desc.ph;
in_left_pad_w = xdnn_desc.pw;
in_right_pad_h = xdnn_desc.ph;
in_right_pad_w = xdnn_desc.pw;
}
else
{
printf("fail to parse xdnn arg:%s\n", argv[3]);
}
}
else
{
if(argc == 1)
{
data_type = 0;
......@@ -312,6 +346,7 @@ int main(int argc, char* argv[])
"RightPx\n");
exit(1);
}
}
auto Run = [&](auto input_type, auto wei_type, auto out_type) {
using InDataType = decltype(input_type);
......@@ -389,6 +424,8 @@ int main(int argc, char* argv[])
<< ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
<< ", Threads:" << omp_get_max_threads() << std::endl;
fflush(stdout);
int per_pixel_check = 0;
switch(init_method)
{
......
#pragma once
#include <string>
#include <vector>
#include <functional>
#define XDNN_OK 0
#define XDNN_FAIL 1
namespace ck {
int sanitize_desc(int& ndims,
std::vector<std::reference_wrapper<int64_t>> d,
std::vector<std::reference_wrapper<int64_t>> h,
std::vector<std::reference_wrapper<int64_t>> w,
const std::vector<int64_t>& def_values,
bool must_have_spatial)
{
size_t N = d.size();
assert(h.size() == N && w.size() == N && def_values.size() == N);
ndims = 5;
// check output spatial values
const bool no_d = d[0].get() == 0;
const bool no_h = h[0].get() == 0;
const bool no_w = w[0].get() == 0;
if(no_d)
ndims--;
if(no_d && no_h)
ndims--;
if(no_d && no_h && no_w)
ndims--;
if(must_have_spatial && ndims <= 2)
return XDNN_FAIL;
if(ndims == 5)
{
if(no_h && no_w)
{
// User specified values for the d dimension but not values for h
// and w dimensions. Propagate d values to h and w dimensions.
for(size_t n = 0; n < N; ++n)
w[n].get() = h[n].get() = d[n].get();
}
else if(!no_h && !no_w)
{
// User specified them all, good to go.
}
else
{
// Problem is not cubic and one of h or w dimension is missing.
return XDNN_FAIL;
}
}
else if(ndims == 4 && no_w)
{
// User specified values for the h dimension but not values for the w
// dimension. Propagate h values to the w dimension.
for(size_t n = 0; n < N; ++n)
w[n].get() = h[n].get();
}
for(size_t n = 0; n < N; ++n)
{
if(ndims < 5)
d[n].get() = def_values[n];
if(ndims < 4)
h[n].get() = def_values[n];
if(ndims < 3)
w[n].get() = def_values[n];
}
return XDNN_OK;
}
struct desc_t
{
int64_t g, mb;
int64_t ic, id, ih, iw;
int64_t oc, od, oh, ow;
int64_t kd, kh, kw;
int64_t sd, sh, sw;
int64_t pd, ph, pw;
int64_t pd_r, ph_r, pw_r; // End side padding for each dimension
int64_t dd, dh, dw;
bool has_groups;
const char* name;
int ndims;
// Initialize dependent opposite-side paddings values
// from the shape parameters
void init_pad_r(bool is_deconv)
{
pw_r = opp_pad(is_deconv, iw, ow, kw, sw, pw, dw);
ph_r = opp_pad(is_deconv, ih, oh, kh, sh, ph, dh);
pd_r = opp_pad(is_deconv, id, od, kd, sd, pd, dd);
}
int64_t desc_nelems(int arg, int mask) const;
private:
int64_t
opp_pad(bool is_deconv, int64_t i, int64_t o, int64_t k, int64_t s, int64_t p, int64_t d) const
{
return is_deconv ? (i - 1) * s - o + ((k - 1) * (d + 1) + 1) - p
: (o - 1) * s - i + ((k - 1) * (d + 1) + 1) - p;
}
};
static inline int str2desc(desc_t* desc, const char* str, bool is_deconv = false)
{
/* canonical form:
* gXmbX_icXidXihXiwX_ocXodXohXowX_kdXkhXkwX_sdXshXswX_pdXphXpwX_ddXdhXdwXnS
*
* where X is number, S - string
* note: symbol `_` is ignored
*
* implicit rules:
* - if smaller dimensions are not specified => square or cubic form;
* - if output is undefined => compute output;
* - if padding is undefined => compute trivial padding;
*/
desc_t d{0};
d.g = 1;
d.mb = 2;
d.sd = d.sh = d.sw = 1;
d.pd = d.ph = d.pw = -1;
const char* s = str;
assert(s);
#define CASE_NN(prb, c) \
do \
{ \
if(!strncmp(prb, s, strlen(prb))) \
{ \
ok = 1; \
s += strlen(prb); \
char* end_s; \
d.c = strtol(s, &end_s, 10); \
s += (end_s - s); \
/* check any # groups, including one, works correctly */ \
if(!strncmp(prb, "g", 1)) \
d.has_groups = true; \
if(d.c < 0) \
return XDNN_FAIL; \
/* printf("@@@debug: %s: %d\n", prb, d. c); */ \
} \
} while(0)
#define CASE_N(c) CASE_NN(#c, c)
while(*s)
{
int ok = 0;
CASE_N(g);
CASE_N(mb);
CASE_N(ic);
CASE_N(id);
CASE_N(ih);
CASE_N(iw);
CASE_N(oc);
CASE_N(od);
CASE_N(oh);
CASE_N(ow);
CASE_N(kd);
CASE_N(kh);
CASE_N(kw);
CASE_N(sd);
CASE_N(sh);
CASE_N(sw);
CASE_N(pd);
CASE_N(ph);
CASE_N(pw);
CASE_N(dd);
CASE_N(dh);
CASE_N(dw);
if(*s == 'n')
{
d.name = s + 1;
break;
}
if(*s == '_')
++s;
if(!ok)
return XDNN_FAIL;
}
#undef CASE_NN
#undef CASE_N
if(d.has_groups && d.g <= 0)
return XDNN_FAIL;
if(d.ic == 0 || d.oc == 0)
return XDNN_FAIL;
if(d.sd <= 0 || d.sh <= 0 || d.sw <= 0)
return XDNN_FAIL;
auto compute_out = [](bool is_deconv, int64_t i, int64_t k, int64_t s, int64_t p, int64_t d) {
if(is_deconv)
return (i - 1) * s + (k - 1) * (d + 1) - 2 * p + 1;
else
return (i - ((k - 1) * (d + 1) + 1) + 2 * p) / s + 1;
};
auto compute_pad = [](bool is_deconv, int64_t o, int64_t i, int64_t k, int64_t s, int64_t d) {
if(is_deconv)
return ((i - 1) * s - o + ((k - 1) * (d + 1) + 1)) / 2;
else
return ((o - 1) * s - i + ((k - 1) * (d + 1) + 1)) / 2;
};
const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1;
const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1;
const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1;
// printf("no_h:%d, no_w:%d, d.iw:%d\n", no_h, no_w, d.iw);
if(!no_d)
{
if(!d.id || !d.kd)
return XDNN_FAIL;
if(!d.od)
{
if(d.pd < 0)
d.pd = 0;
d.od = compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd);
if(d.od <= 0)
return XDNN_FAIL;
}
else if(d.pd < 0)
d.pd = compute_pad(is_deconv, d.od, d.id, d.kd, d.sd, d.dd);
}
if(!no_h)
{
if(!d.ih || !d.kh)
return XDNN_FAIL;
if(!d.oh)
{
if(d.ph < 0)
d.ph = 0;
d.oh = compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh);
if(d.oh <= 0)
return XDNN_FAIL;
}
else if(d.ph < 0)
d.ph = compute_pad(is_deconv, d.oh, d.ih, d.kh, d.sh, d.dh);
}
if(!no_w)
{
if(!d.iw || !d.kw)
return XDNN_FAIL;
if(!d.ow)
{
if(d.pw < 0)
d.pw = 0;
d.ow = compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw);
if(d.ow <= 0)
return XDNN_FAIL;
}
else if(d.pw < 0)
d.pw = compute_pad(is_deconv, d.ow, d.iw, d.kw, d.sw, d.dw);
}
if(sanitize_desc(d.ndims,
{d.od, d.id, d.kd, d.sd, d.pd, d.dd},
{d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
{d.ow, d.iw, d.kw, d.sw, d.pw, d.dw},
{1, 1, 1, 1, 0, 0},
true) != XDNN_OK)
return XDNN_FAIL;
d.init_pad_r(is_deconv);
*desc = d;
// TODO: this is difference CK~OneDNN
d.dh++;
d.dw++;
d.dd++;
return XDNN_OK;
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment