"driver/src/conv_driver.cpp" did not exist on "7faf269c995e5594935a16dfdae75a49f62f4991"
Commit 788f4786 authored by Chao Liu's avatar Chao Liu
Browse files

tweak

parent 5a79ff1e
......@@ -26,14 +26,74 @@ struct PassThrough
struct BiasReluAdd
{
template <typename T1, typename T2>
__host__ __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this spill register
float a = v0 + v1;
float b = float(0.1) * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this use lots of registers (but no spill)
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = b > 0 ? b : 0;
float d = alpha * (a + c);
return d;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * (a + c);
return d;
#elif 1
// this spill registers, 89 Tflops
float a = v0 + v1;
float alpha = 0.1;
float b;
asm volatile("\n \
v_mul_f32_e32 %0, %1, %2 \n \
"
: "=v"(b)
: "s"(alpha), "v"(a));
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#endif
}
};
......@@ -114,53 +174,63 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
int main(int argc, char* argv[])
{
// if(argc != 4)
bool do_verification = 0;
int init_method = 0;
int nrepeat = 5;
// Conv shape
ck::index_t N = 128;
ck::index_t K = 256;
ck::index_t C = 192;
ck::index_t Y = 3;
ck::index_t X = 3;
ck::index_t Hi = 71;
ck::index_t Wi = 71;
ck::index_t conv_stride_h = 2;
ck::index_t conv_stride_w = 2;
ck::index_t conv_dilation_h = 1;
ck::index_t conv_dilation_w = 1;
ck::index_t in_left_pad_h = 1;
ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
}
if(argc == 19)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
C = std::stoi(argv[6]);
Y = std::stoi(argv[7]);
X = std::stoi(argv[8]);
Hi = std::stoi(argv[9]);
Wi = std::stoi(argv[10]);
conv_stride_h = std::stoi(argv[11]);
conv_stride_w = std::stoi(argv[12]);
conv_dilation_h = std::stoi(argv[13]);
conv_dilation_w = std::stoi(argv[14]);
in_left_pad_h = std::stoi(argv[15]);
in_left_pad_w = std::stoi(argv[16]);
in_right_pad_h = std::stoi(argv[17]);
in_right_pad_w = std::stoi(argv[18]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
// exit(0);
exit(0);
}
const bool do_verification = std::stoi(argv[1]);
const int init_method = std::stoi(argv[2]);
const int nrepeat = std::stoi(argv[3]);
// Conv shape
#if 0
const ck::index_t N = 128;
const ck::index_t K = 256;
const ck::index_t C = 192;
const ck::index_t Y = 3;
const ck::index_t X = 3;
const ck::index_t Hi = 71;
const ck::index_t Wi = 71;
const ck::index_t conv_stride_h = 2;
const ck::index_t conv_stride_w = 2;
const ck::index_t conv_dilation_h = 1;
const ck::index_t conv_dilation_w = 1;
const ck::index_t in_left_pad_h = 1;
const ck::index_t in_left_pad_w = 1;
const ck::index_t in_right_pad_h = 1;
const ck::index_t in_right_pad_w = 1;
#else
const ck::index_t N = std::stoi(argv[4]);
const ck::index_t K = std::stoi(argv[5]);
const ck::index_t C = std::stoi(argv[6]);
const ck::index_t Y = std::stoi(argv[7]);
const ck::index_t X = std::stoi(argv[8]);
const ck::index_t Hi = std::stoi(argv[9]);
const ck::index_t Wi = std::stoi(argv[10]);
const ck::index_t conv_stride_h = std::stoi(argv[11]);
const ck::index_t conv_stride_w = std::stoi(argv[12]);
const ck::index_t conv_dilation_h = std::stoi(argv[13]);
const ck::index_t conv_dilation_w = std::stoi(argv[14]);
const ck::index_t in_left_pad_h = std::stoi(argv[15]);
const ck::index_t in_left_pad_w = std::stoi(argv[16]);
const ck::index_t in_right_pad_h = std::stoi(argv[17]);
const ck::index_t in_right_pad_w = std::stoi(argv[18]);
#endif
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment