import numpy as np import tvm import logging import sys, time, subprocess from tvm import autotvm import topi import json from topi.util import get_const_tuple import os op_attributes = { "N": int(os.environ['N']) if 'N' in os.environ else 64, "C": int(os.environ['C']) if 'C' in os.environ else 3, "H": int(os.environ['H']) if 'H' in os.environ else 229, "W": int(os.environ['W']) if 'W' in os.environ else 229, "F": int(os.environ['F']) if 'F' in os.environ else 32, "K": int(os.environ['K']) if 'K' in os.environ else 5, "ST": int(os.environ['ST']) if 'ST' in os.environ else 1, "PD": int(os.environ['PD']) if 'PD' in os.environ else 2, } @autotvm.template def get_template_op(**kargs): N = op_attributes["N"] CI = op_attributes["C"] H = op_attributes["H"] W = op_attributes["W"] H = op_attributes["H"] CO = op_attributes["F"] KH = KW = op_attributes["K"] stride = op_attributes["ST"] padding = op_attributes["PD"] dilation = 1 data = tvm.placeholder((N, CI, H, W), name='data') kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel') conv = topi.nn.conv2d_nchw( data, kernel, (stride, stride), (padding, padding), dilation=1, out_dtype='float32') s = tvm.create_schedule([conv.op]) cfg = autotvm.get_config() ##### space definition begin ##### n, f, y, x = s[conv].op.axis rc, ry, rx = s[conv].op.reduce_axis cfg.define_split("tile_f", f, num_outputs=4) cfg.define_split("tile_y", y, num_outputs=4) cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_rc", rc, num_outputs=2) cfg.define_split("tile_ry", ry, num_outputs=2) cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 125, 256]) target = tvm.target.current_target() if target.target_name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) pad_data, kernel = s[conv].op.input_tensors s[pad_data].compute_inline() if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag: s[kernel].compute_inline() if conv.op in s.outputs: output = conv OL = s.cache_write(conv, 'local') else: output = s.outputs[0].output(0) s[conv].set_scope('local') OL = conv # create cache stage AA = s.cache_read(pad_data, 'shared', [OL]) WW = s.cache_read(kernel, 'shared', [OL]) # tile and bind spatial axes n, f, y, x = s[output].op.axis kernel_scope, n = s[output].split(n, nparts=1) bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) bf = s[output].fuse(n, bf) s[output].bind(bf, tvm.thread_axis("blockIdx.z")) s[output].bind(by, tvm.thread_axis("blockIdx.y")) s[output].bind(bx, tvm.thread_axis("blockIdx.x")) s[output].bind(vf, tvm.thread_axis("vthread")) s[output].bind(vy, tvm.thread_axis("vthread")) s[output].bind(vx, tvm.thread_axis("vthread")) s[output].bind(tf, tvm.thread_axis("threadIdx.z")) s[output].bind(ty, tvm.thread_axis("threadIdx.y")) s[output].bind(tx, tvm.thread_axis("threadIdx.x")) s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) s[OL].compute_at(s[output], tx) # tile reduction axes n, f, y, x = s[OL].op.axis rc, ry, rx = s[OL].op.reduce_axis rco, rci = cfg['tile_rc'].apply(s, OL, rc) ryo, ryi = cfg['tile_rx'].apply(s, OL, ry) rxo, rxi = cfg['tile_ry'].apply(s, OL, rx) s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x) s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) # cooperative fetching for load in [AA, WW]: n, f, y, x = s[load].op.axis fused = s[load].fuse(n, f, y, x) tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2]) ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) s[load].bind(tz, tvm.thread_axis("threadIdx.z")) s[load].bind(ty, tvm.thread_axis("threadIdx.y")) s[load].bind(tx, tvm.thread_axis("threadIdx.x")) # unroll s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) N, CO, OH, OW = get_const_tuple(output.shape) _, KH, KW, CI = get_const_tuple(kernel.shape) cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW) return s, [data, kernel, conv]