Commit c454d419 authored by lisj's avatar lisj
Browse files

删除子模块的gitignore

parent 3359c1f1
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
from string import Template
from datetime import date
import libxsmm_utilities
import fnmatch
import sys
if __name__ == "__main__":
argc = len(sys.argv)
if 1 < argc:
# required argument(s)
filename = sys.argv[1]
# default configuration if no arguments are given
ilp64 = offload = precision = flags = threshold = 0
sync = jit = 1
alpha = beta = 1
cacheline = 64
prefetch = -1
wrap = 1
malloc = 0
mnklist = list()
# optional argument(s)
if 2 < argc:
ilp64 = int(sys.argv[2])
if 3 < argc:
offload = int(sys.argv[3])
if 4 < argc:
cacheline = libxsmm_utilities.sanitize_alignment(int(sys.argv[4]))
if 5 < argc:
precision = int(sys.argv[5])
if 6 < argc:
prefetch = int(sys.argv[6])
if 7 < argc:
threshold = int(sys.argv[7])
if 8 < argc:
sync = int(sys.argv[8])
if 9 < argc:
jit = int(sys.argv[9])
if 10 < argc:
flags = int(sys.argv[10])
if 11 < argc:
alpha = int(sys.argv[11])
if 12 < argc:
beta = int(sys.argv[12])
if 13 < argc:
wrap = int(sys.argv[13])
if 14 < argc:
malloc = int(sys.argv[14])
if 15 < argc:
mnklist = sorted(libxsmm_utilities.load_mnklist(sys.argv[15:], 0))
version, branch, realversion = libxsmm_utilities.version_branch()
major, minor, update, patch = libxsmm_utilities.version_numbers(
version
)
if 0 == threshold:
threshold = 64 * 64 * 64
maxmnk = libxsmm_utilities.max_mnk(mnklist, threshold)
maxdim = int(maxmnk ** (1.0 / 3.0) + 0.5)
avgdim = int(0.5 * maxdim + 0.5)
avgm = libxsmm_utilities.median(
list(map(lambda mnk: mnk[0], mnklist)), avgdim, False
)
avgn = libxsmm_utilities.median(
list(map(lambda mnk: mnk[1], mnklist)), avgdim, False
)
avgk = libxsmm_utilities.median(
list(map(lambda mnk: mnk[2], mnklist)), avgdim, False
)
maxm = libxsmm_utilities.max_mnk(mnklist, avgdim, 0)
maxn = libxsmm_utilities.max_mnk(mnklist, avgdim, 1)
maxk = libxsmm_utilities.max_mnk(mnklist, avgdim, 2)
substitute = {
"VERSION": realversion,
"BRANCH": branch,
"MAJOR": major,
"MINOR": minor,
"UPDATE": update,
"PATCH": patch,
"DATE": date.today().strftime("%Y%m%d"),
"CACHELINE": cacheline,
"PREFETCH": [-1, prefetch][0 <= prefetch],
"MAX_MNK": maxmnk,
"MAX_DIM": maxdim,
"AVG_DIM": int((maxdim + 1) / 2),
"MAX_M": [maxdim, maxm][avgm < maxm],
"MAX_N": [maxdim, maxn][avgn < maxn],
"MAX_K": [maxdim, maxk][avgk < maxk],
"FLAGS": flags,
"ILP64": [0, 1][0 != ilp64],
"ALPHA": alpha,
"BETA": beta,
"WRAP": wrap,
"MALLOC": malloc,
"SYNC": [0, 1][0 != sync],
"JIT": [0, 1][0 != jit],
"LIBXSMM_OFFLOAD_BUILD": ["", "\n#define LIBXSMM_OFFLOAD_BUILD"][
0 != offload
],
"MNK_PREPROCESSOR_LIST": "",
}
template = Template(open(filename, "r").read())
if fnmatch.fnmatch(filename, "*.h*"):
if mnklist:
first = mnklist[0]
for mnk in mnklist:
mnkstr = "_".join(map(str, mnk))
if mnk != first:
substitute["MNK_PREPROCESSOR_LIST"] += "\n"
if 2 != precision:
substitute["MNK_PREPROCESSOR_LIST"] += (
"#define LIBXSMM_SMM_" + mnkstr
)
if mnk != first or 0 == precision:
substitute["MNK_PREPROCESSOR_LIST"] += "\n"
if 1 != precision:
substitute["MNK_PREPROCESSOR_LIST"] += (
"#define LIBXSMM_DMM_" + mnkstr
)
print(template.substitute(substitute))
else:
substitute["BLASINT_KIND"] = ["C_INT", "C_LONG_LONG"][0 != ilp64]
print(template.safe_substitute(substitute))
else:
sys.tracebacklimit = 0
raise ValueError(sys.argv[0] + ": wrong number of arguments!")
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import libxsmm_utilities
import sys
import os
if __name__ == "__main__":
argc = len(sys.argv)
if 1 < argc:
arg1_filename = [sys.argv[1], ""]["0" == sys.argv[1]]
arg1_isfile = os.path.isfile(arg1_filename)
base = 1
if arg1_isfile:
print("#if !defined(_WIN32)")
print("{ static const char *const build_state =")
print('# include "../' + os.path.basename(arg1_filename) + '"')
print(" ;")
print(" internal_build_state = build_state;")
print("}")
print("#endif")
base = 2
if (base + 2) < argc:
precision = int(sys.argv[base + 0])
threshold = int(sys.argv[base + 1])
mnklist = libxsmm_utilities.load_mnklist(sys.argv[base + 2:], 0)
print(
"/* omit registering code if JIT is enabled"
" and if an ISA extension is found"
)
print(
" * which is beyond the static code"
" path used to compile the library"
)
print(" */")
print("#if (0 != LIBXSMM_JIT) && !defined(__MIC__)")
print(
"if (LIBXSMM_X86_GENERIC > libxsmm_target_archid "
"/* JIT code gen. is not available */"
)
print(
" /* conditions allows to avoid JIT "
"(if static code is good enough) */"
)
print(
" || (LIBXSMM_STATIC_TARGET_ARCH == libxsmm_target_archid)"
)
print(
" || (LIBXSMM_X86_AVX512_CORE <= libxsmm_target_archid &&"
)
print(
" libxsmm_cpuid_vlen32(LIBXSMM_STATIC_TARGET_ARCH) =="
)
print(
" libxsmm_cpuid_vlen32(libxsmm_target_archid)))"
)
print("#endif")
print("{")
print(" libxsmm_xmmfunction func;")
for mnk in mnklist:
mstr, nstr, kstr, mnkstr = (
str(mnk[0]),
str(mnk[1]),
str(mnk[2]),
"_".join(map(str, mnk)),
)
mnksig = mstr + ", " + nstr + ", " + kstr
# prefer registering double-precision kernels
# when approaching an exhausted registry
if 1 != precision: # only double-precision
print(
" func.dmm = (libxsmm_dmmfunction)libxsmm_dmm_"
+ mnkstr
+ ";"
)
print(
" internal_register_static_code("
+ "LIBXSMM_GEMM_PRECISION_F64, "
+ mnksig
+ ", func, new_registry);"
)
for mnk in mnklist:
mstr, nstr, kstr, mnkstr = (
str(mnk[0]),
str(mnk[1]),
str(mnk[2]),
"_".join(map(str, mnk)),
)
mnksig = mstr + ", " + nstr + ", " + kstr
# prefer registering double-precision kernels
# when approaching an exhausted registry
if 2 != precision: # only single-precision
print(
" func.smm = (libxsmm_smmfunction)libxsmm_smm_"
+ mnkstr
+ ";"
)
print(
" internal_register_static_code("
+ "LIBXSMM_GEMM_PRECISION_F32, "
+ mnksig
+ ", func, new_registry);"
)
print("}")
else:
sys.tracebacklimit = 0
raise ValueError(sys.argv[0] + ": wrong number of arguments!")
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
from string import Template
import libxsmm_utilities
import fnmatch
import sys
if __name__ == "__main__":
argc = len(sys.argv)
if 1 < argc:
# required argument(s)
filename = sys.argv[1]
# default configuration if no arguments are given
precision = 0 # all
ifversion = 1 # interface
prefetch = -1 # auto
mnklist = list()
# optional argument(s)
if 2 < argc:
ivalue = int(sys.argv[2])
ifversion = (ivalue >> 2)
precision = (ivalue & 3)
if 3 < argc:
prefetch = int(sys.argv[3])
if 4 < argc:
mnklist = sorted(libxsmm_utilities.load_mnklist(sys.argv[4:], 0))
template = Template(open(filename, "r").read())
if fnmatch.fnmatch(filename, "*.h*"):
optional = [", ...", ""][0 <= prefetch]
substitute = {"MNK_INTERFACE_LIST": ""}
for mnk in mnklist:
mnkstr = "_".join(map(str, mnk))
if 2 != precision:
pfsig = [
optional + ");",
",\n "
"const float* pa, "
"const float* pb, "
"const float* pc);"
][0 < prefetch]
substitute["MNK_INTERFACE_LIST"] += (
"\nLIBXSMM_API void libxsmm_smm_"
+ mnkstr
+ "(const float* a, const float* b, float* c"
+ pfsig
)
if 1 != precision:
pfsig = [
optional + ");",
",\n "
"const double* pa, "
"const double* pb, "
"const double* pc);"
][0 < prefetch]
substitute["MNK_INTERFACE_LIST"] += (
"\nLIBXSMM_API void libxsmm_dmm_"
+ mnkstr
+ "(const double* a, const double* b, double* c"
+ pfsig
)
if 0 == precision:
substitute["MNK_INTERFACE_LIST"] += "\n"
if mnklist and 0 != precision:
substitute["MNK_INTERFACE_LIST"] += "\n"
print(template.substitute(substitute))
else: # Fortran interface
if 1 > ifversion and 0 != ifversion:
raise ValueError("Fortran interface level is inconsistent!")
# Fortran's OPTIONAL allows to always generate an interface
# with prefetch signature (more flexible usage)
if 0 == prefetch:
prefetch = -1
version, branch, realversion = libxsmm_utilities.version_branch(16)
major, minor, update, patch = libxsmm_utilities.version_numbers(
version
)
substitute = {
"VERSION": realversion,
"BRANCH": branch,
"MAJOR": major,
"MINOR": minor,
"UPDATE": update,
"PATCH": patch,
"MNK_INTERFACE_LIST": "",
"CONTIGUOUS": ["", ", CONTIGUOUS"][1 < ifversion]
}
if mnklist:
substitute["MNK_INTERFACE_LIST"] += "\n"
for mnk in mnklist:
mnkstr = "_".join(map(str, mnk))
if 0 == precision:
substitute["MNK_INTERFACE_LIST"] += (
"\n "
"!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smm_"
+ mnkstr
+ ", libxsmm_dmm_"
+ mnkstr
)
elif 2 != precision:
substitute["MNK_INTERFACE_LIST"] += (
"\n "
"!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smm_"
+ mnkstr
)
elif 1 != precision:
substitute["MNK_INTERFACE_LIST"] += (
"\n "
"!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_dmm_"
+ mnkstr
)
substitute["MNK_INTERFACE_LIST"] += "\n INTERFACE"
optional = [", OPTIONAL", ""][0 < prefetch]
bindc = ["", "BIND(C)"][0 < prefetch]
for mnk in mnklist:
mnkstr = "_".join(map(str, mnk))
if 2 != precision:
pfsiga = [
") BIND(C)\n",
","
+ "&".rjust(26 - len(mnkstr))
+ "\n & pa, pb, pc) "
+ bindc
+ "\n"
][0 != prefetch]
pfsigb = [
"",
" REAL(C_FLOAT), "
"INTENT(IN)" + optional + " :: "
"pa(*), "
"pb(*), "
"pc(*)\n"
][0 != prefetch]
substitute["MNK_INTERFACE_LIST"] += (
"\n "
"PURE SUBROUTINE libxsmm_smm_"
+ mnkstr
+ "(a, b, c"
+ pfsiga
+ " IMPORT :: C_FLOAT\n"
" REAL(C_FLOAT), "
"INTENT(IN) :: a(*), b(*)\n"
" REAL(C_FLOAT), "
"INTENT(INOUT) :: c(*)\n"
+ pfsigb
+ " END SUBROUTINE"
)
if 1 != precision:
pfsiga = [
") BIND(C)\n",
","
+ "&".rjust(26 - len(mnkstr))
+ "\n & pa, pb, pc) "
+ bindc
+ "\n"
][0 != prefetch]
pfsigb = [
"",
" REAL(C_DOUBLE), "
"INTENT(IN)" + optional + " :: "
"pa(*), "
"pb(*), "
"pc(*)\n"
][0 != prefetch]
substitute["MNK_INTERFACE_LIST"] += (
"\n "
"PURE SUBROUTINE libxsmm_dmm_"
+ mnkstr
+ "(a, b, c"
+ pfsiga
+ " IMPORT :: C_DOUBLE\n"
" REAL(C_DOUBLE), "
"INTENT(IN) :: a(*), b(*)\n"
" REAL(C_DOUBLE), "
"INTENT(INOUT) :: c(*)\n"
+ pfsigb
+ " END SUBROUTINE"
)
substitute["MNK_INTERFACE_LIST"] += "\n END INTERFACE"
print(template.safe_substitute(substitute))
else:
sys.tracebacklimit = 0
raise ValueError(sys.argv[0] + ": wrong number of arguments!")
#!/usr/bin/env sh
SRCDIR=../src
GREP=$(command -v grep)
if [ "" = "${GREP}" ]; then
>&2 echo "Error: missing prerequisites!"
exit 1
fi
cat << EOM
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_SOURCE_H
#define LIBXSMM_SOURCE_H
#if defined(LIBXSMM_MACROS_H)
# error Please do not include any LIBXSMM header other than libxsmm_source.h!
#endif
#if defined(LIBXSMM_BUILD)
# error LIBXSMM_BUILD cannot be defined for the header-only LIBXSMM!
#endif
/**
* This header is intentionally called "libxsmm_source.h" since the followings block
* includes *internal* files, and thereby exposes LIBXSMM's implementation.
* The so-called "header-only" usage model gives up the clearly defined binary interface
* (including support for hot-fixes after deployment), and requires to rebuild client
* code for every (internal) change of LIBXSMM. Please make sure to only rely on the
* public interface as the internal implementation may change without notice.
*/
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
EOM
HERE=$(cd "$(dirname "$0")" && pwd -P)
if [ "" = "$1" ]; then
DSTDIR=${SRCDIR}
else
DSTDIR=$1
fi
# determine order of filenames in directory list
export LC_ALL=C
# good-enough pattern to match a main function, and to exclude this translation unit
for FILE in $(cd "${HERE}/${SRCDIR}" && ${GREP} -L "main[[:space:]]*(.*)" ./*.c); do
BASENAME=$(basename "${FILE}")
echo "#include \"${DSTDIR}/${BASENAME}\""
done
cat << EOM
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#endif /*LIBXSMM_SOURCE_H*/
EOM
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import sys
if __name__ == "__main__":
argc = len(sys.argv)
if 6 == argc:
precision = int(sys.argv[1])
m, n, k = int(sys.argv[2]), int(sys.argv[3]), int(sys.argv[4])
prefetch = int(sys.argv[5])
mnkstr = str(m) + "_" + str(n) + "_" + str(k)
optional = ["", ", ..."][0 > prefetch]
signature = ["a, b, c", "a, b, c, pa, pb, pc"][0 < prefetch]
if 2 != precision:
pfsig = [
optional + ")",
"\n"
", const float* pa"
", const float* pb"
", const float* pc)",
][0 < prefetch]
print
print
print(
"LIBXSMM_API void libxsmm_smm_"
+ mnkstr
+ "(const float* a, const float* b, float* c"
+ pfsig
)
print("{")
print(
"#if defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_skx_sp) && \\"
)
print(" !(defined(__AVX512PF__) && defined(__AVX512ER__))")
print(" libxsmm_smm_" + mnkstr + "_skx(" + signature + ");")
print(
"#elif defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_knl_sp)"
)
print(" libxsmm_smm_" + mnkstr + "_knl(" + signature + ");")
print(
"#elif defined(__AVX2__) && "
"defined(LIBXSMM_GENTARGET_hsw_sp)"
)
print(" libxsmm_smm_" + mnkstr + "_hsw(" + signature + ");")
print(
"#elif defined(__AVX__) && "
"defined(LIBXSMM_GENTARGET_snb_sp)"
)
print(" libxsmm_smm_" + mnkstr + "_snb(" + signature + ");")
print(
"#elif defined(__SSE3__) && "
"defined(LIBXSMM_GENTARGET_wsm_sp)"
)
print(" libxsmm_smm_" + mnkstr + "_wsm(" + signature + ");")
print("#else")
print(
" const char transa = (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print(
" const char transb = (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print(" const float alpha = LIBXSMM_ALPHA, beta = LIBXSMM_BETA;")
print(
" const libxsmm_blasint "
"m = " + str(m) + ", "
"n = " + str(n) + ", "
"k = " + str(k) + ";"
)
if 0 < prefetch:
print(
" LIBXSMM_UNUSED(pa);"
" LIBXSMM_UNUSED(pb);"
" LIBXSMM_UNUSED(pc);"
)
print(
" LIBXSMM_INLINE_XGEMM(float, float, &transa, &transb,"
" &m, &n, &k, &alpha, a, &m, b, &k, &beta, c, &m);"
)
print("#endif")
print("}")
print
print
print(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_smm_"
+ mnkstr
+ ")(const float* a, const float* b, float* c"
+ pfsig
+ ";"
)
print(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_smm_"
+ mnkstr
+ ")(const float* a, const float* b, float* c"
+ pfsig
)
print("{")
print(" libxsmm_smm_" + mnkstr + "(" + signature + ");")
print("}")
if 1 != precision:
pfsig = [
optional + ")",
"\n"
", const double* pa"
", const double* pb"
", const double* pc)",
][0 < prefetch]
print
print
print(
"LIBXSMM_API void libxsmm_dmm_"
+ mnkstr
+ "(const double* a, const double* b, double* c"
+ pfsig
)
print("{")
print(
"#if defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_skx_dp) && \\"
)
print(" !(defined(__AVX512PF__) && defined(__AVX512ER__))")
print(" libxsmm_dmm_" + mnkstr + "_skx(" + signature + ");")
print(
"#elif defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_knl_dp)"
)
print(" libxsmm_dmm_" + mnkstr + "_knl(" + signature + ");")
print(
"#elif defined(__AVX2__) && "
"defined(LIBXSMM_GENTARGET_hsw_dp)"
)
print(" libxsmm_dmm_" + mnkstr + "_hsw(" + signature + ");")
print(
"#elif defined(__AVX__) && "
"defined(LIBXSMM_GENTARGET_snb_dp)"
)
print(" libxsmm_dmm_" + mnkstr + "_snb(" + signature + ");")
print(
"#elif defined(__SSE3__) && "
"defined(LIBXSMM_GENTARGET_wsm_dp)"
)
print(" libxsmm_dmm_" + mnkstr + "_wsm(" + signature + ");")
print("#else")
print(
" const char transa = (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print(
" const char transb = (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print(" const double alpha = LIBXSMM_ALPHA, beta = LIBXSMM_BETA;")
print(
" const libxsmm_blasint "
"m = " + str(m) + ", "
"n = " + str(n) + ", "
"k = " + str(k) + ";"
)
if 0 < prefetch:
print(
" LIBXSMM_UNUSED(pa);"
" LIBXSMM_UNUSED(pb);"
" LIBXSMM_UNUSED(pc);"
)
print(
" LIBXSMM_INLINE_XGEMM(double, double, &transa, &transb,"
" &m, &n, &k, &alpha, a, &m, b, &k, &beta, c, &m);"
)
print("#endif")
print("}")
print
print
print(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dmm_"
+ mnkstr
+ ")(const double* a, const double* b, double* c"
+ pfsig
+ ";"
)
print(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dmm_"
+ mnkstr
+ ")(const double* a, const double* b, double* c"
+ pfsig
)
print("{")
print(" libxsmm_dmm_" + mnkstr + "(" + signature + ");")
print("}")
else:
sys.tracebacklimit = 0
raise ValueError(sys.argv[0] + ": wrong number of arguments!")
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import itertools
import operator
import inspect
import sys
import os
try:
from functools import reduce
except ImportError:
pass
def upper_list(lists, level):
nlist = len(lists)
upper = [level, level + nlist][1 > level] - 1
above = lists[upper]
if above:
return above
elif -nlist <= level:
return upper_list(lists, level - 1)
else:
return []
# https://docs.python.org/3/library/itertools.html#itertools.product
def itertools_product(*args):
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
pools = [tuple(pool) for pool in args]
result = [[]]
for pool in pools:
result = [x + [y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
def load_mnklist(argv, threshold, inputformat=0, resultset=None):
if resultset is None:
resultset = set()
if 0 == inputformat: # indexes format
resultset = set(map(lambda mnk: tuple(map(int, mnk.split("_"))), argv))
elif -1 == inputformat: # new input format
groups = map(
lambda group: [int(i) for i in group.split()],
" ".join(argv[0:]).split(","),
)
resultset = set(
itertools.chain(
*[list(itertools_product(*(i, i, i))) for i in groups]
)
)
elif -2 == inputformat: # legacy format
mlist = list(
map(
int,
map(
lambda s: str(s).replace(",", " ").strip(),
argv[2:2 + int(argv[0])],
),
)
)
nlist = list(
map(
int,
map(
lambda s: str(s).replace(",", " ").strip(),
argv[2 + int(argv[0]):2 + int(argv[0]) + int(argv[1])],
),
)
)
klist = list(
map(
int,
map(
lambda s: str(s).replace(",", " ").strip(),
argv[2 + int(argv[0]) + int(argv[1]):],
),
)
)
mnk = [mlist, nlist, klist]
top = [
[mlist, upper_list(mnk, 0)][0 == len(mlist)],
[nlist, upper_list(mnk, 1)][0 == len(nlist)],
[klist, upper_list(mnk, 2)][0 == len(klist)],
]
for m in top[0]:
for n in top[1]:
if not nlist:
n = m
for k in top[2]:
if not klist:
k = n
if not mlist:
m = k
resultset.add((m, n, k))
else:
sys.tracebacklimit = 0
raise ValueError("load_mnklist: unexpected input format!")
if 0 != threshold: # threshold requested
return set(
filter(
lambda mnk: (0 < mnk[0])
and (0 < mnk[1])
and (0 < mnk[2])
and (threshold >= (mnk[0] * mnk[1] * mnk[2])),
resultset,
)
)
else:
return set(
filter(
lambda mnk: (0 < mnk[0]) and (0 < mnk[1]) and (0 < mnk[2]),
resultset,
)
)
def max_mnk(mnklist, init=0, index=None):
if index is not None and 0 <= index and index < 3:
mapped = map(lambda mnk: mnk[index], mnklist)
else:
mapped = map(lambda mnk: mnk[0] * mnk[1] * mnk[2], mnklist)
return reduce(max, mapped, init)
def median(list_of_numbers, fallback=None, average=True):
size = len(list_of_numbers)
if 0 < size:
# TODO: use nth element
list_of_numbers.sort()
size2 = int(size / 2)
if average and 0 == (size - size2 * 2):
medval = int(
0.5 * (list_of_numbers[size2 - 1] + list_of_numbers[size2])
+ 0.5
)
else:
medval = list_of_numbers[size2]
if fallback is not None:
result = min(medval, fallback)
else:
result = medval
elif fallback is not None:
result = fallback
else:
sys.tracebacklimit = 0
raise ValueError("median: empty list!")
return result
def is_pot(num):
return 0 <= num or 0 == (num & (num - 1))
def sanitize_alignment(alignment):
if 0 >= alignment:
alignment = [1, 64][0 != alignment]
elif not is_pot(alignment):
sys.tracebacklimit = 0
raise ValueError(
"sanitize_alignment: alignment must be a Power of Two (POT)!"
)
return alignment
def align_value(n, typesize, alignment):
if 0 < typesize and 0 < alignment:
return (
((n * typesize + alignment - 1) / alignment) * alignment
) / typesize
else:
sys.tracebacklimit = 0
raise ValueError("align_value: invalid input!")
def version_branch_from_file(version_filepath):
version_file = open(version_filepath, "r")
version, branch, sep = "1.0", "", "-"
try:
version_list, n = version_file.read().replace("\n", "").split(sep), 0
for word in version_list:
if not reduce(
operator.and_,
(subword.isdigit() for subword in word.split(".")),
True,
):
branch += [sep + word, word][0 == n]
n += 1
else:
break
version = sep.join(version_list[n:])
finally:
version_file.close()
return (version, branch)
def version_numbers(version, branch=None):
version_list = version.split("-")
if not version_list[0][0].isdigit():
vbranch = version_list[0]
else:
vbranch = "master"
if branch is None or vbranch == branch:
minor = update = patch = 0
major = 1
n = len(version_list)
if 1 < n:
patch_list = version_list[n - 1]
if 1 == len(patch_list.split(".")):
version_list = version_list[n - 2].split(".")
if version_list != [vbranch]:
patch = int(patch_list)
else:
major = int(patch_list)
else:
version_list = patch_list.split(".")
else:
version_list = version.split(".")
n = len(version_list)
try:
if 0 < n:
major = int(version_list[0])
if 1 < n:
minor = int(version_list[1])
if 2 < n:
update = int(version_list[2])
except ValueError:
# if 1 == n: major = 0
pass
else:
major = minor = update = patch = -1
return major, minor, update, patch
def version_branch(max_strlen=-1):
version_filename = "version.txt"
filepath_default = os.path.realpath(
os.path.join(
os.path.dirname(inspect.getfile(inspect.currentframe())),
"..",
version_filename,
)
)
filepath_local = os.path.realpath(version_filename) # local version file
realversion, branch = version_branch_from_file(filepath_default)
version = realversion
out_of_tree = filepath_default != filepath_local
if out_of_tree and os.path.isfile(filepath_local):
local, ignored = version_branch_from_file(filepath_local)
if version_numbers(realversion) < version_numbers(local):
version = local
if 0 < max_strlen:
start = int(max_strlen / 3)
cut = max(
branch.rfind("-", start, max_strlen),
branch.rfind("_", start, max_strlen),
branch.rfind(".", start, max_strlen),
)
if start < cut:
branch = branch[0:cut]
else:
branch = branch[0:max_strlen]
return (version, branch, realversion)
if __name__ == "__main__":
argc = len(sys.argv)
if 1 < argc:
arg1 = int(sys.argv[1])
else:
arg1 = 0
if -1 == arg1:
if 5 < argc:
# threshold = int(sys.argv[2])
mnk_size = int(sys.argv[3])
dims = load_mnklist(sys.argv[4:4 + mnk_size], 0, -1)
dims = load_mnklist(sys.argv[4 + mnk_size:], 0, -2, dims)
mnklist = map(lambda mnk: "_".join(map(str, mnk)), sorted(dims))
print(" ".join(mnklist))
elif 3 == argc:
major, minor, update, patch = (
version_numbers(sys.argv[2], "release")
)
print(["0", "1"][0 == patch])
elif 0 <= arg1:
if 0 == arg1 and 3 == argc:
major, minor, update, patch = version_numbers(sys.argv[2])
print(major) # soname version
else:
version, branch, realversion = version_branch()
major, minor, update, patch = version_numbers(version)
if 1 == arg1:
print(major)
elif 2 == arg1:
print(minor)
elif 3 == arg1:
print(update)
elif 4 == arg1:
print(patch)
elif "" != branch:
print("{}-{}".format(branch, realversion))
else:
print(realversion)
else:
sys.tracebacklimit = 0
raise ValueError(
"{}: wrong ({}) number of arguments ('{}') given!".format(
sys.argv[0], argc - 1, " ".join(sys.argv[1:]))
)
#!/usr/bin/env sh
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
GIT=$(command -v git)
SHIFT=0
if [ "$1" ]; then
SHIFT=$1
fi
NAME=$(${GIT} rev-parse --abbrev-ref HEAD 2>/dev/null)
MAIN=$(${GIT} describe --tags --match "[0-9]*" --abbrev=0 2>/dev/null)
if [ "${MAIN}" ]; then
VERSION="${NAME}-${MAIN}"
REVC=$(${GIT} rev-list --count --no-merges "${MAIN}"..HEAD 2>/dev/null)
else
VERSION=${NAME}
REVC=$(${GIT} rev-list --count --no-merges HEAD 2>/dev/null)
fi
echo "${VERSION}-$((REVC+SHIFT))"
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include <libxsmm_cpuid.h>
#include <libxsmm_generator.h>
#include <libxsmm_memory.h>
#include <libxsmm_sync.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <signal.h>
#include <setjmp.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#if defined(_MSC_VER)
# define LIBXSMM_CPUID_ARM_ENC16(OP0, OP1, CRN, CRM, OP2) ( \
(((OP0) & 1) << 14) | \
(((OP1) & 7) << 11) | \
(((CRN) & 15) << 7) | \
(((CRM) & 15) << 3) | \
(((OP2) & 7) << 0))
# define ID_AA64ISAR1_EL1 LIBXSMM_CPUID_ARM_ENC16(0b11, 0b000, 0b0000, 0b0110, 0b001)
# define ID_AA64PFR0_EL1 LIBXSMM_CPUID_ARM_ENC16(0b11, 0b000, 0b0000, 0b0100, 0b000)
# define LIBXSMM_CPUID_ARM_MRS(RESULT, ID) RESULT = _ReadStatusReg(ID)
#else
# define LIBXSMM_CPUID_ARM_MRS(RESULT, ID) __asm__ __volatile__( \
"mrs %0," LIBXSMM_STRINGIFY(ID) : "=r"(RESULT))
#endif
#if defined(LIBXSMM_PLATFORM_AARCH64)
LIBXSMM_APIVAR_DEFINE(jmp_buf internal_cpuid_arm_jmp_buf);
LIBXSMM_API_INTERN void internal_cpuid_arm_sigill(int /*signum*/);
LIBXSMM_API_INTERN void internal_cpuid_arm_sigill(int signum)
{
void (*const handler)(int) = signal(signum, internal_cpuid_arm_sigill);
LIBXSMM_ASSERT(SIGILL == signum);
if (SIG_ERR != handler) longjmp(internal_cpuid_arm_jmp_buf, 1);
}
#endif
LIBXSMM_API int libxsmm_cpuid_arm(libxsmm_cpuid_info* info)
{
static int result = LIBXSMM_TARGET_ARCH_UNKNOWN;
#if defined(LIBXSMM_PLATFORM_AARCH64)
#if defined(__APPLE__) && defined(__arm64__)
result = LIBXSMM_AARCH64_V81;
#else
if (LIBXSMM_TARGET_ARCH_UNKNOWN == result) { /* avoid redetecting features */
void (*const handler)(int) = signal(SIGILL, internal_cpuid_arm_sigill);
result = LIBXSMM_AARCH64_V81;
if (SIG_ERR != handler) {
uint64_t capability; /* 64-bit value */
if (0 == setjmp(internal_cpuid_arm_jmp_buf)) {
LIBXSMM_CPUID_ARM_MRS(capability, ID_AA64ISAR1_EL1);
if (0xF & capability) { /* DPB */
result = LIBXSMM_AARCH64_V82;
if (0 == setjmp(internal_cpuid_arm_jmp_buf)) {
LIBXSMM_CPUID_ARM_MRS(capability, ID_AA64PFR0_EL1);
if (0xF & (capability >> 32)) { /* SVE */
result = LIBXSMM_AARCH64_A64FX;
}
}
}
}
/* restore original state */
signal(SIGILL, handler);
}
if (NULL != info) LIBXSMM_MEMZERO127(info);
}
#endif
#else
# if !defined(NDEBUG)
static int error_once = 0;
if (0 != libxsmm_verbosity /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM WARNING: libxsmm_cpuid_arm called on non-ARM platform!\n");
}
# endif
if (NULL != info) LIBXSMM_MEMZERO127(info);
#endif
return result;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include <libxsmm_generator.h>
#include <libxsmm_memory.h>
#include <libxsmm_sync.h>
#if !defined(_WIN32)
# include <sys/mman.h>
#endif
#if defined(LIBXSMM_PLATFORM_X86)
/* XGETBV: receive results (EAX, EDX) for eXtended Control Register (XCR). */
/* CPUID, receive results (EAX, EBX, ECX, EDX) for requested FUNCTION/SUBFN. */
#if defined(_MSC_VER) /*defined(_WIN32) && !defined(__GNUC__)*/
# define LIBXSMM_XGETBV(XCR, EAX, EDX) { \
unsigned long long libxsmm_xgetbv_ = _xgetbv(XCR); \
EAX = (int)libxsmm_xgetbv_; \
EDX = (int)(libxsmm_xgetbv_ >> 32); \
}
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) { \
int libxsmm_cpuid_x86_[/*4*/] = { 0, 0, 0, 0 }; \
__cpuidex(libxsmm_cpuid_x86_, FUNCTION, SUBFN); \
EAX = (unsigned int)libxsmm_cpuid_x86_[0]; \
EBX = (unsigned int)libxsmm_cpuid_x86_[1]; \
ECX = (unsigned int)libxsmm_cpuid_x86_[2]; \
EDX = (unsigned int)libxsmm_cpuid_x86_[3]; \
}
# elif defined(__GNUC__) || !defined(_CRAYC)
# if (64 > (LIBXSMM_BITS))
LIBXSMM_EXTERN LIBXSMM_RETARGETABLE int __get_cpuid( /* prototype */
unsigned int, unsigned int*, unsigned int*, unsigned int*, unsigned int*);
# define LIBXSMM_XGETBV(XCR, EAX, EDX) EAX = (EDX) = 0xFFFFFFFF
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) \
EAX = (EBX) = (EDX) = 0; ECX = (SUBFN); \
__get_cpuid(FUNCTION, &(EAX), &(EBX), &(ECX), &(EDX))
# else /* 64-bit */
# define LIBXSMM_XGETBV(XCR, EAX, EDX) __asm__ __volatile__( \
".byte 0x0f, 0x01, 0xd0" /*xgetbv*/ : "=a"(EAX), "=d"(EDX) : "c"(XCR) \
)
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) \
__asm__ __volatile__ (".byte 0x0f, 0xa2" /*cpuid*/ \
: "=a"(EAX), "=b"(EBX), "=c"(ECX), "=d"(EDX) \
: "a"(FUNCTION), "b"(0), "c"(SUBFN), "d"(0) \
)
# endif
# else /* legacy Cray Compiler */
# define LIBXSMM_XGETBV(XCR, EAX, EDX) EAX = (EDX) = 0
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) EAX = (EBX) = (ECX) = (EDX) = 0
# endif
#endif
#define LIBXSMM_CPUID_CHECK(VALUE, CHECK) ((CHECK) == ((CHECK) & (VALUE)))
LIBXSMM_API int libxsmm_cpuid_x86(libxsmm_cpuid_info* info)
{
static int result = LIBXSMM_TARGET_ARCH_UNKNOWN;
#if defined(LIBXSMM_PLATFORM_X86)
unsigned int eax, ebx, ecx, edx;
LIBXSMM_CPUID_X86(0, 0/*ecx*/, eax, ebx, ecx, edx);
if (1 <= eax) { /* CPUID max. leaf */
/* avoid redetecting features but redetect on request (info given) */
if (LIBXSMM_TARGET_ARCH_UNKNOWN == result || NULL != info) {
int feature_cpu = LIBXSMM_X86_GENERIC, feature_os = LIBXSMM_X86_GENERIC, has_context = 0;
unsigned int maxleaf = eax;
# if defined(__linux__)
if (0 == libxsmm_se && LIBXSMM_TARGET_ARCH_UNKNOWN == result) {
FILE *const selinux = fopen("/sys/fs/selinux/enforce", "rb");
if (NULL != selinux) {
if (1 == fread(&libxsmm_se, 1/*sizeof(char)*/, 1/*count*/, selinux)) {
libxsmm_se = ('0' != libxsmm_se ? 1 : 0);
}
else { /* conservative assumption in case of read-error */
libxsmm_se = 1;
}
fclose(selinux);
}
}
# elif defined(MAP_JIT)
libxsmm_se = 1;
# endif
LIBXSMM_CPUID_X86(1, 0/*ecx*/, eax, ebx, ecx, edx);
if (LIBXSMM_CPUID_CHECK(ecx, 0x00000001)) { /* SSE3(0x00000001) */
if (LIBXSMM_CPUID_CHECK(ecx, 0x00100000)) { /* SSE42(0x00100000) */
if (LIBXSMM_CPUID_CHECK(ecx, 0x10000000)) { /* AVX(0x10000000) */
if (LIBXSMM_CPUID_CHECK(ecx, 0x00001000)) { /* FMA(0x00001000) */
unsigned int ecx2;
LIBXSMM_CPUID_X86(7, 0/*ecx*/, eax, ebx, ecx2, edx);
/* AVX512F(0x00010000), AVX512CD(0x10000000) */
if (LIBXSMM_CPUID_CHECK(ebx, 0x10010000)) { /* Common */
/* AVX512DQ(0x00020000), AVX512BW(0x40000000), AVX512VL(0x80000000) */
if (LIBXSMM_CPUID_CHECK(ebx, 0xC0020000)) { /* AVX512-Core */
if (LIBXSMM_CPUID_CHECK(ecx2, 0x00000800)) { /* VNNI */
unsigned int edx2; /* we need to save edx for AMX check */
# if 0 /* no check required yet */
unsigned int ecx3;
LIBXSMM_CPUID_X86(7, 1/*ecx*/, eax, ebx, ecx3, edx);
# else
LIBXSMM_CPUID_X86(7, 1/*ecx*/, eax, ebx, ecx2, edx2);
# endif
if (LIBXSMM_CPUID_CHECK(eax, 0x00000020)) { /* BF16 */
feature_cpu = LIBXSMM_X86_AVX512_CPX;
if (LIBXSMM_CPUID_CHECK(edx, 0x03400000)) { /* AMX-TILE, AMX-INT8, AMX-BF16 */
feature_cpu = LIBXSMM_X86_AVX512_SPR;
}
}
else feature_cpu = LIBXSMM_X86_AVX512_CLX; /* CLX */
}
else feature_cpu = LIBXSMM_X86_AVX512_CORE; /* SKX */
}
/* AVX512PF(0x04000000), AVX512ER(0x08000000) */
else if (LIBXSMM_CPUID_CHECK(ebx, 0x0C000000)) { /* AVX512-MIC */
if (LIBXSMM_CPUID_CHECK(edx, 0x0000000C)) { /* KNM */
feature_cpu = LIBXSMM_X86_AVX512_KNM;
}
else feature_cpu = LIBXSMM_X86_AVX512_MIC; /* KNL */
}
else feature_cpu = LIBXSMM_X86_AVX512; /* AVX512-Common */
}
else feature_cpu = LIBXSMM_X86_AVX2;
}
else feature_cpu = LIBXSMM_X86_AVX;
}
else feature_cpu = LIBXSMM_X86_SSE42;
}
else feature_cpu = LIBXSMM_X86_SSE3;
}
# if !defined(LIBXSMM_INTRINSICS_DEBUG)
LIBXSMM_ASSERT_MSG(LIBXSMM_STATIC_TARGET_ARCH <= LIBXSMM_MAX(LIBXSMM_X86_GENERIC, feature_cpu), "missed detecting ISA extensions");
/* coverity[dead_error_line] */
if (LIBXSMM_STATIC_TARGET_ARCH > feature_cpu) feature_cpu = LIBXSMM_STATIC_TARGET_ARCH;
# endif
/* XSAVE/XGETBV(0x04000000), OSXSAVE(0x08000000) */
if (LIBXSMM_CPUID_CHECK(ecx, 0x0C000000)) { /* OS SSE support */
feature_os = LIBXSMM_MIN(LIBXSMM_X86_SSE42, feature_cpu);
if (LIBXSMM_X86_AVX <= feature_cpu) {
LIBXSMM_XGETBV(0, eax, edx);
if (LIBXSMM_CPUID_CHECK(eax, 0x00000006)) { /* OS XSAVE 256-bit */
feature_os = LIBXSMM_MIN(LIBXSMM_X86_AVX2, feature_cpu);
if (LIBXSMM_CPUID_CHECK(eax, 0x000000E0)) { /* OS XSAVE 512-bit */
feature_os = LIBXSMM_MIN(LIBXSMM_X86_AVX512_CPX, feature_cpu);
if (LIBXSMM_X86_AVX512_SPR <= feature_cpu && 7 <= maxleaf
&& LIBXSMM_CPUID_CHECK(eax, 0x00060000)) /* OS XSAVE 512-bit */
{
feature_os = feature_cpu; /* unlimited AMX */
}
}
}
}
}
else if (LIBXSMM_X86_GENERIC <= feature_cpu) {
/* assume FXSAVE, which should be fine
* 16 years after the first x86_64 OS
*/
feature_os = LIBXSMM_X86_SSE42;
}
else feature_os = LIBXSMM_TARGET_ARCH_GENERIC;
has_context = (LIBXSMM_STATIC_TARGET_ARCH >= feature_cpu || feature_os >= feature_cpu) ? 1 : 0;
if (LIBXSMM_TARGET_ARCH_UNKNOWN == result && 0 != libxsmm_verbosity) { /* library code is expected to be mute */
# if !defined(LIBXSMM_TARGET_ARCH)
const int target_vlen32 = libxsmm_cpuid_vlen32(feature_cpu);
const char *const compiler_support = (libxsmm_cpuid_vlen32(LIBXSMM_MAX_STATIC_TARGET_ARCH) < target_vlen32
? "" : (((2 <= libxsmm_verbosity || 0 > libxsmm_verbosity) && LIBXSMM_MAX_STATIC_TARGET_ARCH < feature_cpu)
? "highly " : NULL));
if (NULL != compiler_support) {
const char *const name = libxsmm_cpuid_name( /* exclude MIC when running on Core processors */
(((LIBXSMM_X86_AVX512_MIC == LIBXSMM_MAX_STATIC_TARGET_ARCH) ||
(LIBXSMM_X86_AVX512_KNM == LIBXSMM_MAX_STATIC_TARGET_ARCH)) && (LIBXSMM_X86_AVX512_CORE <= feature_cpu))
? LIBXSMM_X86_AVX2 : LIBXSMM_MAX_STATIC_TARGET_ARCH);
fprintf(stderr, "LIBXSMM WARNING: %soptimized non-JIT code paths are limited to \"%s\"!\n", compiler_support, name);
}
# endif
# if !defined(NDEBUG) && defined(__OPTIMIZE__)
fprintf(stderr, "LIBXSMM WARNING: library is optimized without -DNDEBUG and contains debug code!\n");
# endif
# if !defined(__APPLE__) || !defined(__MACH__) /* permitted features */
if (0 == has_context) {
fprintf(stderr, "LIBXSMM WARNING: detected CPU features are not permitted by the OS!\n");
if (0 == libxsmm_se) {
fprintf(stderr, "LIBXSMM WARNING: downgraded code generation to supported features!\n");
}
}
# endif
}
/* macOS is faulting AVX-512 (on-demand larger state) */
result = feature_cpu;
# if !defined(__APPLE__) || !defined(__MACH__)
# if 0 /* opportunistic */
if (0 == libxsmm_se)
# endif
{ /* only permitted features */
result = LIBXSMM_MIN(feature_cpu, feature_os);
}
# endif
if (NULL != info) {
LIBXSMM_CPUID_X86(0x80000007, 0/*ecx*/, eax, ebx, ecx, edx);
info->constant_tsc = LIBXSMM_CPUID_CHECK(edx, 0x00000100);
info->has_context = has_context;
}
}
}
else {
if (NULL != info) LIBXSMM_MEMZERO127(info);
result = LIBXSMM_X86_GENERIC;
}
#else
# if !defined(NDEBUG)
static int error_once = 0;
if (0 != libxsmm_verbosity /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM WARNING: libxsmm_cpuid_x86 called on non-x86 platform!\n");
}
# endif
if (NULL != info) LIBXSMM_MEMZERO127(info);
#endif
return result;
}
LIBXSMM_API int libxsmm_cpuid(void)
{
#if defined(LIBXSMM_PLATFORM_X86)
return libxsmm_cpuid_x86(NULL/*info*/);
#else
return libxsmm_cpuid_arm(NULL/*info*/);
#endif
}
/**
* This implementation also accounts for non-x86 platforms,
* which not only allows to resolve any given ID but to
* fallback gracefully ("unknown").
*/
LIBXSMM_API const char* libxsmm_cpuid_name(int id)
{
const char* target_arch = NULL;
switch (id) {
case LIBXSMM_X86_AVX512_SPR: {
target_arch = "spr";
} break;
case LIBXSMM_X86_AVX512_CPX: {
target_arch = "cpx";
} break;
case LIBXSMM_X86_AVX512_CLX: {
target_arch = "clx";
} break;
case LIBXSMM_X86_AVX512_CORE: {
target_arch = "skx";
} break;
case LIBXSMM_X86_AVX512_KNM: {
target_arch = "knm";
} break;
case LIBXSMM_X86_AVX512_MIC: {
target_arch = "knl";
} break;
case LIBXSMM_X86_AVX512: {
/* TODO: rework BE to use target ID instead of set of strings (target_arch = "avx3") */
target_arch = "hsw";
} break;
case LIBXSMM_X86_AVX2: {
target_arch = "hsw";
} break;
case LIBXSMM_X86_AVX: {
target_arch = "snb";
} break;
case LIBXSMM_X86_SSE42: {
target_arch = "wsm";
} break;
case LIBXSMM_X86_SSE3: {
target_arch = "sse3";
} break;
case LIBXSMM_AARCH64_V81: {
target_arch = "aarch64";
} break;
case LIBXSMM_AARCH64_A64FX: {
target_arch = "a64fx";
} break;
case LIBXSMM_TARGET_ARCH_GENERIC: {
target_arch = "generic";
} break;
default: if (LIBXSMM_X86_GENERIC <= id) {
target_arch = "x86_64";
}
else {
target_arch = "unknown";
}
}
LIBXSMM_ASSERT(NULL != target_arch);
return target_arch;
}
/**
* This implementation also accounts for non-x86 platforms,
* which not only allows to resolve any given ID but to
* fallback gracefully (scalar).
*/
LIBXSMM_API int libxsmm_cpuid_vlen32(int id)
{
int result;
#if defined(LIBXSMM_PLATFORM_X86)
if (LIBXSMM_X86_AVX512 <= id) {
result = 16;
}
else if (LIBXSMM_X86_AVX <= id) {
result = 8;
}
else if (LIBXSMM_X86_SSE42 <= id) {
result = 4;
}
else
#elif defined(LIBXSMM_PLATFORM_AARCH64)
if (LIBXSMM_AARCH64_V81 == id) {
result = 4;
}
else if (LIBXSMM_AARCH64_A64FX == id) {
result = 16;
}
else
#else
LIBXSMM_UNUSED(id);
#endif
{ /* scalar */
result = 1;
}
return result;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DIFF_H
#define LIBXSMM_DIFF_H
#include <libxsmm_intrinsics_x86.h>
#if !defined(LIBXSMM_DIFF_AVX512_ENABLED) && 0
# define LIBXSMM_DIFF_AVX512_ENABLED
#endif
#define LIBXSMM_DIFF_4_DECL(A) const uint32_t */*const*/ A = NULL
#define LIBXSMM_DIFF_4_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_4_LOAD(A, SRC) A = (const uint32_t*)(SRC)
#define LIBXSMM_DIFF_4(A, B, ...) ((unsigned char)(0 != (*(A) ^ (*(const uint32_t*)(B)))))
#define LIBXSMM_DIFF_8_DECL(A) const uint64_t */*const*/ A = NULL
#define LIBXSMM_DIFF_8_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_8_LOAD(A, SRC) A = (const uint64_t*)(SRC)
#define LIBXSMM_DIFF_8(A, B, ...) ((unsigned char)(0 != (*(A) ^ (*(const uint64_t*)(B)))))
#define LIBXSMM_DIFF_SSE_DECL(A) __m128i A = LIBXSMM_INTRINSICS_MM_UNDEFINED_SI128()
#define LIBXSMM_DIFF_SSE_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_SSE_LOAD(A, SRC) A = LIBXSMM_INTRINSICS_LOADU_SI128((const __m128i*)(SRC))
#define LIBXSMM_DIFF_SSE(A, B, ...) ((unsigned char)(0xFFFF != _mm_movemask_epi8(_mm_cmpeq_epi8( \
A, LIBXSMM_INTRINSICS_LOADU_SI128((const __m128i*)(B))))))
#if (LIBXSMM_X86_GENERIC <= LIBXSMM_STATIC_TARGET_ARCH) /*|| defined(LIBXSMM_INTRINSICS_TARGET)*/
# define LIBXSMM_DIFF_16_DECL LIBXSMM_DIFF_SSE_DECL
# define LIBXSMM_DIFF_16_ASSIGN LIBXSMM_DIFF_SSE_ASSIGN
# define LIBXSMM_DIFF_16_LOAD LIBXSMM_DIFF_SSE_LOAD
# define LIBXSMM_DIFF_16 LIBXSMM_DIFF_SSE
#else
# define LIBXSMM_DIFF_16_DECL(A) const uint64_t */*const*/ A = NULL
# define LIBXSMM_DIFF_16_ASSIGN(A, B) (A) = (B)
# define LIBXSMM_DIFF_16_LOAD(A, SRC) A = (const uint64_t*)(SRC)
# define LIBXSMM_DIFF_16(A, B, ...) ((unsigned char)(0 != (((A)[0] ^ (*(const uint64_t*)(B))) | \
((A)[1] ^ ((const uint64_t*)(B))[1]))))
#endif
#define LIBXSMM_DIFF_AVX2_DECL(A) __m256i A = LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256()
#define LIBXSMM_DIFF_AVX2_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_AVX2_LOAD(A, SRC) A = _mm256_loadu_si256((const __m256i*)(SRC))
#define LIBXSMM_DIFF_AVX2(A, B, ...) ((unsigned char)(-1 != _mm256_movemask_epi8(_mm256_cmpeq_epi8( \
A, _mm256_loadu_si256((const __m256i*)(B))))))
#if (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_DIFF_32_DECL LIBXSMM_DIFF_AVX2_DECL
# define LIBXSMM_DIFF_32_ASSIGN LIBXSMM_DIFF_AVX2_ASSIGN
# define LIBXSMM_DIFF_32_LOAD LIBXSMM_DIFF_AVX2_LOAD
# define LIBXSMM_DIFF_32 LIBXSMM_DIFF_AVX2
#else
# define LIBXSMM_DIFF_32_DECL(A) LIBXSMM_DIFF_16_DECL(A); LIBXSMM_DIFF_16_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _))
# define LIBXSMM_DIFF_32_ASSIGN(A, B) LIBXSMM_DIFF_16_ASSIGN(A, B); LIBXSMM_DIFF_16_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_32_, B, _))
# define LIBXSMM_DIFF_32_LOAD(A, SRC) LIBXSMM_DIFF_16_LOAD(A, SRC); LIBXSMM_DIFF_16_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), (const uint64_t*)(SRC) + 2)
# define LIBXSMM_DIFF_32(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_16(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_16(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), (const uint64_t*)(B) + 2, __VA_ARGS__)))
#endif
#define LIBXSMM_DIFF_48_DECL(A) LIBXSMM_DIFF_16_DECL(A); LIBXSMM_DIFF_32_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _))
#define LIBXSMM_DIFF_48_ASSIGN(A, B) LIBXSMM_DIFF_16_ASSIGN(A, B); LIBXSMM_DIFF_32_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_48_, B, _))
#define LIBXSMM_DIFF_48_LOAD(A, SRC) LIBXSMM_DIFF_16_LOAD(A, SRC); LIBXSMM_DIFF_32_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), (const uint64_t*)(SRC) + 2)
#define LIBXSMM_DIFF_48(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_16(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_32(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), (const uint64_t*)(B) + 2, __VA_ARGS__)))
#define LIBXSMM_DIFF_64SW_DECL(A) LIBXSMM_DIFF_32_DECL(A); LIBXSMM_DIFF_32_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _))
#define LIBXSMM_DIFF_64SW_ASSIGN(A, B) LIBXSMM_DIFF_32_ASSIGN(A, B); LIBXSMM_DIFF_32_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_64_, B, _))
#define LIBXSMM_DIFF_64SW_LOAD(A, SRC) LIBXSMM_DIFF_32_LOAD(A, SRC); LIBXSMM_DIFF_32_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), (const uint64_t*)(SRC) + 4)
#define LIBXSMM_DIFF_64SW(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_32(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_32(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), (const uint64_t*)(B) + 4, __VA_ARGS__)))
#if defined(LIBXSMM_DIFF_AVX512_ENABLED)
# define LIBXSMM_DIFF_AVX512_DECL(A) __m512i A = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32()
# define LIBXSMM_DIFF_AVX512_ASSIGN(A, B) (A) = (B)
# define LIBXSMM_DIFF_AVX512_LOAD(A, SRC) A = _mm512_loadu_si512((const __m512i*)(SRC))
# define LIBXSMM_DIFF_AVX512(A, B, ...) ((unsigned char)(0xFFFF != (unsigned int)/*_cvtmask16_u32*/(_mm512_cmpeq_epi32_mask( \
A, _mm512_loadu_si512((const __m512i*)(B))))))
#else
# define LIBXSMM_DIFF_AVX512_DECL LIBXSMM_DIFF_64SW_DECL
# define LIBXSMM_DIFF_AVX512_ASSIGN LIBXSMM_DIFF_64SW_ASSIGN
# define LIBXSMM_DIFF_AVX512_LOAD LIBXSMM_DIFF_64SW_LOAD
# define LIBXSMM_DIFF_AVX512 LIBXSMM_DIFF_64SW
#endif
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_DIFF_64_DECL LIBXSMM_DIFF_AVX512_DECL
# define LIBXSMM_DIFF_64_ASSIGN LIBXSMM_DIFF_AVX512_ASSIGN
# define LIBXSMM_DIFF_64_LOAD LIBXSMM_DIFF_AVX512_LOAD
# define LIBXSMM_DIFF_64 LIBXSMM_DIFF_AVX512
#else
# define LIBXSMM_DIFF_64_DECL LIBXSMM_DIFF_64SW_DECL
# define LIBXSMM_DIFF_64_ASSIGN LIBXSMM_DIFF_64SW_ASSIGN
# define LIBXSMM_DIFF_64_LOAD LIBXSMM_DIFF_64SW_LOAD
# define LIBXSMM_DIFF_64 LIBXSMM_DIFF_64SW
#endif
#define LIBXSMM_DIFF_DECL(N, A) LIBXSMM_CONCATENATE3(LIBXSMM_DIFF_, N, _DECL)(A)
#define LIBXSMM_DIFF_LOAD(N, A, SRC) LIBXSMM_CONCATENATE3(LIBXSMM_DIFF_, N, _LOAD)(A, SRC)
#define LIBXSMM_DIFF(N) LIBXSMM_CONCATENATE(LIBXSMM_DIFF_, N)
#define LIBXSMM_DIFF_N(TYPE, RESULT, DIFF, A, BN, ELEMSIZE, STRIDE, HINT, N) { \
const char* libxsmm_diff_b_ = (const char*)(BN) + (size_t)(HINT) * (STRIDE); \
for (RESULT = (HINT); (RESULT) < (N); ++(RESULT)) { \
if (0 == DIFF(A, libxsmm_diff_b_, ELEMSIZE)) break; \
libxsmm_diff_b_ += (STRIDE); \
} \
if ((N) == (RESULT)) { /* wrong hint */ \
TYPE libxsmm_diff_r_ = 0; \
libxsmm_diff_b_ = (const char*)(BN); /* reset */ \
for (; libxsmm_diff_r_ < (HINT); ++libxsmm_diff_r_) { \
if (0 == DIFF(A, libxsmm_diff_b_, ELEMSIZE)) { \
RESULT = libxsmm_diff_r_; \
break; \
} \
libxsmm_diff_b_ += (STRIDE); \
} \
} \
}
/** Function type representing the diff-functionality. */
LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE unsigned int (*libxsmm_diff_function)(
const void* /*a*/, const void* /*b*/, ... /*size*/);
/** Compare two data blocks of 4 Byte each. */
LIBXSMM_API unsigned char libxsmm_diff_4(const void* a, const void* b, ...);
/** Compare two data blocks of 8 Byte each. */
LIBXSMM_API unsigned char libxsmm_diff_8(const void* a, const void* b, ...);
/** Compare two data blocks of 16 Byte each. */
LIBXSMM_API unsigned char libxsmm_diff_16(const void* a, const void* b, ...);
/** Compare two data blocks of 32 Byte each. */
LIBXSMM_API unsigned char libxsmm_diff_32(const void* a, const void* b, ...);
/** Compare two data blocks of 48 Byte each. */
LIBXSMM_API unsigned char libxsmm_diff_48(const void* a, const void* b, ...);
/** Compare two data blocks of 64 Byte each. */
LIBXSMM_API unsigned char libxsmm_diff_64(const void* a, const void* b, ...);
#endif /*LIBXSMM_DIFF_H*/
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include <libxsmm_dnn.h>
#include "libxsmm_main.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(_OPENMP)
# include <omp.h>
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
LIBXSMM_API_INTERN void libxsmm_dnn_init(int target_arch)
{
LIBXSMM_UNUSED(target_arch);
}
LIBXSMM_API_INTERN void libxsmm_dnn_finalize(void)
{
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_get_feature_map_blocks( int C, int K, int* C_block, int* K_block, int* fm_lp_block, libxsmm_dnn_datatype datatype_in, libxsmm_dnn_datatype datatype_out ) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
int ifmblock = 0;
int ofmblock = 0;
int lp_block = 0;
int tmp_max_c_block = 64;
int tmp_max_k_block = 64;
int tmp_block = 0;
/* init libxsmm */
LIBXSMM_INIT
/* C */
if ( ((libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR) && (datatype_in == LIBXSMM_DNN_DATATYPE_BF16)) ||
(libxsmm_target_archid < LIBXSMM_X86_AVX512 ) ) {
tmp_max_c_block = 32;
} else if ( libxsmm_target_archid == LIBXSMM_AARCH64_V81 ) {
tmp_max_c_block = 16;
}
if ( C < tmp_max_c_block ) {
ifmblock = C;
} else {
for ( tmp_block = 1; tmp_block <= tmp_max_c_block; tmp_block *= 2 ) {
if ( C % tmp_block == 0 ) ifmblock = tmp_block;
}
}
/* K */
if ( ((libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR) && (datatype_in == LIBXSMM_DNN_DATATYPE_BF16)) ||
(libxsmm_target_archid < LIBXSMM_X86_AVX512 ) ) {
tmp_max_k_block = 32;
} else if ( libxsmm_target_archid == LIBXSMM_AARCH64_V81 ) {
tmp_max_k_block = 16;
}
if ( K < tmp_max_k_block ) {
ofmblock = K;
} else {
for ( tmp_block = 1; tmp_block <= tmp_max_k_block; tmp_block *= 2 ) {
if ( K % tmp_block == 0 ) ofmblock = tmp_block;
}
}
/* when do we need VNNI format? */
if ( (datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
lp_block = 1;
} else if ( (datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
lp_block = 2;
} else if ( (datatype_in == LIBXSMM_DNN_DATATYPE_I16) && ((datatype_out == LIBXSMM_DNN_DATATYPE_I32) || (datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) {
lp_block = 2;
} else if (datatype_in == LIBXSMM_DNN_DATATYPE_I8) {
lp_block = 4;
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
*C_block = ifmblock;
*K_block = ofmblock;
*fm_lp_block = lp_block;
return status;
}
LIBXSMM_API const char* libxsmm_dnn_get_error(libxsmm_dnn_err_t code)
{
switch (code) {
case LIBXSMM_DNN_SUCCESS:
return "LIBXSMM DNN Success!";
case LIBXSMM_DNN_WARN_FALLBACK:
return "LIBXSMM DNN Warning: Falling back to naive code as target is currently not supported by LIBXSMM!";
case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING:
return "LIBXSMM DNN Warning: RNN cell suboptimal minibatch blocking!";
case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING:
return "LIBXSMM DNN Warning: RNN cell suboptimal input feature blocking!";
case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING:
return "LIBXSMM DNN Warning: RNN cell suboptimal output feature blocking!";
case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING:
return "LIBXSMM DNN Warning: FC layer suboptimal minibatch blocking!";
case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING:
return "LIBXSMM DNN Warning: FC layer suboptimal input feature blocking!";
case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING:
return "LIBXSMM DNN Warning: FC layer suboptimal output feature blocking!";
case LIBXSMM_DNN_ERR_GENERAL:
return "LIBXSMM DNN Error: General error occurred!";
case LIBXSMM_DNN_ERR_CREATE_HANDLE:
return "LIBXSMM DNN Error: Handle creation failed!";
case LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE:
return "LIBXSMM DNN Error: Requested datatype is not available!";
case LIBXSMM_DNN_ERR_INVALID_BLOCKING:
return "LIBXSMM DNN Error: Requested Input/Output buffer size cannot be blocked!";
case LIBXSMM_DNN_ERR_INVALID_HANDLE:
return "LIBXSMM DNN Error: An invalid handle was provided!";
case LIBXSMM_DNN_ERR_DATA_NOT_BOUND:
return "LIBXSMM DNN Error: Not all required sources and destinations have been bound to convolution!";
case LIBXSMM_DNN_ERR_CREATE_TENSOR:
return "LIBXSMM DNN Error: Tensor creation failed!";
case LIBXSMM_DNN_ERR_INVALID_TENSOR:
return "LIBXSMM DNN Error: Invalid tensor was specified!";
case LIBXSMM_DNN_ERR_MISMATCH_TENSOR:
return "LIBXSMM DNN Error: Tensor doesn't match handle it should be bind to!";
case LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR:
return "LIBXSMM DNN Error: Invalid handle or tensor!";
case LIBXSMM_DNN_ERR_INVALID_KIND:
return "LIBXSMM DNN Error: Invalid convolution kind!";
case LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW:
return "LIBXSMM DNN Error: NCHW format is currently not natively supported by LIBXSMM!";
case LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT:
return "LIBXSMM DNN Error: Unsupported destination format when copying data!";
case LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT:
return "LIBXSMM DNN Error: Unsupported source format when copying data!";
case LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE:
return "LIBXSMM DNN Error: Unsupported format when requesting a convolution!";
case LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS:
return "LIBXSMM DNN Error: KCRS format is currently not natively supported by LIBXSMM!";
case LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL:
return "LIBXSMM DNN Error: Invalid format was specified!";
case LIBXSMM_DNN_ERR_CREATE_LAYOUT:
return "LIBXSMM DNN Error: Layout creation failed!";
case LIBXSMM_DNN_ERR_INVALID_LAYOUT:
return "LIBXSMM DNN Error: Invalid layout was specified!";
case LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH:
return "LIBXSMM DNN Error: Unsupported architecture!";
case LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED:
return "LIBXSMM DNN Error: scratch binding failed as scratch was not allocated!";
case LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE:
return "LIBXSMM DNN Error: an unknown tensor type was provided!";
case LIBXSMM_DNN_ERR_INVALID_ALGO:
return "LIBXSMM DNN Error: Invalid algorithm was specified!";
case LIBXSMM_DNN_ERR_INVALID_PADDING:
return "LIBXSMM DNN Error: Invalid padding was specified!";
case LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL:
return "LIBXSMM DNN Error: time steps should be >= 2 for RNN/LSTM!";
case LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS:
return "LIBXSMM DNN Error: failed to create internal layout arrays!";
case LIBXSMM_DNN_ERR_NOT_IMPLEMENTED:
return "LIBXSMM DNN Error: the requested functionality is right now not implemented!";
case LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER:
return "LIBXSMM DNN Error: the requested order of fusion in batch norm is right now not implemented!";
case LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION:
return "LIBXSMM DNN Error: the requested fusion in batch norm is right now not implemented!";
case LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN:
return "LIBXSMM DNN Error: Unsupported format when requesting a fused batch norm!";
case LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING:
return "LIBXSMM DNN Error: Unsupported pooling operations was requested!";
case LIBXSMM_DNN_ERR_INVALID_FORMAT_FC:
return "LIBXSMM DNN Error: Unsupported format when requesting a fullyconnected layer!";
case LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN:
return "LIBXSMM DNN Error: max sequence length is shorter than sequence length we attempt to set!";
case LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER:
return "LIBXSMM DNN Error: the requested order of fusion in group norm is right now not implemented!";
case LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION:
return "LIBXSMM DNN Error: the requested fusion in group norm is right now not implemented!";
case LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION:
return "LIBXSMM DNN Error: the requested fusion in fullyconnected is right now not implemented!";
default:
return "LIBXSMM DNN Error: Unknown error or warning occurred!";
}
}
LIBXSMM_API size_t libxsmm_dnn_typesize(libxsmm_dnn_datatype datatype)
{
switch (datatype) {
case LIBXSMM_DNN_DATATYPE_F32: return 4;
case LIBXSMM_DNN_DATATYPE_I32: return 4;
case LIBXSMM_DNN_DATATYPE_BF16: return 2;
case LIBXSMM_DNN_DATATYPE_I16: return 2;
case LIBXSMM_DNN_DATATYPE_I8: return 1;
/* no error expected as enumeration really arrives at an enum; compiler-checked */
default: return 1;
}
}
LIBXSMM_API size_t libxsmm_dnn_get_simd_width(libxsmm_dnn_datatype datatype)
{
size_t l_cl_width_bytes;
/* init libxsmm */
LIBXSMM_INIT
if ( libxsmm_target_archid == LIBXSMM_X86_GENERIC ||
libxsmm_target_archid == LIBXSMM_X86_SSE3 ||
libxsmm_target_archid == LIBXSMM_X86_SSE42 ) {
l_cl_width_bytes = 16;
} else if ( libxsmm_target_archid == LIBXSMM_X86_AVX2 ||
libxsmm_target_archid == LIBXSMM_X86_AVX ) {
l_cl_width_bytes = 32;
} else {
l_cl_width_bytes = 64;
}
return l_cl_width_bytes/libxsmm_dnn_typesize(datatype);
}
LIBXSMM_API_INLINE float libxsmm_internal_get_max( float* in_buffer, int length ) {
float absmax_value = LIBXSMM_ABS(in_buffer[0]);
int i = 0;
#ifdef _OPENMP
LIBXSMM_OMP_VAR(i);
# pragma omp parallel private(i)
{
float my_absmax_value = absmax_value;
# pragma omp for
for (i = 0; i < length; ++i ) {
if (LIBXSMM_ABS(in_buffer[i]) > my_absmax_value) {
my_absmax_value = LIBXSMM_ABS(in_buffer[i]);
}
}
# pragma omp critical
{
if (my_absmax_value > absmax_value) {
absmax_value = my_absmax_value;
}
}
}
#else
for (i = 1; i < length; ++i ) {
if (LIBXSMM_ABS(in_buffer[i]) > absmax_value) {
absmax_value = LIBXSMM_ABS(in_buffer[i]);
}
}
#endif
return absmax_value;
}
LIBXSMM_API_INLINE unsigned char libxsmm_internal_get_max_exp( float* in_buffer, int length ) {
libxsmm_intfloat val_exp;
unsigned char max_exp = 0;
/* bit-wise conversion to int */
val_exp.f = libxsmm_internal_get_max( in_buffer, length );
/* shift by mantissa to the right and convert to char */
max_exp = (unsigned char)((val_exp.ui & LIBXSMM_DNN_MASK_ABS_F32) >> LIBXSMM_DNN_MANT_SZ_F32);
return max_exp;
}
LIBXSMM_API_INLINE short libxsmm_internal_quantize_scalar_no_scf( float input, unsigned char max_exp, unsigned char add_shift, int round_mode ) {
libxsmm_intfloat value;
unsigned int qvalue = 0;
unsigned int mant = 0;
unsigned int sign = 0;
unsigned char rhs = 0;
unsigned char exp_off = 0;
/* init libxsmm */
LIBXSMM_INIT
/* in case of zero we don't need to do anything */
if (LIBXSMM_FEQ(input, 0)) {
qvalue = 0;
} else {
/* let's get a float copy to work on */
/* vinp = LIBXSMM_INTRINSICS_MM512_LOAD_PS( in_buffer[i] ); */
value.f = input;
/* let's compute the offset of the current exp at pos i from max offset, we need to mask the sign bit though */
/*__m512i vexp = _mm512_cvtps_epi32(_mm512_getexp_ps (vinp));
__m512i vexp_off = _mm512_sub_epi32(maxexpf, vexp);*/
exp_off = (unsigned char)(max_exp - ((value.ui & LIBXSMM_DNN_MASK_ABS_F32) >> LIBXSMM_DNN_MANT_SZ_F32));
/* cut out mantissa and set leading bit */
/*__m512i mmask = _mm512_set1_epi32(LIBXSMM_DNN_MASK_MANT_F32);
__m512i vmant = _mm512_or_epi32(_mm512_set1_epi32(0x1 << LIBXSMM_DNN_MANT_SZ_F32), _mm512_and_epi32( _mm512_castps_si512( vinp ), mmask));*/
mant = ((0x1 << LIBXSMM_DNN_MANT_SZ_F32) | (value.ui & LIBXSMM_DNN_MASK_MANT_F32));
/* extract sign */
/* __mmask16 smask = _mm512_cmplt_ps_mask (inp, _mm512_set1_ps(0)); */
sign = ((value.ui & LIBXSNN_DNN_MASK_SIGN_F32) >> (LIBXSMM_DNN_SZ_F32-1));
/* calculate rhs, be aware of the now explicit leading bit, @TODO add DFP8/4 */
rhs = (unsigned char)((LIBXSMM_DNN_MANT_SZ_F32+1) - LIBXSMM_DNN_MANT_DFP16 + exp_off + add_shift);
/* some safety, to generate 0 when we fall off quant region, @TODO issue a LIBXSMM WARNING: that we shifted out the entire mantissa */
if (rhs > (LIBXSMM_DNN_MANT_SZ_F32+1)) {
rhs = (LIBXSMM_DNN_MANT_SZ_F32+1);
}
/* finally shift the value into the region we need, this is now a 15-add_rhs bit number for the max value in in_buffer */
qvalue = (mant >> rhs);
/* handle sign, 2 complement */
if ( (sign > 0) && (qvalue > 0) ) {
qvalue = (~qvalue + 1);
}
if (round_mode == LIBXSMM_DNN_QUANT_BIAS_ROUND) {
/* biased rounding towards next bigger number */
/* first let's determine in the original number if we need a bias rounding, @TODO need fix for F64 */
int bias_needed = (mant & (0x3 << (rhs-2)));
/* apply bias */
if (bias_needed > 0) {
qvalue++;
}
} else if (round_mode == LIBXSMM_DNN_QUANT_NEAREST_ROUND) {
int nearest_needed = (mant & (0x1 << (rhs-1)));
/* apply rounding */
if ((nearest_needed > 0) && (rhs > 1)) {
qvalue++;
}
} else if (round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND) {
/* stochastic rounding, as implemented in the IBM paper from 2015, @TODO, fix F64 and DFP8 */
const float eps = LIXSMMM_DNN_RES_DFP16;
/* coverity[dont_call] */
const float r = (float)rand();
libxsmm_intfloat fvalue;
float p, q;
/* masking all bits which will be shifted out */
fvalue.ui = value.ui & ((LIBXSMM_DNN_MASK_FULL_F32) << rhs);
/* drawing a random number */
p = r/((float)RAND_MAX);
q = (input - fvalue.f)/eps;
/* apply rounding if needed */
if ((p + q) > 0.5f) {
++qvalue;
}
} else {
/* do nothing about rounding, just chop */
}
}
return (short)qvalue;
}
/* @TODO make this routine aware of any int type */
LIBXSMM_API void libxsmm_dnn_quantize( float* in_buffer, short* out_buffer, int length, unsigned char add_shift, unsigned char* scf, int round_mode ) {
int i = 0;
/* init libxsmm */
LIBXSMM_INIT
/* in case we are using FP-Mul based quantization we use a different path for now
@TODO let's unify the paths by using the similar vectorization for both */
if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) {
const float max_value = libxsmm_internal_get_max( in_buffer, length );
int maxexp = 0;
/* take return value of LIBXSMM_FREXPF to mute static analysis issue */
float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift);
scfq = libxsmm_sexp2_i8i(-maxexp);
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
if ( length % 16 == 0 ) {
__m512 vscfq = _mm512_set1_ps(scfq);
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for (i = 0; i < length; i+=16 ) {
_mm256_stream_si256( (__m256i *)&(out_buffer[i]), LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( &(in_buffer[i]), vscfq ) );
}
} else {
#endif
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for (i = 0; i < length; ++i ) {
out_buffer[i] = (short)LIBXSMM_ROUNDF(in_buffer[i] * scfq);
}
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
}
#endif
/* @TODO, we need to potentially fix this unsigned char problem */
#if !defined(NDEBUG) /* library code is expected to be mute */
if (maxexp > 0) {
fprintf(stderr, "error quant fil\n");
}
#endif
*scf = (unsigned char)(-maxexp);
} else {
/* get max exponent */
unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, length );
/* if we go for stochastic rounding, let's initialize random seed */
if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) {
srand(libxsmm_timer_tick() % ((unsigned int)-1));
}
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for (i = 0; i < length; ++i ) {
out_buffer[i] = libxsmm_internal_quantize_scalar_no_scf( in_buffer[i], max_exp, add_shift, round_mode );
}
*scf = (unsigned char)(14 - add_shift - (max_exp - 127));
}
}
LIBXSMM_API void libxsmm_dnn_quantize_act( float* in_buffer, short* out_buffer, unsigned int N, unsigned int C, unsigned int H, unsigned int W, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ) {
LIBXSMM_VLA_DECL(5, const float, in, in_buffer, C/cblk_f32, H, W, cblk_f32);
LIBXSMM_VLA_DECL(6, short, out, out_buffer, C/(cblk_i16*lp_blk), H, W, cblk_i16, lp_blk);
const unsigned int cblk = C/(cblk_i16*lp_blk);
int i1 = 0, i2 = 0, i3 = 0, i4 = 0, i5, i6;
/* init libxsmm */
LIBXSMM_INIT
/* some quick and dirty checks */
assert((C % cblk_f32) == 0);
assert((C % cblk_i16) == 0);
/* in case we are using FP-Mul based quantization we use a different path for now
@TODO let's unify the paths by using the similar vectorization for both */
if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) {
const float max_value = libxsmm_internal_get_max( in_buffer, N*C*H*W );
int maxexp = 0;
/* take return value of LIBXSMM_FREXPF to mute static analysis issue */
float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift);
scfq = libxsmm_sexp2_i8i(-maxexp);
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
if ( (cblk_f32 == 16) && (cblk_i16*lp_blk == 16) ) {
__m512 vscfq = _mm512_set1_ps(scfq);
#ifdef _OPENMP
LIBXSMM_OMP_VAR(i1);
# pragma omp parallel for private(i1)
#endif
for (i1 = 0; i1 < (int)(N*C*H*W); i1 += 16 ) {
_mm256_stream_si256( (__m256i *)&(out_buffer[i1]), LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( &(in_buffer[i1]), vscfq ) );
}
} else {
#endif
#ifdef _OPENMP
LIBXSMM_OMP_VAR(i1); LIBXSMM_OMP_VAR(i2); LIBXSMM_OMP_VAR(i3); LIBXSMM_OMP_VAR(i4); LIBXSMM_OMP_VAR(i5); LIBXSMM_OMP_VAR(i6);
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for (i1 = 0; i1 < (int)N; ++i1 ) {
for (i2 = 0; i2 < (int)cblk; ++i2 ) {
for (i3 = 0; i3 < (int)H; ++i3 ) {
for (i4 = 0; i4 < (int)W; ++i4 ) {
for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
for (i6 = 0; i6 < (int)lp_blk; ++i6 ) {
const int fi1 = i1;
const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)/cblk_f32;
const int fi3 = i3;
const int fi4 = i4;
const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)%cblk_f32;
LIBXSMM_VLA_ACCESS(6, out, i1, i2, i3, i4, i5, i6, cblk, H, W, cblk_i16, lp_blk) = (short)LIBXSMM_ROUNDF(
LIBXSMM_VLA_ACCESS(5, in, fi1, fi2, fi3, fi4, fi5, C / cblk_f32, H, W, cblk_f32) * scfq);
}
}
}
}
}
}
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
}
#endif
/* @TODO, we need to potentially fix this unsigned char problem */
#if !defined(NDEBUG) /* library code is expected to be mute */
if (maxexp > 0) {
fprintf(stderr, "error quant act\n");
}
#endif
*scf = (unsigned char)(-maxexp);
} else {
/* get max exponent */
unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, N*C*H*W );
/* if we go for stochastic rounding, let's initialize random seed */
if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) {
srand(libxsmm_timer_tick() % ((unsigned int)-1));
}
#ifdef _OPENMP
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for (i1 = 0; i1 < (int)N; ++i1 ) {
for (i2 = 0; i2 < (int)cblk; ++i2 ) {
for (i3 = 0; i3 < (int)H; ++i3 ) {
for (i4 = 0; i4 < (int)W; ++i4 ) {
for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
for (i6 = 0; i6 < (int)lp_blk; ++i6 ) {
const int fi1 = i1;
const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)/cblk_f32;
const int fi3 = i3;
const int fi4 = i4;
const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)%cblk_f32;
LIBXSMM_VLA_ACCESS(6, out, i1, i2, i3, i4, i5, i6, cblk, H, W, cblk_i16, lp_blk) = libxsmm_internal_quantize_scalar_no_scf(
LIBXSMM_VLA_ACCESS(5, in, fi1, fi2, fi3, fi4, fi5, C / cblk_f32, H, W, cblk_f32), max_exp, add_shift, round_mode);
}
}
}
}
}
}
*scf = (unsigned char)(14 - add_shift - (max_exp - 127));
}
}
LIBXSMM_API void libxsmm_dnn_quantize_fil( float* in_buffer, short* out_buffer, unsigned int K, unsigned int C, unsigned int R, unsigned int S, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int kblk_f32, unsigned int kblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ) {
LIBXSMM_VLA_DECL(6, const float, in, in_buffer, C/cblk_f32, R, S, cblk_f32, kblk_f32);
LIBXSMM_VLA_DECL(7, short, out, out_buffer, C/(cblk_i16*lp_blk), R, S, cblk_i16, kblk_i16, lp_blk);
unsigned int cblk = C/(cblk_i16*lp_blk);
unsigned int kblk = K/kblk_i16;
int i1 = 0, i2 = 0, i3 = 0, i4 = 0, i5, i6, i7;
/* some quick and dirty checks */
assert((C % cblk_f32) == 0);
assert((C % (cblk_i16*lp_blk)) == 0);
assert((K % kblk_f32) == 0);
assert((K % kblk_i16) == 0);
assert((lp_blk % 2) == 0);
/* init libxsmm */
LIBXSMM_INIT
/* in case we are using FP-Mul based quantization we use a different path for now
@TODO let's unify the paths by using the similar vectorization for both */
if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) {
const float max_value = libxsmm_internal_get_max( in_buffer, K*C*R*S );
int maxexp = 0;
/* take return value of LIBXSMM_FREXPF to mute static analysis issue */
float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift);
scfq = libxsmm_sexp2_i8i(-maxexp);
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
if ( (kblk_f32 == 16) && (cblk_f32 == 16) && (kblk_i16 == 16) && (cblk_i16*lp_blk == 16) ) {
const __m512 vscfq = _mm512_set1_ps(scfq);
const __m512i permute_compact_idx = _mm512_set_epi32(15,14,13,12,7,6,5,4,11,10,9,8,3,2,1,0);
#ifdef _OPENMP
# pragma omp parallel for private(i1, i2, i3, i4, i5) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for (i1 = 0; i1 < (int)kblk; ++i1 ) {
for (i2 = 0; i2 < (int)cblk; ++i2 ) {
for (i3 = 0; i3 < (int)R; ++i3 ) {
for (i4 = 0; i4 < (int)S; ++i4 ) {
for (i5 = 0; i5 < 16; i5+=2 ) {
__m256i even_ch = LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16(
&LIBXSMM_VLA_ACCESS(6, in, i1, i2, i3, i4, i5 + 0, 0, C / cblk_f32, R, S, cblk_f32, kblk_f32), vscfq);
__m256i odd_ch = LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16(
&LIBXSMM_VLA_ACCESS(6, in, i1, i2, i3, i4, i5 + 1, 0, C / cblk_f32, R, S, cblk_f32, kblk_f32), vscfq);
__m256i compressed_lo = _mm256_unpacklo_epi16(even_ch, odd_ch);
__m256i compressed_hi = _mm256_unpackhi_epi16(even_ch, odd_ch);
__m512i compact = _mm512_inserti64x4( _mm512_setzero_si512(), compressed_lo, 0);
compact = _mm512_inserti64x4(compact, compressed_hi, 1);
compact = _mm512_permutexvar_epi32(permute_compact_idx, compact);
LIBXSMM_INTRINSICS_MM512_STREAM_SI512(
(void*)&LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5 / 2, 0, 0, cblk, R, S, cblk_i16, kblk_i16, lp_blk),
compact);
}
}
}
}
}
} else {
#endif
#ifdef _OPENMP
LIBXSMM_OMP_VAR(i1); LIBXSMM_OMP_VAR(i2); LIBXSMM_OMP_VAR(i3); LIBXSMM_OMP_VAR(i4); LIBXSMM_OMP_VAR(i5); LIBXSMM_OMP_VAR(i6); LIBXSMM_OMP_VAR(i7);
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for (i1 = 0; i1 < (int)kblk; ++i1 ) {
for (i2 = 0; i2 < (int)cblk; ++i2 ) {
for (i3 = 0; i3 < (int)R; ++i3 ) {
for (i4 = 0; i4 < (int)S; ++i4 ) {
for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
for (i6 = 0; i6 < (int)kblk_i16; ++i6 ) {
for (i7 = 0; i7 < (int)lp_blk; ++i7 ) {
const int fi1 = ((i1*kblk_i16)+i6)/kblk_f32;
const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)/cblk_f32;
const int fi3 = i3;
const int fi4 = i4;
const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)%cblk_f32;
const int fi6 = ((i1*kblk_i16)+i6)%kblk_f32;
LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5, i6, i7, cblk, R, S, cblk_i16, kblk_i16, lp_blk) = (short)LIBXSMM_ROUNDF(
LIBXSMM_VLA_ACCESS(6, in, fi1, fi2, fi3, fi4, fi5, fi6, C / cblk_f32, R, S, cblk_f32, kblk_f32) * scfq);
}
}
}
}
}
}
}
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
}
#endif
/* @TODO, we need to potentially fix this unsigned char problem */
#if !defined(NDEBUG) /* library code is expected to be mute */
if (maxexp > 0) {
fprintf(stderr, "error quant fil\n");
}
#endif
*scf = (unsigned char)(-maxexp);
} else {
/* get max exponent */
unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, K*C*R*S );
/* if we go for stochastic rounding, let's initialize random seed */
if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) {
srand(libxsmm_timer_tick() % ((unsigned int)-1));
}
#ifdef _OPENMP
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for (i1 = 0; i1 < (int)kblk; ++i1 ) {
for (i2 = 0; i2 < (int)cblk; ++i2 ) {
for (i3 = 0; i3 < (int)R; ++i3 ) {
for (i4 = 0; i4 < (int)S; ++i4 ) {
for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
for (i6 = 0; i6 < (int)kblk_i16; ++i6 ) {
for (i7 = 0; i7 < (int)lp_blk; ++i7 ) {
const int fi1 = ((i1*kblk_i16)+i6)/kblk_f32;
const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)/cblk_f32;
const int fi3 = i3;
const int fi4 = i4;
const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)%cblk_f32;
const int fi6 = ((i1*kblk_i16)+i6)%kblk_f32;
LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5, i6, i7, cblk, R, S, cblk_i16, kblk_i16, lp_blk) = libxsmm_internal_quantize_scalar_no_scf(
LIBXSMM_VLA_ACCESS(6, in, fi1, fi2, fi3, fi4, fi5, fi6, C / cblk_f32, R, S, cblk_f32, kblk_f32), max_exp, add_shift, round_mode);
}
}
}
}
}
}
}
*scf = (unsigned char)(14 - add_shift - (max_exp - 127));
}
}
LIBXSMM_API void libxsmm_dnn_dequantize( short* in_buffer, float* out_buffer, int length, unsigned char scf ) {
const float val_exp = libxsmm_sexp2_i8i(-scf);
int i = 0;
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for ( i = 0; i < length; ++i ) {
out_buffer[i] = ((float)in_buffer[i])*val_exp;
}
}
LIBXSMM_API void libxsmm_truncate_convert_f32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int length) {
unsigned int i = 0;
/* truncate buffer to bf16 */
for ( i = 0; i < length; ++i ) {
libxsmm_bfloat16_hp t;
t.f = in[i];
out[i] = t.i[1];
}
}
LIBXSMM_API void libxsmm_rnaz_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len) {
unsigned int i = 0;
/* truncate buffer to bf16 */
for ( i = 0; i < len; ++i ) {
unsigned int int_round = 0;
unsigned int do_round = 1;
int_round = *((unsigned int*)&(in[i]));
/* we don't round NaN and inf */
if ( (int_round & 0x7f800000) == 0x7f800000 ) {
do_round = 0;
}
/* perform round nearest tie away from zero */
if ( do_round != 0 ) {
int_round = int_round + 0x00008000;
}
/* create the bf16 value by shifting out the lower 16bits */
int_round = int_round >> 16;
out[i] = (libxsmm_bfloat16)int_round;
}
}
LIBXSMM_API void libxsmm_rne_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len) {
unsigned int i = 0;
/* truncate buffer to bf16 */
for ( i = 0; i < len; ++i ) {
unsigned int int_round = 0;
unsigned int do_round = 1;
int_round = *((unsigned int*)&(in[i]));
/* we don't round NaN and inf */
if ( (int_round & 0x7f800000) == 0x7f800000 ) {
do_round = 0;
}
/* perform round nearest tie even */
if ( do_round != 0 ) {
unsigned int fixup = (int_round >> 16) & 1;
int_round = int_round + 0x00007fff + fixup;
}
/* create the bf16 value by shifting out the lower 16bits */
int_round = int_round >> 16;
out[i] = (unsigned short)int_round;
}
}
LIBXSMM_API void libxsmm_convert_bf16_f32(const libxsmm_bfloat16* in, float* out, unsigned int length) {
unsigned int i = 0;
/* up-convert is super simple */
for ( i = 0; i < length; ++i ) {
libxsmm_bfloat16_hp t;
t.i[1] = in[i];
t.i[0] = 0;
out[i] = t.f;
}
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst, Alexander Heinecke, Evangelos Georganas, Rajkishore Barik (Intel Corp.)
******************************************************************************/
#include <libxsmm_sync.h>
#include "libxsmm_main.h"
#include "libxsmm_dnn_convolution_forward.h"
#include "libxsmm_dnn_convolution_backward.h"
#include "libxsmm_dnn_convolution_weight_update.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(_OPENMP)
# include <omp.h>
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#define MIXED 0
#define KHWC 1
#define HWKC 2
#define CHWK 3
#define HWCK 4
/**********************************************************/
/* Helper functions for convolutions' general param setup */
/**********************************************************/
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_ifmblock( libxsmm_dnn_layer* handle ) {
int result = 1;
int ofm, lp;
libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K, &result, &ofm, &lp, handle->desc.datatype_in, handle->desc.datatype_out );
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_ofmblock( libxsmm_dnn_layer* handle ) {
int result = 1;
int ifm, lp;
libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K, &ifm, &result, &lp, handle->desc.datatype_in, handle->desc.datatype_out );
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fm_lp_block( libxsmm_dnn_layer* handle ) {
int result = 1;
int ifm, ofm;
libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K, &ifm, &ofm, &result, handle->desc.datatype_in, handle->desc.datatype_out );
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fallback_loops_fwd( libxsmm_dnn_layer* handle ) {
int result = 0;
/* FIXME: For now fallback only if MB is not divisible by number of threads */
if (handle->desc.N % handle->desc.threads != 0) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_blocksifm( libxsmm_dnn_layer* handle ) {
int result = handle->desc.C / handle->ifmblock;
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_blocksofm( libxsmm_dnn_layer* handle ) {
int result = handle->desc.K / handle->ofmblock;
return result;
}
/**********************************************************/
/* Helper functions for FWD convolutions' parameter setup */
/**********************************************************/
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fwd_ofw_rb( libxsmm_dnn_layer* handle ) {
int result = 0;
result = handle->ofw;
if (handle->ofw == 56) {
result = 28;
}
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) {
if (handle->ofw % 2 == 0) {
result = handle->ofw/2;
}
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_pack_input_fwd( libxsmm_dnn_layer* handle ) {
int result = 0;
/* Pack only for small images and when having large K to amortize, and we can only pack for 1x1 convolutions */
if ((handle->ofw <= 14) && (handle->desc.K > 512) && (handle->desc.R == 1) && (handle->desc.S == 1) && (handle->desc.u == 2) && (handle->desc.v == 2)) {
result = 1;
}
/* For SPR we allow packing more aggressively to generate more efficient BRGEMMs */
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) {
if ((handle->ofw <= 14) && (handle->desc.R == 1) && (handle->desc.S == 1) && (handle->desc.u == 2) && (handle->desc.v == 2)) {
result = 1;
}
}
/* Make sure we don't pack when minibatch is not divisible by number of threads since H is used potentially for parallelism */
if (handle->desc.N != handle->desc.threads) {
result = 0;
}
/* we don't pack for int8 */
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) {
result = 0;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fwd_ofh_rb( libxsmm_dnn_layer* handle ) {
int result = 1;
/* Multiple rows for "small" images and 1x1 convolutions */
if ((handle->ofh <= 14) && (handle->desc.R == 1) && (handle->desc.S == 1)) {
result = handle->ofh;
}
/* In this case we will be using fallback generic loops, thus ofh_rb should be 1 */
if ((handle->desc.N % handle->desc.threads != 0) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) {
result = 1;
}
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) {
if (handle->ofw == 7 && handle->ofh == 7 && handle->desc.R == 3 && handle->desc.S == 3) {
result = 7;
}
if (handle->ofw == 14 && handle->ofh == 14 /*&& handle->desc.R == 3 && handle->desc.S == 3*/) {
result = 2;
}
}
/* Make sure we don't use multiple rows when we don't pack input and convolutions are strided*/
if ((handle->pack_input == 0) && ((handle->desc.u !=1 ) || (handle->desc.v != 1))) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fwd_pixels_gemm( libxsmm_dnn_layer* handle ) {
int result = handle->fwd_ofw_rb * handle->fwd_ofh_rb;
/* In the case below we calculate redundantly pixels in order to efficiently use AMX */
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) {
if (handle->desc.R != 1 || handle->desc.R != 1) {
if (handle->ofw < 24) {
result = (handle->fwd_ofw_rb+2*handle->desc.pad_w) * (handle->fwd_ofh_rb-2) + 2 * (handle->fwd_ofw_rb+handle->desc.pad_w);
}
}
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fwd_block_H( libxsmm_dnn_layer* handle ) {
int result = 14;
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) {
/* Spatial dimension block tuning for SPR */
if ((handle->ofh == 7 && handle->desc.u == 2) || (handle->ofh == 14 && handle->desc.R != 3 ) || handle->ofh == 27 || (handle->ofh == 28 && handle->desc.R == 1) || handle->ofh == 48 || handle->ofh == 54 || handle->ofh == 56 || handle->ofh == 112 ) {
result = 4;
}
} else {
/* Block H only for large images */
if (handle->ofh >= 28) {
result = 4;
}
if (handle->ofh == 28 && handle->desc.R == 3 ) {
result = 14;
}
}
/* Make sure it is divisible bu the ofh_rb factor in the kernel */
while ( result % handle->fwd_ofh_rb != 0 ) {
result--;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_blocksifm_blocking( libxsmm_dnn_layer* handle ) {
int result = 1;
/* For 1x1 Convolutions bring in kernel all IFMs unless filters are huge*/
if ((handle->desc.R == 1) && (handle->desc.S == 1) ) {
result = handle->blocksifm;
if ((handle->desc.C >= 2048) && (handle->desc.K >= 512)) {
result = 1;
}
if ( (handle->target_archid < LIBXSMM_X86_AVX512) && (handle->desc.C >= 512) ) {
result = 2;
}
if ( (handle->target_archid < LIBXSMM_X86_AVX512) && (handle->desc.C >= 1024) ) {
result = 4;
}
} else {
result = 1;
/* If small image can bring in more IFMS even if NOT 1x1 convolution */
if (handle->ofw <= 7) {
result = 2;
}
}
if (handle->blocksifm % result != 0) {
result = 1;
}
/* In case of SPR bring always in all accumulation */
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8))) {
result = handle->blocksifm;
}
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) {
result = handle->blocksifm;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_loop_order_fwd( libxsmm_dnn_layer* handle ) {
int result = 0;
/* Switch to loop order 1 only if 1x1 convolution with "large" input image and "small" K */
if ((handle->desc.H >= 28) && (handle->desc.R == 1) && (handle->desc.S == 1) && (handle->desc.C >=512) && (handle->desc.K <=512)) {
result = 1;
}
if (handle->ofw == 56 && handle->desc.R == 1 && handle->desc.C == 256 && handle->desc.K == 64 ) {
result = 1;
}
if (handle->ofw == 28 && handle->desc.R == 1) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_fwd_IFM( libxsmm_dnn_layer* handle ) {
int result = 8;
if (handle->ofw == 7 && handle->desc.C == 2048 && handle->desc.K == 512) {
result = 4;
}
/* Make sure it is divisible by ifms in the kernel */
while (result % handle->blocksifm_blocking != 0) {
result++;
}
result = LIBXSMM_MIN(handle->blocksifm, result);
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_fwd_OFM( libxsmm_dnn_layer* handle ) {
int result = 8;
if (handle->ofw == 14 && handle->desc.K == 1024) {
result = 16;
}
if (handle->ofw == 7) {
result = 16;
}
result = LIBXSMM_MIN(handle->blocksofm, result);
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_use_ofm_parallelization( libxsmm_dnn_layer* handle ) {
int result = 0;
#if 0
/* Use "hybrid" minibatch/ofm parallelization if we have huge filters */
if ((handle->desc.R >= 3) && (handle->desc.S >= 3) && (handle->desc.C >= 512) && (handle->desc.K >= 512)) {
result = 1;
}
#endif
if ((handle->ofw <= 7) && (handle->desc.C == 1024) && (handle->desc.K == 512)) {
result = 1;
}
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8))) {
if (handle->ofw == 7) {
result = 1;
}
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd( libxsmm_dnn_layer* handle ) {
int result = 0;
/* Avoid rim FMA if the convolution is 3x3 (non-strided) and the image is "small" */
if ((handle->desc.R == 3) && (handle->desc.S == 3) &&
(handle->desc.u == 1) && (handle->desc.v == 1) &&
(handle->desc.pad_h_in == 1) && (handle->desc.pad_w_in == 1) &&
(handle->desc.H == handle->desc.W) ) {
if (handle->ofw <= 28) {
result = 1;
}
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) {
result = 0;
}
}
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8))) {
result = 0;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_shuffle_filter_accesses( libxsmm_dnn_layer* handle ) {
int result = 0;
/* Shuffle filter accesses only if "pure minibatch" parallelization and large filters are involved */
if ((handle->use_ofm_parallelization == 0) && (handle->desc.C > 512) && (handle->desc.K > 512)) {
result = 1;
}
if (handle->ofw == 7 && handle->desc.R == 3 && handle->desc.C == 512) {
result = 1;
}
if (handle->ofw == 7 && handle->desc.R == 1 && handle->desc.C == 512 && handle->desc.K == 2048) {
result = 1;
}
if (handle->ofw == 7 && handle->desc.R == 1 && handle->desc.C == 2048 && handle->desc.K == 512) {
result = 1;
}
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) {
result = 0;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_avoid_acc_load( libxsmm_dnn_layer* handle ) {
int result = 0;
if ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) {
if ((handle->desc.R == 1) && (handle->desc.S == 1)) {
if (handle->blocksifm_blocking == handle->blocksifm) {
result = 1;
}
} else {
if ((handle->blocksifm_blocking == handle->blocksifm) && (handle->avoid_fmas_in_rim == 0)) {
result = 1;
}
}
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_init_fwd_gemm_flags( libxsmm_dnn_layer* handle ) {
int result = 0;
#if defined(LIBXSMM_DNN_CONVOLUTION_SETUP_USE_NTS)
/* If large image and NOT already loaded in accumulators, tnen use streaming stores */
if ((handle->ofw >= 56) && (handle->desc.K >= 256) && (handle->avoid_acc_load == 1) && (handle->desc.R == 1) && (handle->desc.S == 1)) {
result = LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT;
}
if (handle->ofw == 56 && handle->desc.C == 64 && handle->desc.K == 64 && handle->desc.R == 1) {
result = LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT;
}
if (handle->ofw == 56 && handle->desc.C == 256 && handle->desc.K == 64 && handle->desc.R == 1) {
result = LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT;
}
/* Disable since the GEMM output is going to f32 scratch */
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16 || handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) {
result = 0;
}
#else
LIBXSMM_UNUSED(handle);
#endif
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8))) {
result = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fwd_padding_copy( libxsmm_dnn_layer* handle ) {
int result = 0;
if ( (handle->desc.pad_h != handle->desc.pad_h_in) && (handle->desc.pad_w != handle->desc.pad_w_in) ) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE void libxsmm_dnn_convolution_setup_fwd_scratch( libxsmm_dnn_layer* handle ) {
handle->fwd_packing_padding_scratch_size = 0;
/* packing of input */
if ( handle->pack_input != 0 ) {
handle->fwd_packing_padding_scratch_size = (size_t)handle->desc.N * handle->desc.C *
handle->desc.H/handle->desc.u *
handle->desc.W/handle->desc.v *
libxsmm_dnn_typesize(handle->datatype_in);
}
/* logical padding with copying in the fly */
if ( handle->fwd_padding_copy != 0 ) {
handle->fwd_packing_padding_scratch_size = (size_t)handle->desc.N * handle->desc.C *
(handle->desc.H + 2*handle->desc.pad_h) *
(handle->desc.W + 2*handle->desc.pad_w) *
libxsmm_dnn_typesize(handle->datatype_in);
}
/* output buffer in high precision when we use BF16 */
if ( ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16 ) ||
( handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8 ) ) {
handle->fwd_lp_output_full_scratch_size = (size_t) LIBXSMM_MAX(handle->desc.threads * handle->fwd_gemm_pixels * handle->ofmblock * libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32), handle->desc.N * handle->desc.K * handle->ofwp * handle->ofhp * libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32));
handle->fwd_lp_output_block_scratch_size = (size_t)handle->desc.threads * handle->fwd_ofw_rb *
handle->fwd_ofh_rb * handle->ofmblock *
libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32);
} else {
handle->fwd_lp_output_full_scratch_size = 0;
handle->fwd_lp_output_block_scratch_size = 0;
}
/* align sizes to full cacheline */
handle->fwd_packing_padding_scratch_size += ( handle->fwd_packing_padding_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->fwd_packing_padding_scratch_size % LIBXSMM_CACHELINE);
handle->fwd_lp_output_full_scratch_size += ( handle->fwd_lp_output_full_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->fwd_lp_output_full_scratch_size % LIBXSMM_CACHELINE);
handle->fwd_lp_output_block_scratch_size += ( handle->fwd_lp_output_block_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->fwd_lp_output_block_scratch_size % LIBXSMM_CACHELINE);
/* set offsets */
handle->fwd_packing_padding_scratch_offset = 0;
handle->fwd_lp_output_full_scratch_offset = handle->fwd_packing_padding_scratch_size;
handle->fwd_lp_output_block_scratch_offset = handle->fwd_lp_output_full_scratch_offset +
handle->fwd_lp_output_full_scratch_size;
/* set overall scratch size for forward */
handle->fwd_scratch_size = handle->fwd_packing_padding_scratch_size +
handle->fwd_lp_output_full_scratch_size +
handle->fwd_lp_output_block_scratch_size;
}
/**********************************************************/
/* Helper functions for BWD convolutions' parameter setup */
/**********************************************************/
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fallback_loops_bwd( libxsmm_dnn_layer* handle ) {
int result = 0;
/* FIXME: Fallback if MB is not divisible by number of threads */
if (handle->desc.N % handle->desc.threads != 0) {
result = 1;
}
if (handle->desc.R == 1 && handle->desc.S == 1 && (handle->desc.pad_h != 0 || handle->desc.pad_w != 0)) {
result = 1;
}
if ((handle->desc.R > 1 && handle->desc.pad_h == 0) || (handle->desc.S > 1 && handle->desc.pad_w == 0)) {
result = 1;
}
if ((handle->desc.R > 1 && (handle->desc.pad_h_out == 0 || handle->desc.pad_h_in == 0)) ||
(handle->desc.S > 1 && (handle->desc.pad_w_out == 0 || handle->desc.pad_w_in == 0)) ) {
result = 1;
}
if ((handle->desc.R > 1 && handle->desc.u > 1) || (handle->desc.S > 1 && handle->desc.v > 1)) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_bwd_ofw_rb( libxsmm_dnn_layer* handle ) {
int result = libxsmm_dnn_convolution_setup_fwd_ofw_rb(handle);
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_bwd_ofh_rb( libxsmm_dnn_layer* handle ) {
int result = libxsmm_dnn_convolution_setup_fwd_ofh_rb(handle);
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_bwd_pixels_gemm( libxsmm_dnn_layer* handle ) {
int result = handle->bwd_ofw_rb * handle->bwd_ofh_rb;
/* In the case below we calculate redundantly pixels in order to efficiently use AMX */
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) {
if (handle->desc.R != 1 || handle->desc.R != 1) {
if (handle->ofw < 24) {
result = (handle->bwd_ofw_rb+2*handle->desc.pad_w) * (handle->bwd_ofh_rb-2) + 2 * (handle->bwd_ofw_rb+handle->desc.pad_w);
}
}
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_bwd_block_H( libxsmm_dnn_layer* handle ) {
int result = 0;
result = libxsmm_dnn_convolution_setup_fwd_block_H(handle);
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_loop_order_bwd( libxsmm_dnn_layer* handle ) {
int result = 0;
result = libxsmm_dnn_convolution_setup_loop_order_fwd(handle);
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_bwd_IFM( libxsmm_dnn_layer* handle ) {
int result = 0;
result = LIBXSMM_MIN(handle->blocksifm, 16);
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_bwd_OFM( libxsmm_dnn_layer* handle ) {
int result = 8;
while (result % handle->blocksofm_blocking != 0) {
result++;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_pack_input_bwd( libxsmm_dnn_layer* handle ) {
int result = 0;
if ((handle->desc.u != 1) && (handle->bwd_ofh_rb != 1)) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_use_ifm_parallelization( libxsmm_dnn_layer* handle ) {
int result = 0;
if (handle->ofw <= 7) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_avoid_rim_fmas_bwd( libxsmm_dnn_layer* handle ) {
int result = libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd(handle);
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_blocksofm_blocking( libxsmm_dnn_layer* handle ) {
int result = 0;
if (handle->desc.R == 1 && handle->desc.S == 1) {
result = handle->blocksofm;
} else {
result = 1;
if (handle->desc.R == 3 && handle->desc.S == 3 && handle->ofh == 7 && handle->ofw == 7) {
result = 2;
}
}
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) {
result = handle->blocksofm;
}
if (handle->blocksofm % result != 0) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_init_bwd_gemm_flags( libxsmm_dnn_layer* handle ) {
int result = 0;
LIBXSMM_UNUSED( handle );
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) {
result = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_spread_input_bwd( libxsmm_dnn_layer* handle ) {
int result = 0;
LIBXSMM_UNUSED(handle);
if (((handle->desc.u != 1) || (handle->desc.v != 1)) && (handle->bwd_ofh_rb == 1)) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_avoid_acc_load_bwd( libxsmm_dnn_layer* handle ) {
int result = 0;
if ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) {
if ((handle->desc.R == 1) && (handle->desc.S == 1)) {
if (handle->blocksofm_blocking == handle->blocksofm) {
result = 1;
}
} else {
if ((handle->blocksofm_blocking == handle->blocksofm) && (handle->avoid_fmas_in_rim == 0)) {
result = 1;
}
}
}
return result;
}
LIBXSMM_API_INLINE void libxsmm_dnn_convolution_setup_bwd_scratch( libxsmm_dnn_layer* handle ) {
/* transpose of weights */
handle->bwd_filter_trans_scratch_size = (size_t)handle->desc.C * handle->desc.K *
handle->desc.R * handle->desc.S *
libxsmm_dnn_typesize(handle->datatype_in);
handle->bwd_packing_padding_scratch_size = 0;
/* packing of input */
if ( handle->pack_input_bwd != 0 ) {
handle->bwd_packing_padding_scratch_size = (size_t)handle->desc.N * handle->desc.C *
handle->ofhp * handle->ofwp *
libxsmm_dnn_typesize(handle->datatype_in);
}
/* logical padding with copying in the fly */
if ( handle->use_fallback_bwd_loops != 0 ) {
handle->bwd_packing_padding_scratch_size = (size_t)handle->desc.threads * handle->ifmblock *
(handle->desc.H + 2*handle->desc.pad_h) *
(handle->desc.W + 2*handle->desc.pad_w) *
libxsmm_dnn_typesize(handle->datatype_in);
}
/* input bufffer in high precision when we use BF16 */
if ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16 ) {
handle->bwd_lp_input_full_scratch_size = (size_t) LIBXSMM_MAX(handle->desc.threads * handle->bwd_gemm_pixels * handle->ifmblock * libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32), handle->desc.N * handle->desc.C * handle->ifwp * handle->ifhp * libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32));
/* logical padding with copying in the fly */
if ( handle->use_fallback_bwd_loops != 0 ) {
handle->bwd_packing_padding_scratch_size = (size_t)handle->desc.threads * handle->ifmblock *
(handle->desc.H + 2*handle->desc.pad_h) *
(handle->desc.W + 2*handle->desc.pad_w) *
libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32);
}
} else {
handle->bwd_lp_input_full_scratch_size = 0;
}
/* align sizes to full cacheline */
handle->bwd_filter_trans_scratch_size += ( handle->bwd_filter_trans_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->bwd_filter_trans_scratch_size % LIBXSMM_CACHELINE);
handle->bwd_packing_padding_scratch_size += ( handle->bwd_packing_padding_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->bwd_packing_padding_scratch_size % LIBXSMM_CACHELINE);
handle->bwd_lp_input_full_scratch_size += ( handle->bwd_lp_input_full_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->bwd_lp_input_full_scratch_size % LIBXSMM_CACHELINE);
/* set offsets */
handle->bwd_filter_trans_scratch_offset = 0;
handle->bwd_packing_padding_scratch_offset = handle->bwd_filter_trans_scratch_size;
handle->bwd_lp_input_full_scratch_offset = handle->bwd_packing_padding_scratch_offset +
handle->bwd_packing_padding_scratch_size;
/* set overall scratch size for forward */
handle->bwd_scratch_size = handle->bwd_filter_trans_scratch_size +
handle->bwd_packing_padding_scratch_size +
handle->bwd_lp_input_full_scratch_size;
}
/**********************************************************/
/* Helper functions for UPD convolutions' parameter setup */
/**********************************************************/
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_loop_order_upd( libxsmm_dnn_layer* handle ) {
int result = 1;
if (handle->ofh == 28 && handle->desc.R == 1 && handle->desc.u == 1 && handle->desc.C == 128 && handle->desc.K == 512) {
result = 0;
}
if (handle->ofh == 28 && handle->desc.R == 3 && handle->desc.u == 1 && handle->desc.C == 128 && handle->desc.K == 128) {
result = 0;
}
if (handle->ofw == 28 && handle->desc.R == 1 && handle->desc.C == 256 && handle->desc.K == 512) {
result = 0;
}
if (handle->ofw == 14 && !(handle->desc.R == 1 && handle->desc.C == 1024 && handle->desc.K == 256)) {
result = 0;
}
if (handle->ofw == 7) {
result = 0;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_pack_input_upd( libxsmm_dnn_layer* handle ) {
int result = 0;
/* Pack input only for very small images, 1x1 convs, with large K to amortize the relevant overhead */
if ((handle->ofh <= 7) && (handle->desc.R == 1) && (handle->desc.S == 1) && (handle->desc.u != 1) && (handle->desc.v != 1) && (handle->desc.K >= 2048)) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_avoid_rim_fmas_upd( libxsmm_dnn_layer* handle ) {
int result = 0;
/* Avoid rim FMAs only for small images */
if ( (handle->ofh <= 7) && (handle->desc.R == 3) && (handle->desc.S == 3) && (handle->desc.pad_w == 1) && (handle->desc.pad_h == 1)) {
result = 1;
}
if (handle->desc.N != handle->desc.threads) {
result = 0;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_upd_ofw_rb( libxsmm_dnn_layer* handle ) {
int result = 1;
result = handle->ofw;
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_upd_ofh_rb( libxsmm_dnn_layer* handle ) {
int result = 1;
/* Restrict the reduction chain which is ofw_rb*ofh_rb*/
if (handle->ofh <= 28 ) {
result = handle->ofh;
}
/* In the following scenario with strided convolutions and non batch reduce kernel make sure we have ofh_rb = 1 */
if ((handle->desc.u != 1) && (handle->desc.v != 1) && (handle->upd_use_batchreduce == 0) && (handle->upd_pack_input == 0)) {
result = 1;
}
/* If using linearized taskview and have strided convs, make sure ofh_rb is 1.. */
if (handle->upd_linearized_tasklist == 1 && handle->upd_avoid_rim_fmas == 0 && handle->upd_pack_input == 0 && handle->desc.u != 1) {
result = 1;
}
if (handle->upd_linearized_tasklist == 1 && handle->upd_use_batchreduce == 0 && (handle->desc.R != 1 || handle->desc.S != 1)) {
result = 1;
}
if (handle->upd_linearized_tasklist == 0 && handle->upd_use_batchreduce == 0 && (handle->desc.R != 1 || handle->desc.S != 1)) {
result = 1;
}
if (handle->ofw == 56 && handle->desc.R == 1) {
result = 2;
}
if (handle->upd_linearized_tasklist == 1 && handle->upd_use_batchreduce == 1 && handle->upd_avoid_rim_fmas == 1) {
result = handle->ofh;
}
if ((handle->desc.N != handle->desc.threads) && (handle->desc.R > 1 || handle->desc.S > 1 ) && (handle->desc.u > 1 || handle->desc.v > 1 )) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_upd_IFM( libxsmm_dnn_layer* handle ) {
int result = 1;
if (handle->ofh == 56 && handle->desc.R == 1 && handle->desc.S == 1 && handle->desc.u == 1 && handle->desc.v == 1) {
result = 4;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_upd_OFM( libxsmm_dnn_layer* handle ) {
int result = 1;
LIBXSMM_UNUSED(handle);
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_img_batchreduce_block( libxsmm_dnn_layer* handle ) {
int result = 1;
LIBXSMM_UNUSED(handle);
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_use_batchreduce_upd( libxsmm_dnn_layer* handle ) {
int result = 1;
/* If W is large, no need for batchreduce kernel */
if (handle->ofw >= 56) {
result = 0;
}
/* If we have packed the input, then disable batch-reduce GEMM */
if (handle->upd_pack_input == 1) {
result = 0;
}
if (handle->upd_linearized_tasklist == 1 && handle->upd_avoid_rim_fmas == 0) {
result = 0;
}
if (handle->upd_linearized_tasklist == 1 && handle->upd_avoid_rim_fmas == 1) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_weight_copies_upd( libxsmm_dnn_layer* handle ) {
int result = handle->desc.threads;
if (handle->ofw <= 14) {
result = 9;
}
if (handle->ofw == 14 && handle->desc.N == 92 && handle->desc.threads == 92) {
result = 23;
}
if (handle->ofw == 7 && handle->desc.N == 92 && handle->desc.threads == 92 && handle->desc.R == 3 && handle->desc.S == 3 && handle->desc.u == 1 && handle->desc.v == 1) {
result = 23;
}
while (handle->desc.threads % result != 0) {
result--;
}
/* FIXME: Hardcoded logic for N=27, N=26 */
if (handle->desc.N == 27 && handle->desc.threads == 27 && handle->desc.R == 1 && handle->ofw == 14 && handle->desc.u == 1) {
result = 7;
}
if (((handle->ofh == 14) || (handle->ofw == 7 && handle->desc.u == 2)) && handle->desc.N == 26 && handle->desc.threads == 26) {
result = 13;
}
if ((handle->desc.N != handle->desc.threads) && !(handle->upd_linearized_tasklist == 0 && handle->upd_use_batchreduce == 0)) {
result = handle->desc.N;
}
/* Make sure a single copy when we use linearized-task view */
if (handle->upd_linearized_tasklist == 1) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_linearized_tasklist_upd( libxsmm_dnn_layer* handle ) {
int result = 0;
/* Use linearized task-list (i.e. no reduction) only if small images and large filters */
if (handle->ofh <= 10 && handle->ofw <= 10) {
result = 1;
}
if (handle->ofw == 7 && handle->desc.N == 92 && handle->desc.threads == 92 && handle->desc.R == 3 && handle->desc.S == 3 && handle->desc.u == 1 && handle->desc.v == 1) {
result = 0;
}
if (handle->ofh == 14 && handle->ofw == 14 && handle->desc.N == 23 && handle->desc.threads == 23) {
result = 1;
}
#if 0
if ((handle->blocksofm * handle->blocksifm * handle->desc.R * handle->desc.S > (handle->desc.threads * 4)) && (handle->ofh <= 56)) {
result = 1;
}
#endif
if (handle->desc.u == 2 && handle->desc.v == 2 && handle->desc.K == 512) {
result = 0;
}
return result;
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_init_upd_gemm_flags( libxsmm_dnn_layer* handle ) {
int result = 0;
LIBXSMM_UNUSED(handle);
return result;
}
LIBXSMM_API_INLINE void libxsmm_dnn_convolution_setup_bf16_upd( libxsmm_dnn_layer* handle ) {
int remainder_pixels, max_init_offset, max_compute_offset_input, input_compute_pad, accum_length_pixels, compute_pixels;
const int multiple_target = 2;
int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2 * handle->desc.pad_h : handle->ifhp;
int IFWP = (handle->upd_padding_copy == 1) ? handle->ifwp + 2 * handle->desc.pad_w : handle->ifwp;
int OFHP = (handle->upd_padding_copy == 1) ? handle->ofhp + 2 * handle->desc.pad_h : handle->ofhp;
int OFWP = (handle->upd_padding_copy == 1) ? handle->ofwp + 2 * handle->desc.pad_w : handle->ofwp;
handle->upd_linearized_pixels = 1;
if (handle->desc.S != 1 && handle->desc.v != 1) {
handle->upd_linearized_pixels = 0;
handle->upd_trans_w_only = 0;
}
/* For large images facilitate the "large" transposes by blocking the pixel/reduction domains */
if (handle->ofw >= 56 && handle->ofh >=56 && handle->desc.R == 1 && handle->desc.S == 1 && handle->desc.u == 1 && handle->desc.v == 1) {
handle->upd_linearized_pixels = 0;
handle->upd_trans_w_only = 1;
}
handle->on_the_fly_input_packing = 0;
handle->upd_pack_input_upfront = 0;
handle->use_hybrid_imgofm_parallelization = 0;
handle->upd_linearized_tasklist = 0;
if (handle->upd_linearized_pixels == 1) {
/* Logistics to pad accumulation chainlength */
compute_pixels = handle->ofw * handle->ofh + 2 * handle->desc.pad_w * (handle->ofh-1);
remainder_pixels = (compute_pixels % multiple_target == 0) ? 0 : (compute_pixels/multiple_target+1)*multiple_target - compute_pixels;
accum_length_pixels = compute_pixels + remainder_pixels;
/* In this case compact input upfront */
if (handle->desc.R == 1 && handle->desc.S == 1 && (handle->desc.u != 1 || handle->desc.v != 1)) {
handle->upd_pack_input_upfront = 1;
}
/* Logistics for input transpose and additional pixel padding */
max_init_offset = 2 * handle->desc.pad_h * IFWP + 2 * handle->desc.pad_w;
max_compute_offset_input = max_init_offset + accum_length_pixels;
input_compute_pad = (max_compute_offset_input > IFWP*IFHP) ? max_compute_offset_input - IFWP*IFHP : 0;
handle->input_pixels = IFWP * IFHP + input_compute_pad;
if (handle->upd_pack_input_upfront) {
handle->input_pixels = accum_length_pixels;
}
handle->output_pixels = accum_length_pixels;
handle->pixel_blocking = accum_length_pixels;
handle->n_used_pixels = accum_length_pixels;
handle->compute_pixels = compute_pixels;
handle->use_intermediate_f32_wt_tensor = (handle->pixel_blocking == handle->n_used_pixels) ? 0 : 1;
if (handle->ofw <= 14) {
handle->use_hybrid_imgofm_parallelization = 1;
handle->weight_copies = libxsmm_dnn_convolution_setup_weight_copies_upd(handle);
if (handle->ofw == 14 && handle->desc.K >= 1024) {
handle->use_hybrid_imgofm_parallelization = 0;
handle->weight_copies = handle->desc.threads;
}
} else {
handle->weight_copies = handle->desc.threads;
}
}
if (handle->upd_linearized_pixels == 0) {
handle->weight_copies = handle->desc.threads;
if (handle->desc.v !=1) {
handle->on_the_fly_input_packing = 1;
}
remainder_pixels = (handle->ofw % multiple_target == 0) ? 0 : (handle->ofw/multiple_target+1)*multiple_target - handle->ofw;
handle->ofwp_extended = OFWP + remainder_pixels;
handle->ifwp_extended = IFWP + remainder_pixels;
handle->output_pixels = OFHP * handle->ofwp_extended;
/* coverity[identical_branches] */
handle->batchreduce_h_pixels = (handle->upd_trans_w_only) ? 1 : 1; /* TODO: identical_branches */
handle->use_intermediate_f32_wt_tensor = (handle->batchreduce_h_pixels == handle->ofh) ? 0 : 1;
}
if (handle->desc.N != handle->desc.threads) {
handle->use_intermediate_f32_wt_tensor = 1;
handle->use_hybrid_imgofm_parallelization = 0;
handle->weight_copies = LIBXSMM_MIN(handle->desc.N, handle->desc.threads);
}
}
LIBXSMM_API_INLINE void libxsmm_dnn_convolution_setup_bf16_upd_amx( libxsmm_dnn_layer* handle ) {
/* JIT related variables... */
libxsmm_blasint LDA = handle->ofmblock;
libxsmm_blasint LDB = handle->input_pixels;
libxsmm_blasint LDC = handle->ofmblock;
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG;
int l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
size_t stride_a, stride_b;
int unroll_hint;
float beta;
int remainder_pixels, max_init_offset, max_compute_offset_input, input_compute_pad, accum_length_pixels, compute_pixels;
const int multiple_target = 32;
int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2 * handle->desc.pad_h : handle->ifhp;
int IFWP = (handle->upd_padding_copy == 1) ? handle->ifwp + 2 * handle->desc.pad_w : handle->ifwp;
int OFWP = (handle->upd_padding_copy == 1) ? handle->ofwp + 2 * handle->desc.pad_w : handle->ofwp;
handle->upd_linearized_pixels = 1;
if (handle->desc.S != 1 && handle->desc.v != 1) {
handle->upd_linearized_pixels = 0;
}
handle->fuse_upd_transposes = 1;
handle->pack_to_cnhw = 0;
handle->on_the_fly_input_packing = 0;
handle->upd_pack_input_upfront = 0;
handle->use_hybrid_imgofm_parallelization = 0;
handle->upd_linearized_tasklist = 0;
if (((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) && (handle->ofw == 7) && (handle->desc.R == 1) && (handle->desc.S == 1) ) {
handle->pack_to_cnhw= 1;
}
if (handle->upd_linearized_pixels == 1) {
if (handle->pack_to_cnhw == 0) {
handle->fuse_upd_transposes = 1;
/* Logistics to pad accumulation chainlength */
compute_pixels = handle->ofw * handle->ofh + 2 * handle->desc.pad_w * (handle->ofh-1);
remainder_pixels = (compute_pixels % multiple_target == 0) ? 0 : (compute_pixels/multiple_target+1)*multiple_target - compute_pixels;
accum_length_pixels = compute_pixels + remainder_pixels;
/* In this case compact input upfront */
if (handle->desc.R == 1 && handle->desc.S == 1 && (handle->desc.u != 1 || handle->desc.v != 1)) {
handle->upd_pack_input_upfront = 1;
}
/* Logistics for input transpose and additional pixel padding */
max_init_offset = 2 * handle->desc.pad_h * IFWP + 2 * handle->desc.pad_w;
max_compute_offset_input = max_init_offset + accum_length_pixels;
input_compute_pad = (max_compute_offset_input > IFWP*IFHP) ? max_compute_offset_input - IFWP*IFHP : 0;
handle->input_pixels = IFWP*IFHP+ input_compute_pad;
if (handle->upd_pack_input_upfront) {
handle->input_pixels = accum_length_pixels;
}
handle->output_pixels = accum_length_pixels;
handle->pixel_blocking = accum_length_pixels;
handle->n_used_pixels = accum_length_pixels;
handle->compute_pixels = compute_pixels;
handle->use_intermediate_f32_wt_tensor = (handle->pixel_blocking == handle->n_used_pixels) ? 0 : 1;
#if 0
handle->scratch2_size = (size_t) (handle->desc.N * handle->output_pixels * handle->desc.K * sizeof(float)/2);
if (handle->use_intermediate_f32_wt_tensor) {
handle->scratch2_size += (size_t) handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * handle->desc.threads * sizeof(float);
}
handle->scratch3_size = (size_t) (handle->desc.N * handle->input_pixels * handle->desc.C * sizeof(float)/2);
#endif
if (handle->ofw <= 14) {
handle->use_hybrid_imgofm_parallelization = 1;
handle->fuse_upd_transposes = 0;
} else {
handle->weight_copies = handle->desc.threads;
}
if ((handle->ofmblock % 32 != 0) || (handle->ifmblock % 32 != 0)) {
handle->fuse_upd_transposes = 0;
}
} else {
/* Logistics to pad accumulation chainlength */
handle->use_hybrid_imgofm_parallelization = 1;
handle->weight_copies = 7;
while (handle->desc.threads % handle->weight_copies != 0) {
handle->weight_copies--;
}
compute_pixels = handle->ofw * handle->ofh * (handle->desc.N/handle->weight_copies);
remainder_pixels = (compute_pixels % multiple_target == 0) ? 0 : (compute_pixels/multiple_target+1)*multiple_target - compute_pixels;
handle->remainder_pixels = remainder_pixels;
accum_length_pixels = compute_pixels + remainder_pixels;
handle->output_pixels = accum_length_pixels * handle->weight_copies;
handle->input_pixels = accum_length_pixels * handle->weight_copies;
handle->pixel_blocking = accum_length_pixels;
handle->n_used_pixels = accum_length_pixels;
handle->use_intermediate_f32_wt_tensor = 0;
#if 0
handle->scratch2_size = (size_t) (handle->weight_copies * handle->output_pixels * handle->desc.K * sizeof(float)/2);
handle->scratch3_size = (size_t) (handle->weight_copies * handle->input_pixels * handle->desc.C * sizeof(float)/2);
#endif
}
}
if (handle->upd_linearized_pixels == 0) {
handle->weight_copies = handle->desc.threads;
if (handle->desc.v !=1) {
handle->on_the_fly_input_packing = 1;
}
remainder_pixels = (handle->ofw % multiple_target == 0) ? 0 : (handle->ofw/multiple_target+1)*multiple_target - handle->ofw;
handle->remainder_pixels = remainder_pixels;
handle->ofwp_extended = OFWP + remainder_pixels;
handle->ifwp_extended = IFWP + remainder_pixels;
handle->batchreduce_h_pixels = handle->ofh;
handle->use_intermediate_f32_wt_tensor = (handle->batchreduce_h_pixels == handle->ofh) ? 0 : 1;
#if 0
handle->scratch2_size = (size_t) (handle->desc.N * handle->ofhp*handle->ofwp_extended * handle->desc.K * sizeof(float)/2);
if (handle->use_intermediate_f32_wt_tensor) {
handle->scratch2_size += (size_t) handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * handle->desc.threads * sizeof(float);
}
handle->scratch3_size = (size_t) (handle->desc.N * handle->ifhp * handle->ifwp_extended * handle->desc.C * sizeof(float)/2);
#endif
}
/* Now that all decisions have been made, JIT the proper kernel... */
beta = (handle->use_intermediate_f32_wt_tensor) ? (float)1.0 : (float)0.0;
if (handle->upd_linearized_pixels == 0) {
LDA = handle->ofmblock;
LDB = IFHP*handle->ifwp_extended;
LDC = handle->ofmblock;
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
unroll_hint = handle->batchreduce_h_pixels;
stride_a = handle->ofwp_extended * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in);
stride_b = handle->desc.u * handle->ifwp_extended * libxsmm_dnn_typesize(handle->datatype_in);
handle->upd_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->ofw+handle->remainder_pixels, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL);
handle->upd_compute_kernel_brgemm_no_linearized_pixels = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->ofmblock, handle->ifmblock, handle->ofw+handle->remainder_pixels,
(libxsmm_blasint)stride_a, (libxsmm_blasint)stride_b, unroll_hint, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode);
} else {
LDA = handle->ofmblock;
LDB = handle->input_pixels;
LDC = handle->ofmblock;
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
if (handle->use_hybrid_imgofm_parallelization == 0) {
handle->upd_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL);
handle->upd_compute_kernel_gemm_linearized_pixels_no_hybrid_par = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode);
} else {
if (handle->pack_to_cnhw == 1) {
handle->upd_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL);
handle->upd_compute_kernel_gemm_linearized_pixels_hybrid_par_cnhw = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode);
} else {
/* TODO: Hoist here hybrid parallelization logic and then we should be able to also provide unroll hint in the BRGEMM call */
stride_a = handle->blocksofm * handle->output_pixels * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in);
stride_b = handle->blocksifm * handle->ifmblock * handle->input_pixels * libxsmm_dnn_typesize(handle->datatype_in);
handle->upd_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL);
handle->upd_compute_kernel_brgemm_linearized_pixels_hybrid_par_no_cnhw = libxsmm_bsmmdispatch_reducebatch_strd(handle->ofmblock, handle->ifmblock, handle->pixel_blocking,
(libxsmm_blasint)stride_a, (libxsmm_blasint)stride_b, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode);
}
}
}
if (handle->desc.N != handle->desc.threads) {
handle->use_intermediate_f32_wt_tensor = 1;
handle->use_hybrid_imgofm_parallelization = 0;
handle->weight_copies = LIBXSMM_MIN(handle->desc.N, handle->desc.threads);
}
}
LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_upd_padding_copy( libxsmm_dnn_layer* handle ) {
int result = 0;
if ( (handle->desc.pad_h != handle->desc.pad_h_in) && (handle->desc.pad_w != handle->desc.pad_w_in) ) {
result = 1;
}
return result;
}
LIBXSMM_API_INLINE void libxsmm_dnn_convolution_setup_upd_scratch( libxsmm_dnn_layer* handle ) {
handle->upd_packing_padding_scratch_size = 0;
/* packing of input */
if ( handle->upd_pack_input != 0 ) {
handle->upd_packing_padding_scratch_size = (size_t)handle->desc.N * handle->desc.C *
handle->desc.H/handle->desc.u *
handle->desc.W/handle->desc.v *
libxsmm_dnn_typesize(handle->datatype_in);
}
/* logical padding with copying in the fly */
if ( handle->upd_padding_copy != 0 ) {
handle->upd_packing_padding_scratch_size = (size_t)handle->desc.N * handle->desc.C *
(handle->desc.H + 2*handle->desc.pad_h) *
(handle->desc.W + 2*handle->desc.pad_w) *
libxsmm_dnn_typesize(handle->datatype_in);
}
/* output/input buffer to transpose when we use bf16 */
if ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16 ) {
if (handle->target_archid >= LIBXSMM_X86_AVX512_SPR) {
int OFHP = (handle->upd_padding_copy == 1) ? handle->ofhp + 2 * handle->desc.pad_h : handle->ofhp;
int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2 * handle->desc.pad_h : handle->ifhp;
if (handle->upd_linearized_pixels == 1) {
handle->upd_lp_output_full_scratch_size = (size_t) (handle->desc.N * handle->output_pixels * handle->desc.K * sizeof(handle->datatype_in));
handle->upd_lp_input_full_scratch_size = (size_t) (handle->desc.N * handle->input_pixels * handle->desc.C * sizeof(handle->datatype_in));
}
if (handle->upd_linearized_pixels == 0) {
handle->upd_lp_output_full_scratch_size = (size_t) (handle->desc.N * OFHP * handle->ofwp_extended * handle->desc.K * sizeof(handle->datatype_in));
handle->upd_lp_input_full_scratch_size = (size_t) (handle->desc.N * IFHP * handle->ifwp_extended * handle->desc.C * sizeof(handle->datatype_in));
}
} else {
const int multiple_target = 2;
int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2 * handle->desc.pad_h : handle->ifhp;
int IFWP = (handle->upd_padding_copy == 1) ? handle->ifwp + 2 * handle->desc.pad_w : handle->ifwp;
int OFHP = (handle->upd_padding_copy == 1) ? handle->ofhp + 2 * handle->desc.pad_h : handle->ofhp;
int OFWP = (handle->upd_padding_copy == 1) ? handle->ofwp + 2 * handle->desc.pad_w : handle->ofwp;
if (handle->upd_linearized_pixels == 1) {
int compute_pixels = handle->ofw * handle->ofh + 2 * handle->desc.pad_w * (handle->ofh-1);
int remainder_pixels = (compute_pixels % multiple_target == 0) ? 0 : (compute_pixels/multiple_target+1)*multiple_target - compute_pixels;
int accum_length_pixels = compute_pixels + remainder_pixels;
int max_init_offset = 2 * handle->desc.pad_h * IFWP + 2 * handle->desc.pad_w;
int max_compute_offset_input = max_init_offset + accum_length_pixels;
int input_compute_pad = (max_compute_offset_input > IFWP*IFHP) ? max_compute_offset_input - IFWP*IFHP : 0;
int input_pixels = IFWP * IFHP + input_compute_pad;
if (handle->upd_pack_input_upfront == 1) {
input_pixels = accum_length_pixels;
}
handle->upd_lp_output_full_scratch_size = (size_t) (handle->desc.N * accum_length_pixels * handle->desc.K * sizeof(handle->datatype_in));
handle->upd_lp_input_full_scratch_size = (size_t) (handle->desc.N * input_pixels * handle->desc.C * sizeof(handle->datatype_in));
}
if (handle->upd_linearized_pixels == 0) {
int remainder_pixels = (handle->ofw % multiple_target == 0) ? 0 : (handle->ofw/multiple_target+1)*multiple_target - handle->ofw;
int ofwp_extended = OFWP + remainder_pixels;
int ifwp_extended = IFWP + remainder_pixels;
handle->upd_lp_output_full_scratch_size = (size_t) (handle->desc.N * OFHP * ofwp_extended * handle->desc.K * sizeof(handle->datatype_in));
handle->upd_lp_input_full_scratch_size = (size_t) (handle->desc.N * IFHP * ifwp_extended * handle->desc.C * sizeof(handle->datatype_in));
}
}
handle->upd_lp_filter_full_scratch_size = (size_t)handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * handle->desc.threads *
libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32);
} else {
handle->upd_lp_output_full_scratch_size = 0;
handle->upd_lp_input_full_scratch_size = 0;
handle->upd_lp_filter_full_scratch_size = 0;
}
/* filter scratch */
handle->upd_filter_scratch_size = (size_t) handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * LIBXSMM_MAX(handle->desc.threads, handle->desc.N) * sizeof(float);
/* align sizes to full cacheline */
handle->upd_packing_padding_scratch_size += ( handle->upd_packing_padding_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->upd_packing_padding_scratch_size % LIBXSMM_CACHELINE);
handle->upd_lp_output_full_scratch_size += ( handle->upd_lp_output_full_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->upd_lp_output_full_scratch_size % LIBXSMM_CACHELINE);
handle->upd_lp_input_full_scratch_size += ( handle->upd_lp_input_full_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->upd_lp_input_full_scratch_size % LIBXSMM_CACHELINE);
handle->upd_filter_scratch_size += ( handle->upd_filter_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->upd_filter_scratch_size % LIBXSMM_CACHELINE);
handle->upd_lp_filter_full_scratch_size += ( handle->upd_lp_filter_full_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 :
LIBXSMM_CACHELINE - (handle->upd_lp_filter_full_scratch_size % LIBXSMM_CACHELINE);
/* calculate offsets */
handle->upd_packing_padding_scratch_offset = 0;
handle->upd_lp_output_full_scratch_offset = handle->upd_packing_padding_scratch_size;
handle->upd_lp_input_full_scratch_offset = handle->upd_lp_output_full_scratch_offset + handle->upd_lp_output_full_scratch_size;
handle->upd_filter_scratch_offset = handle->upd_lp_input_full_scratch_offset + handle->upd_lp_input_full_scratch_size;
handle->upd_lp_filter_full_scratch_offset = handle->upd_filter_scratch_offset + handle->upd_filter_scratch_size;
/* set overall scratch size for update */
handle->upd_scratch_size = handle->upd_packing_padding_scratch_size +
handle->upd_lp_output_full_scratch_size +
handle->upd_lp_input_full_scratch_size +
handle->upd_filter_scratch_size +
handle->upd_lp_filter_full_scratch_size;
}
LIBXSMM_API_INLINE libxsmm_dnn_err_t libxsmm_dnn_convolution_setup( libxsmm_dnn_layer* handle ) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
libxsmm_blasint _ldi = 64, _ldo = 64;
libxsmm_blasint ldx;
libxsmm_blasint ldA;
libxsmm_blasint ldC;
int beta_int;
float beta;
int l_flags;
int l_tc_flags;
/* init libxsmm */
LIBXSMM_INIT
/* Generic parameter setup */
handle->target_archid = libxsmm_target_archid;
if ( ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) && (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && ((handle->desc.C % 16 != 0) || (handle->desc.K % 16 != 0)) ) {
handle->target_archid = LIBXSMM_X86_AVX512_CPX;
}
handle->ifmblock = libxsmm_dnn_convolution_setup_ifmblock(handle);
handle->ofmblock = libxsmm_dnn_convolution_setup_ofmblock(handle);
handle->fm_lp_block = libxsmm_dnn_convolution_setup_fm_lp_block(handle);
handle->blocksifm = libxsmm_dnn_convolution_setup_blocksifm(handle);
handle->blocksofm = libxsmm_dnn_convolution_setup_blocksofm(handle);
/* If in SPR, generate tilerelease kernel */
if (handle->target_archid >= LIBXSMM_X86_AVX512_SPR) {
int l_tr_flags = LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
handle->tilerelease_kernel = libxsmm_bsmmdispatch(handle->ifmblock, handle->ifmblock, handle->ifmblock, NULL, NULL, NULL, NULL, NULL, &l_tr_flags, NULL);
}
/* FWD parameter setup */
handle->fwd_ofw_rb = libxsmm_dnn_convolution_setup_fwd_ofw_rb(handle);
handle->pack_input = libxsmm_dnn_convolution_setup_pack_input_fwd(handle);
handle->fwd_ofh_rb = libxsmm_dnn_convolution_setup_fwd_ofh_rb(handle);
handle->fwd_gemm_pixels = libxsmm_dnn_convolution_setup_fwd_pixels_gemm(handle);
handle->block_fwd_oj = libxsmm_dnn_convolution_setup_fwd_block_H(handle);
handle->loop_order = libxsmm_dnn_convolution_setup_loop_order_fwd(handle);
handle->blocksifm_blocking = libxsmm_dnn_convolution_setup_blocksifm_blocking(handle);
handle->block_fwd_ofm = libxsmm_dnn_convolution_setup_block_fwd_OFM(handle);
handle->block_fwd_ifm = libxsmm_dnn_convolution_setup_block_fwd_IFM(handle);
handle->avoid_fmas_in_rim = libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd(handle);
handle->use_ofm_parallelization = libxsmm_dnn_convolution_setup_use_ofm_parallelization(handle);
handle->shuffle_filter_accesses = libxsmm_dnn_convolution_setup_shuffle_filter_accesses(handle);
handle->avoid_acc_load = libxsmm_dnn_convolution_setup_avoid_acc_load(handle);
handle->fwd_flags = libxsmm_dnn_convolution_setup_init_fwd_gemm_flags(handle);
handle->use_fallback_fwd_loops = libxsmm_dnn_convolution_setup_fallback_loops_fwd(handle);
handle->fwd_padding_copy = libxsmm_dnn_convolution_setup_fwd_padding_copy(handle);
#if 0
if ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 ) {
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
handle->block_fwd_ofm = 1;
handle->block_fwd_oj = handle->fwd_ofh_rb;
ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock;
ldA = handle->ofmblock;
ldC = handle->ofmblock;
beta = (handle->avoid_acc_load) ? (float)0.0 : (float)1.0;
l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags;
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
handle->fwd_compute_kernel_offs_f32 = NULL;
handle->fwd_compute_kernel_strd_f32 = NULL;
handle->fwd_compute_kernel_addr_a_f32 = NULL;
handle->fwd_compute_kernel_addr_b_f32 = NULL;
if (handle->desc.R == 1 && handle->desc.S == 1) {
const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp;
const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp;
int stride_a = handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in);
int stride_b = IFW * IFH * handle->ifmblock * libxsmm_dnn_typesize(handle->datatype_in);
handle->fwd_compute_kernel_strd_f32 = libxsmm_smmdispatch_reducebatch_strd_unroll(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, stride_a, stride_b, handle->blocksifm_blocking, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
} else {
const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp );
const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp );
int n_blocks = handle->desc.R * handle->desc.S * handle->blocksifm_blocking;
int i = 0, ifm, ki, kj;
handle->A_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
handle->B_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
for (ifm = 0; ifm < handle->blocksifm_blocking; ifm++) {
for (kj = 0; kj < handle->desc.R; kj++) {
for (ki = 0; ki < handle->desc.S; ki++) {
handle->A_offsets[i] = (ifm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock +
kj * handle->desc.S * handle->ifmblock * handle->ofmblock +
ki * handle->ifmblock * handle->ofmblock) * libxsmm_dnn_typesize(handle->datatype_in);
handle->B_offsets[i] = (ifm * IFH * IFW * handle->ifmblock +
kj * IFW * handle->ifmblock +
ki * handle->ifmblock) * libxsmm_dnn_typesize(handle->datatype_in);
i++;
}
}
}
handle->fwd_compute_kernel_offs_f32 = libxsmm_smmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
}
handle->fwd_compute_kernel_addr_a_f32 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
handle->fwd_compute_kernel_addr_b_f32 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
}
#endif
if ( ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) && (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) ) {
handle->block_fwd_ofm = 1;
handle->block_fwd_oj = handle->fwd_ofh_rb;
ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock;
ldA = handle->ofmblock;
ldC = handle->ofmblock;
beta = (handle->avoid_acc_load) ? (float)0.0 : (float)1.0;
l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG;
l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
handle->fwd_compute_kernel_addr = NULL;
handle->fwd_compute_kernel_offs_a = NULL;
handle->fwd_compute_kernel_offs_b = NULL;
handle->fwd_compute_kernel_strd = NULL;
if (handle->desc.R == 1 && handle->desc.S == 1) {
const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp;
const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp;
size_t stride_a = handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in);
size_t stride_b = IFW * IFH * handle->ifmblock * libxsmm_dnn_typesize(handle->datatype_in);
handle->fwd_compute_kernel_strd = libxsmm_bmmdispatch_reducebatch_strd_unroll(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock,
(libxsmm_blasint)stride_a, (libxsmm_blasint)stride_b, handle->blocksifm_blocking, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
} else {
const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp );
const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp );
int n_blocks = handle->desc.R * handle->desc.S * handle->blocksifm_blocking;
int i = 0, ifm, ki, kj;
handle->A_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
handle->B_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
for (ifm = 0; ifm < handle->blocksifm_blocking; ifm++) {
for (kj = 0; kj < handle->desc.R; kj++) {
for (ki = 0; ki < handle->desc.S; ki++) {
handle->A_offsets[i] = (ifm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock +
kj * handle->desc.S * handle->ifmblock * handle->ofmblock +
ki * handle->ifmblock * handle->ofmblock) * libxsmm_dnn_typesize(handle->datatype_in);
handle->B_offsets[i] = (ifm * IFH * IFW * handle->ifmblock +
kj * IFW * handle->ifmblock +
ki * handle->ifmblock) * libxsmm_dnn_typesize(handle->datatype_in);
i++;
}
}
}
handle->fwd_compute_kernel_offs_a = libxsmm_bmmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
handle->fwd_compute_kernel_offs_b = libxsmm_bsmmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
}
handle->fwd_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_tc_flags, NULL);
}
handle->code_fwd[0].ptr = 0;
handle->code_fwd[1].ptr = 0;
handle->code_fwd[2].ptr = 0;
/* JIT cvt eltwise functions for fwd convolutions */
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) {
_ldi = handle->ofmblock * handle->ofwp;
_ldo = handle->ofmblock * handle->ofwp;
handle->fwd_cvtfp32bf16_kernel = libxsmm_dispatch_meltw_unary(handle->ofmblock * handle->fwd_ofw_rb, handle->fwd_ofh_rb, &_ldi, &_ldo, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY);
}
/* Create strided BRGEMMs for i8i32 convolutions */
if ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_I32)) {
ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock;
ldA = handle->ofmblock;
ldC = handle->ofmblock;
beta_int = (handle->avoid_acc_load) ? 0 : 1;
l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | handle->fwd_flags;
if (handle->desc.R == 1 && handle->desc.S == 1) {
const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp;
const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp;
libxsmm_blasint stride_A = handle->ifmblock * handle->ofmblock * sizeof(char);
libxsmm_blasint stride_B = handle->ifmblock * IFW * IFH * sizeof(char) ;
handle->gemm_fwd.xgemm.subimrs = libxsmm_subimmdispatch_reducebatch_strd(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, stride_A, stride_B, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL);
} else {
const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp;
const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp;
if (handle->avoid_fmas_in_rim == 0) {
int n_blocks = handle->desc.R * handle->desc.S * handle->blocksifm_blocking;
int i = 0, ifm, ki, kj;
handle->A_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
handle->B_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
for (ifm = 0; ifm < handle->blocksifm_blocking; ifm++) {
for (kj = 0; kj < handle->desc.R; kj++) {
for (ki = 0; ki < handle->desc.S; ki++) {
handle->A_offsets[i] = (ifm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock +
kj * handle->desc.S * handle->ifmblock * handle->ofmblock +
ki * handle->ifmblock * handle->ofmblock) * sizeof(char);
handle->B_offsets[i] = (ifm * IFH * IFW * handle->ifmblock +
kj * IFW * handle->ifmblock +
ki * handle->ifmblock) * sizeof(char);
i++;
}
}
}
handle->gemm_fwd.xgemm.subimro = libxsmm_subimmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL);
} else {
libxsmm_blasint stride_A = handle->ifmblock * handle->desc.R * handle->desc.S * handle->ofmblock * sizeof(char);
libxsmm_blasint stride_B = handle->ifmblock * IFW * IFH * sizeof(char) ;
handle->gemm_fwd.xgemm.subimrs = libxsmm_subimmdispatch_reducebatch_strd(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, stride_A, stride_B, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL);
handle->gemm_fwd2.xgemm.subimrs = libxsmm_subimmdispatch_reducebatch_strd(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, stride_A, stride_B, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL);
}
}
} else if ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_I8)) {
ldx = (libxsmm_blasint)handle->desc.v*handle->ifmblock;
ldA = handle->ofmblock;
ldC = handle->ofmblock;
beta_int = 0;
l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | handle->fwd_flags;
if (handle->desc.R == 1 && handle->desc.S == 1) {
const int IFW = handle->ifwp;
const int IFH = handle->ifhp;
libxsmm_blasint stride_A = handle->ifmblock * handle->ofmblock * sizeof(char);
libxsmm_blasint stride_B = handle->ifmblock * IFW * IFH * sizeof(char) ;
handle->gemm_fwd.xgemm.sububmrs = libxsmm_sububmmdispatch_reducebatch_strd(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, stride_A, stride_B, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL);
} else {
const int IFW = handle->ifwp;
const int IFH = handle->ifhp;
int n_blocks = handle->desc.R * handle->desc.S * handle->blocksifm_blocking;
int i = 0, ifm, ki, kj;
handle->A_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
handle->B_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
for (ifm = 0; ifm < handle->blocksifm_blocking; ifm++) {
for (kj = 0; kj < handle->desc.R; kj++) {
for (ki = 0; ki < handle->desc.S; ki++) {
handle->A_offsets[i] = (ifm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock +
kj * handle->desc.S * handle->ifmblock * handle->ofmblock +
ki * handle->ifmblock * handle->ofmblock) * sizeof(char);
handle->B_offsets[i] = (ifm * IFH * IFW * handle->ifmblock +
kj * IFW * handle->ifmblock +
ki * handle->ifmblock) * sizeof(char);
i++;
}
}
}
handle->gemm_fwd.xgemm.sububmro = libxsmm_sububmmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL);
}
}
#if 0
/* Spit out FWD parameters that are selected... */
printf("FWD params...\n");
printf("Fwd_ofw_rb = %d\n", handle->fwd_ofw_rb);
printf("Fwd_ofh_rb = %d\n", handle->fwd_ofh_rb);
printf("Pack input = %d\n", handle->pack_input);
printf("Block oj = %d\n", handle->block_fwd_oj);
printf("Loop order = %d\n", handle->loop_order);
printf("Blocksifm_blocking = %d\n", handle->blocksifm_blocking);
printf("Block fwd ofm = %d\n", handle->block_fwd_ofm);
printf("Block fwd ifm = %d\n", handle->block_fwd_ifm);
printf("Avoid rim fmas = %d\n", handle->avoid_fmas_in_rim);
printf("Ofm parallelization = %d\n", handle->use_ofm_parallelization);
printf("Shuffle filter accesses = %d\n", handle->shuffle_filter_accesses);
printf("Avoid acc load = %d\n", handle->avoid_acc_load);
printf("Fwd GEMM flags = %d\n", handle->fwd_flags);
#endif
/* BWD parameter setup */
handle->bwd_ofw_rb = libxsmm_dnn_convolution_setup_bwd_ofw_rb(handle);
handle->bwd_ofh_rb = libxsmm_dnn_convolution_setup_bwd_ofh_rb(handle);
handle->bwd_gemm_pixels = libxsmm_dnn_convolution_setup_bwd_pixels_gemm(handle);
handle->pack_input_bwd = libxsmm_dnn_convolution_setup_pack_input_bwd(handle);
handle->spread_input_bwd = libxsmm_dnn_convolution_setup_spread_input_bwd(handle);
handle->blocksofm_blocking = libxsmm_dnn_convolution_setup_blocksofm_blocking(handle);
handle->avoid_acc_load_bwd = libxsmm_dnn_convolution_setup_avoid_acc_load_bwd(handle);
handle->use_ifm_parallelization = libxsmm_dnn_convolution_setup_use_ifm_parallelization(handle);
handle->block_bwd_ofm = libxsmm_dnn_convolution_setup_block_bwd_OFM(handle);
handle->block_bwd_ifm = libxsmm_dnn_convolution_setup_block_bwd_IFM(handle);
handle->block_bwd_oj = libxsmm_dnn_convolution_setup_bwd_block_H(handle);
handle->use_fallback_bwd_loops = libxsmm_dnn_convolution_setup_fallback_loops_bwd(handle);
handle->bwd_flags = libxsmm_dnn_convolution_setup_init_bwd_gemm_flags(handle);
if ( ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) && (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) ) {
handle->block_bwd_ifm = 1;
handle->block_bwd_oj = handle->bwd_ofh_rb ;
ldx = ((libxsmm_blasint)handle->ofmblock);
ldA = handle->ifmblock;
ldC = (handle->spread_input_bwd == 1) ? handle->ifmblock * handle->desc.v : handle->ifmblock;
beta = (handle->avoid_acc_load_bwd) ? (float)0.0 : (float)1.0;
l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG;
l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
handle->bwd_compute_kernel_addr = NULL;
handle->bwd_compute_kernel_offs = NULL;
handle->bwd_compute_kernel_strd = NULL;
if (handle->desc.R == 1 && handle->desc.S == 1) {
size_t stride_a = handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in);
size_t stride_b = handle->ofwp * handle->ofhp * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in);
handle->bwd_compute_kernel_strd = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->ifmblock, handle->bwd_gemm_pixels, handle->ofmblock,
(libxsmm_blasint)stride_a, (libxsmm_blasint)stride_b, handle->blocksofm_blocking, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
} else {
int n_blocks = handle->desc.R * handle->desc.S * handle->blocksofm_blocking;
int i = 0, ofm, ki, kj;
handle->A_offsets_bwd = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
handle->B_offsets_bwd = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
for (ofm = 0; ofm < handle->blocksofm_blocking; ofm++) {
for (kj = 0; kj < handle->desc.R; kj++) {
for (ki = 0; ki < handle->desc.S; ki++) {
handle->A_offsets_bwd[i] = (ofm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock +
kj * handle->desc.S * handle->ifmblock * handle->ofmblock +
ki * handle->ifmblock * handle->ofmblock) * libxsmm_dnn_typesize(handle->datatype_in);
handle->B_offsets_bwd[i] = (ofm * handle->ofhp * handle->ofwp * handle->ofmblock +
kj * handle->ofwp * handle->ofmblock +
ki * handle->ofmblock) * libxsmm_dnn_typesize(handle->datatype_in);
i++;
}
}
}
handle->bwd_compute_kernel_offs = libxsmm_bsmmdispatch_reducebatch_offs(handle->ifmblock, handle->bwd_gemm_pixels, handle->ofmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
}
handle->bwd_config_kernel = libxsmm_bsmmdispatch(handle->ifmblock, handle->bwd_gemm_pixels, handle->ofmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_tc_flags, NULL);
}
#if 0
/* Spit out BWD parameters that are selected... */
printf("BWD params...\n");
printf("Bwd_ofw_rb = %d\n", handle->bwd_ofw_rb);
printf("Bwd_ofh_rb = %d\n", handle->bwd_ofh_rb);
printf("Pack input = %d\n", handle->pack_input_bwd);
printf("Spread input = %d\n", handle->spread_input_bwd);
printf("Blocksofm_blocking = %d\n", handle->blocksofm_blocking);
printf("Avoid acc load = %d\n", handle->avoid_acc_load_bwd);
printf("Ifm parallelization = %d\n", handle->use_ifm_parallelization);
printf("Block bwd ofm = %d\n", handle->block_bwd_ofm);
printf("Block bwd ifm = %d\n", handle->block_bwd_ifm);
printf("Block oj = %d\n", handle->block_bwd_oj);
#endif
handle->code_bwd[0].ptr = 0;
handle->code_bwd[1].ptr = 0;
handle->code_bwd[2].ptr = 0;
/* Transpose kernel used for filter transpose in bwd pass */
handle->tr_kernel = libxsmm_dispatch_meltw_unary(64, 16, &(_ldi), &(_ldo), LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT);
/* UPD parameter setup */
handle->upd_linearized_tasklist = libxsmm_dnn_convolution_setup_linearized_tasklist_upd(handle);
handle->upd_avoid_rim_fmas = libxsmm_dnn_convolution_setup_avoid_rim_fmas_upd(handle);
handle->upd_pack_input = libxsmm_dnn_convolution_setup_pack_input_upd(handle);
handle->upd_use_batchreduce = libxsmm_dnn_convolution_setup_use_batchreduce_upd(handle);
handle->upd_ofw_rb = libxsmm_dnn_convolution_setup_upd_ofw_rb(handle);
handle->upd_ofh_rb = libxsmm_dnn_convolution_setup_upd_ofh_rb(handle);
handle->upd_loop_order = libxsmm_dnn_convolution_setup_loop_order_upd(handle);
handle->weight_copies = libxsmm_dnn_convolution_setup_weight_copies_upd(handle);
handle->block_upd_ofm = libxsmm_dnn_convolution_setup_block_upd_OFM(handle);
handle->block_upd_ifm = libxsmm_dnn_convolution_setup_block_upd_IFM(handle);
handle->upd_loop_order = libxsmm_dnn_convolution_setup_loop_order_upd(handle);
handle->upd_padding_copy = libxsmm_dnn_convolution_setup_upd_padding_copy(handle);
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) {
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) {
libxsmm_dnn_convolution_setup_bf16_upd_amx(handle);
} else {
libxsmm_dnn_convolution_setup_bf16_upd(handle);
}
}
#if 0
/* Spit out UPD parameters that are selected... */
printf("UPD params...\n");
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) {
printf("BF16 path...\n");
printf("UPD use_hybrid_imgofm_parallelization = %d\n", handle->use_hybrid_imgofm_parallelization);
printf("UPD linearized_pixels = %d\n", handle->upd_linearized_pixels);
printf("UPD upd_trans_w_only = %d\n", handle->upd_trans_w_only);
printf("UPD on_the_fly_input_packing = %d\n", handle->on_the_fly_input_packing);
printf("UPD use_intermediate_f32_wt_tensor = %d\n", handle->use_intermediate_f32_wt_tensor);
printf("UPD pack to CNHW format = %d\n", handle->pack_to_cnhw);
printf("UPD batchreduce H pixels = %d\n", handle->batchreduce_h_pixels);
}
printf("UPD linearized tasks = %d\n", handle->upd_linearized_tasklist);
printf("UPD avoid rim fmas = %d\n", handle->upd_avoid_rim_fmas);
printf("UPD Pack input = %d\n", handle->upd_pack_input);
printf("UPD use batch-reduce GEMM = %d\n", handle->upd_use_batchreduce);
printf("Upd_ofw_rb = %d\n", handle->upd_ofw_rb);
printf("Upd_ofh_rb = %d\n", handle->upd_ofh_rb);
printf("UPD loop order = %d\n", handle->upd_loop_order);
printf("UPD weight_copies = %d\n", handle->weight_copies);
printf("Block upd ofm = %d\n", handle->block_upd_ofm);
printf("Block upd ifm = %d\n", handle->block_upd_ifm);
#endif
handle->code_upd[0].ptr = 0;
handle->code_upd[1].ptr = 0;
/* prepare barrier */
handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
/* setup up scratch */
libxsmm_dnn_convolution_setup_fwd_scratch( handle );
libxsmm_dnn_convolution_setup_bwd_scratch( handle );
libxsmm_dnn_convolution_setup_upd_scratch( handle );
handle->scratch = 0;
handle->scratch_size = LIBXSMM_MAX( handle->fwd_scratch_size, LIBXSMM_MAX( handle->bwd_scratch_size, handle->upd_scratch_size ) );
return status;
}
#undef MIXED
#undef KHWC
#undef HWKC
#undef CHWK
#undef HWCK
LIBXSMM_API libxsmm_dnn_layer* libxsmm_dnn_create_conv_layer(
libxsmm_dnn_conv_desc conv_desc,
libxsmm_dnn_err_t* status)
{
libxsmm_dnn_layer* handle = 0;
*status = LIBXSMM_DNN_SUCCESS;
/* currently we don't support NCHW */
if ( (conv_desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCHW) > 0 ) {
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW;
return 0;
}
/* currently we don't support KCRS */
if ( (conv_desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_KCRS) > 0 ) {
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS;
return 0;
}
/* we only support physical paddind in these days */
/* @TODO: add logical padding support for other datatypes than FP32 */
if ( ( ( conv_desc.pad_h != conv_desc.pad_h_in ) ||
( conv_desc.pad_w != conv_desc.pad_w_in ) ||
( conv_desc.pad_h != conv_desc.pad_h_out ) ||
( conv_desc.pad_w != conv_desc.pad_w_out ) ) &&
( conv_desc.datatype_in != LIBXSMM_DNN_DATATYPE_F32 ) && (conv_desc.datatype_in != LIBXSMM_DNN_DATATYPE_BF16) ) {
*status = LIBXSMM_DNN_ERR_INVALID_PADDING;
return 0;
}
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle = (libxsmm_dnn_layer*)calloc(1, sizeof(libxsmm_dnn_layer));
if (0 != handle) {
/* initialize known handle components */
handle->desc = conv_desc;
handle->datatype_in = conv_desc.datatype_in;
handle->datatype_out = conv_desc.datatype_out;
/* select the intermediate format, only applicable for integer types */
if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_F32) ) {
/* error */
} else if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_BF16) ) {
/* error */
} else if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_I16) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_F32) ) {
/* error */
} else if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_I8) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_I32) ) {
/* error */
} else if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_I8) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_I8) ) {
/* error */
} else if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_I8) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_F32) ) {
/* error */
} else {
/* fine, no error */
}
handle->buffer_format = conv_desc.buffer_format;
handle->filter_format = conv_desc.filter_format;
handle->fuse_ops = conv_desc.fuse_ops;
handle->options = conv_desc.options;
/* derive additional values */
handle->ifhp = conv_desc.H + 2*conv_desc.pad_h_in;
handle->ifwp = conv_desc.W + 2*conv_desc.pad_w_in;
handle->ofh = (conv_desc.H + 2*conv_desc.pad_h - conv_desc.R) / conv_desc.u + 1;
handle->ofw = (conv_desc.W + 2*conv_desc.pad_w - conv_desc.S) / conv_desc.v + 1;
handle->ofhp = handle->ofh + 2*conv_desc.pad_h_out;
handle->ofwp = handle->ofw + 2*conv_desc.pad_w_out;
handle->ifmblock = 1;
handle->ofmblock = 1;
handle->blocksifm = conv_desc.C;
handle->blocksofm = conv_desc.K;
handle->fwd_ofw_rb = 1;
handle->fwd_ofh_rb = 1;
handle->bwd_ofw_rb = 1;
handle->bwd_ofh_rb = 1;
handle->upd_ofw_rb = 1;
handle->upd_ofh_rb = 1;
handle->fm_lp_block = 1;
handle->blocksifm_blocking = 1;
handle->blocksofm_blocking = 1;
/* Set algorithm to use */
if (conv_desc.algo == LIBXSMM_DNN_CONV_ALGO_AUTO) {
handle->algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
} else {
handle->algo = conv_desc.algo;
}
if ( handle->algo != LIBXSMM_DNN_CONV_ALGO_DIRECT ) {
*status = LIBXSMM_DNN_ERR_INVALID_ALGO;
free(handle);
handle = 0;
return 0;
}
*status = libxsmm_dnn_convolution_setup(handle);
}
else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
}
/* account for eventually deallocated handle */
if ( LIBXSMM_DNN_SUCCESS != *status ) {
handle = 0;
}
return handle;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_conv_layer(const libxsmm_dnn_layer* handle)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
/* Deallocate barrier */
if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
/* deallocate handle structure itself */
free(/*remove constness*/(libxsmm_dnn_layer*)handle);
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_create_tensor_datalayout(const libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor_datalayout* layout;
*status = LIBXSMM_DNN_SUCCESS;
layout = 0;
if (handle != 0) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout));
if (layout != 0) {
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->format = handle->buffer_format;
layout->tensor_type = LIBXSMM_DNN_ACTIVATION;
if ((handle->buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->ifwp;
layout->dim_size[2] = handle->ifhp;
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->ofwp;
layout->dim_size[2] = handle->ofhp;
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
/* @TODO this need to change */
} else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_I32) ) {
if ( ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_INPUT) ) ) {
layout->datatype = handle->datatype_in;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->datatype = handle->datatype_out;
}
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->ifwp;
layout->dim_size[2] = handle->ifhp;
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->ofwp;
layout->dim_size[2] = handle->ofhp;
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->ifwp;
layout->dim_size[2] = handle->ifhp;
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->ofwp;
layout->dim_size[2] = handle->ofhp;
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else { /* coverity[dead_error_begin] */
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else if ( ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_I16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) ) {
if ( ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_INPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) ) ) {
layout->datatype = handle->datatype_in;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) ) {
layout->datatype = handle->datatype_out;
}
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->ifwp;
layout->dim_size[2] = handle->ifhp;
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->ofwp;
layout->dim_size[2] = handle->ofhp;
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->ofwp;
layout->dim_size[2] = handle->ofhp;
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->ifwp;
layout->dim_size[2] = handle->ifhp;
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else { /* coverity[dead_error_begin] */
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
if ( ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->dim_size[0] = handle->ifmblock * handle->blocksifm;
layout->dim_size[1] = handle->ifwp;
layout->dim_size[2] = handle->ifhp;
layout->dim_size[3] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock * handle->blocksofm;
layout->dim_size[1] = handle->ofwp;
layout->dim_size[2] = handle->ofhp;
layout->dim_size[3] = handle->desc.N;
} else { /* coverity[dead_error_begin] */
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) || (type == LIBXSMM_DNN_FILTER) ) {
layout->format = handle->filter_format;
layout->tensor_type = LIBXSMM_DNN_FILTER;
if ((handle->filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 6;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->ifmblock;
layout->dim_size[2] = handle->desc.S;
layout->dim_size[3] = handle->desc.R;
layout->dim_size[4] = handle->blocksifm;
layout->dim_size[5] = handle->blocksofm;
}
} else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(7*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(7*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 7;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[6] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = handle->fm_lp_block;
layout->dim_size[1] = handle->ofmblock;
layout->dim_size[2] = handle->ifmblock/handle->fm_lp_block;
layout->dim_size[3] = handle->desc.S;
layout->dim_size[4] = handle->desc.R;
layout->dim_size[5] = handle->blocksifm;
layout->dim_size[6] = handle->blocksofm;
}
} else if ( ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_I16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8 ) ) {
if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_FILTER) ) {
layout->datatype = handle->datatype_in;
} else if (type == LIBXSMM_DNN_GRADIENT_FILTER) {
layout->datatype = handle->datatype_out;
}
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(7*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(7*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
if ((type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_FILTER)) {
layout->num_dims = 7;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[6] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = handle->fm_lp_block;
layout->dim_size[1] = handle->ofmblock;
layout->dim_size[2] = handle->ifmblock/handle->fm_lp_block;
layout->dim_size[3] = handle->desc.S;
layout->dim_size[4] = handle->desc.R;
layout->dim_size[5] = handle->blocksifm;
layout->dim_size[6] = handle->blocksofm;
}
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->filter_format & LIBXSMM_DNN_TENSOR_FORMAT_RSCK) > 0) {
if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_size[0] = handle->ofmblock * handle->blocksofm;
layout->dim_size[1] = handle->ifmblock * handle->blocksifm;
layout->dim_size[2] = handle->desc.S;
layout->dim_size[3] = handle->desc.R;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( type == LIBXSMM_DNN_REGULAR_FILTER_TRANS ) {
layout->format = handle->filter_format;
layout->tensor_type = LIBXSMM_DNN_REGULAR_FILTER_TRANS;
if ((handle->filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 6;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->ofmblock;
layout->dim_size[2] = handle->desc.S;
layout->dim_size[3] = handle->desc.R;
layout->dim_size[4] = handle->blocksofm;
layout->dim_size[5] = handle->blocksifm;
}
} else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(7*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(7*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 7;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[6] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->fm_lp_block;
layout->dim_size[1] = handle->ifmblock;
layout->dim_size[2] = handle->ofmblock/handle->fm_lp_block;
layout->dim_size[3] = handle->desc.S;
layout->dim_size[4] = handle->desc.R;
layout->dim_size[5] = handle->blocksofm;
layout->dim_size[6] = handle->blocksifm;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
#if 0
} else if ((handle->filter_format & LIBXSMM_DNN_TENSOR_FORMAT_RSCK) > 0) {
if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_size[0] = handle->ofmblock * handle->blocksofm;
layout->dim_size[1] = handle->ifmblock * handle->blocksifm;
layout->dim_size[2] = handle->desc.S;
layout->dim_size[3] = handle->desc.K;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
#endif
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) || (type == LIBXSMM_DNN_CHANNEL_BIAS) ) {
layout->format = handle->buffer_format;
layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR;
if ((handle->buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
layout->datatype = handle->datatype_out;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 2;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->blocksofm;
}
#if 0
} else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) ) {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(3*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(3*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 3;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->fm_lp_block;
layout->dim_size[1] = handle->ofmblock;
layout->dim_size[2] = handle->blocksofm;
}
#endif
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
layout->datatype = handle->datatype_out;
if ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 ) {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 1;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->ofmblock*handle->blocksofm;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_BATCH_STATS) ) {
layout->format = handle->buffer_format;
layout->tensor_type = LIBXSMM_DNN_BATCH_STATS;
if ((handle->buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) || (handle->datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 2;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->desc.N;
layout->dim_size[2] = handle->blocksofm;
layout->dim_size[3] = 2;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if (type == LIBXSMM_DNN_MAX_STATS_FWD) {
layout->format = handle->buffer_format;
layout->tensor_type = LIBXSMM_DNN_MAX_STATS_FWD;
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 2;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->desc.N;
}
} else if (type == LIBXSMM_DNN_MAX_STATS_BWD) {
layout->format = handle->buffer_format;
layout->tensor_type = LIBXSMM_DNN_MAX_STATS_BWD;
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 2;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->desc.N;
}
} else if (type == LIBXSMM_DNN_MAX_STATS_UPD) {
layout->format = handle->buffer_format;
layout->tensor_type = LIBXSMM_DNN_MAX_STATS_UPD;
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 2;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->desc.N;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
}
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return layout;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_trans_reg_bf16_filter(const libxsmm_dnn_layer* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (handle != 0) {
if ( (handle->reg_filter != 0) && (handle->reg_filter_tr != 0) ) {
/* TODO handle more datatypes */
int ifm1, ifm2, kj, ki, ofm1, ofm2;
int ofmblock_lp = handle->ofmblock/handle->fm_lp_block;
int ifmblock_lp = handle->ifmblock/handle->fm_lp_block;
int lpb = handle->fm_lp_block;
LIBXSMM_VLA_DECL(7, libxsmm_bfloat16, wt, (libxsmm_bfloat16*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb);
LIBXSMM_VLA_DECL(7, libxsmm_bfloat16, tr_wt, (libxsmm_bfloat16*)handle->reg_filter_tr->data, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb);
/* TODO we might want to do this in parallel.... */
for ( ifm1 = 0; ifm1 < handle->blocksifm; ++ifm1 ) {
for ( ofm1 = 0; ofm1 < handle->blocksofm; ++ofm1 ) {
for (kj=0; kj < handle->desc.R; ++kj) {
for (ki=0; ki < handle->desc.S; ++ki) {
for ( ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2 ) {
for ( ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2 ) {
LIBXSMM_VLA_ACCESS(7, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2/lpb, ifm2, ofm2%lpb, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb) =
LIBXSMM_VLA_ACCESS(7, wt, ofm1, ifm1, kj, ki, ifm2/lpb, ofm2, ifm2%lpb, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb);
}
}
}
}
}
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_trans_reg_filter(const libxsmm_dnn_layer* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (handle != 0) {
if ( (handle->reg_filter != 0) && (handle->reg_filter_tr != 0) ) {
/* TODO handle more datatypes */
int ifm1, ifm2, kj, ki, ofm1, ofm2;
LIBXSMM_VLA_DECL(6, float, wt, (float*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock);
LIBXSMM_VLA_DECL(6, float, tr_wt, (float*)handle->reg_filter_tr->data, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock);
/* TODO we might want to do this in parallel.... */
for ( ifm1 = 0; ifm1 < handle->blocksifm; ++ifm1 ) {
for ( ofm1 = 0; ofm1 < handle->blocksofm; ++ofm1 ) {
for (kj=0; kj < handle->desc.R; ++kj) {
for (ki=0; ki < handle->desc.S; ++ki) {
for ( ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2 ) {
for ( ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2 ) {
LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, ofm2, ifm2, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock) =
LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock);
}
}
}
}
}
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_bind_tensor(libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
(type != LIBXSMM_DNN_REGULAR_FILTER_TRANS) && (type != LIBXSMM_DNN_BATCH_STATS) && (type != LIBXSMM_DNN_MAX_STATS_FWD) && (type != LIBXSMM_DNN_MAX_STATS_BWD) && (type != LIBXSMM_DNN_MAX_STATS_UPD) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0 && tensor != 0) {
libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_create_tensor_datalayout(handle, type, &status);
if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
handle->grad_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
handle->reg_filter = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
handle->grad_filter = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
handle->reg_bias = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
handle->grad_bias = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_FILTER_TRANS ) {
handle->reg_filter_tr = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_BATCH_STATS ) {
handle->batch_stats = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_MAX_STATS_FWD ) {
handle->maxstats_fwd = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_MAX_STATS_BWD ) {
handle->maxstats_bwd = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_MAX_STATS_UPD ) {
handle->maxstats_upd = (libxsmm_dnn_tensor*)tensor;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
}
libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_get_tensor(libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status)
{
libxsmm_dnn_tensor* return_tensor = 0;
*status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
(type != LIBXSMM_DNN_REGULAR_FILTER_TRANS) && (type != LIBXSMM_DNN_BATCH_STATS) && (type != LIBXSMM_DNN_MAX_STATS_FWD) && (type != LIBXSMM_DNN_MAX_STATS_BWD) && (type != LIBXSMM_DNN_MAX_STATS_UPD) ) {
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return return_tensor;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
return_tensor = handle->reg_input;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
return_tensor = handle->grad_input;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
return_tensor = handle->reg_output;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
return_tensor = handle->grad_output;
} else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
return_tensor = handle->reg_filter;
} else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
return_tensor = handle->grad_filter;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
return_tensor = handle->reg_bias;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
return_tensor = handle->grad_bias;
} else if ( type == LIBXSMM_DNN_REGULAR_FILTER_TRANS ) {
return_tensor = handle->reg_filter_tr;
} else if ( type == LIBXSMM_DNN_BATCH_STATS ) {
return_tensor = handle->batch_stats;
} else if ( type == LIBXSMM_DNN_MAX_STATS_FWD ) {
return_tensor = handle->maxstats_fwd;
} else if ( type == LIBXSMM_DNN_MAX_STATS_BWD ) {
return_tensor = handle->maxstats_bwd;
} else if ( type == LIBXSMM_DNN_MAX_STATS_UPD ) {
return_tensor = handle->maxstats_upd;
} else {
/* cannot happen */
}
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return return_tensor;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_release_tensor(libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor_type type)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
(type != LIBXSMM_DNN_REGULAR_FILTER_TRANS) && (type != LIBXSMM_DNN_BATCH_STATS) && (type != LIBXSMM_DNN_MAX_STATS_FWD) && (type != LIBXSMM_DNN_MAX_STATS_BWD) && (type != LIBXSMM_DNN_MAX_STATS_UPD) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
handle->grad_output = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
handle->reg_filter = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
handle->grad_filter = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
handle->reg_bias = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
handle->grad_bias = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_FILTER_TRANS ) {
handle->reg_filter_tr = 0;
} else if ( type == LIBXSMM_DNN_BATCH_STATS ) {
handle->batch_stats = 0;
} else if ( type == LIBXSMM_DNN_MAX_STATS_FWD ) {
handle->maxstats_fwd = 0;
} else if ( type == LIBXSMM_DNN_MAX_STATS_BWD ) {
handle->maxstats_bwd = 0;
} else if ( type == LIBXSMM_DNN_MAX_STATS_UPD ) {
handle->maxstats_upd = 0;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API size_t libxsmm_dnn_get_scratch_size(const libxsmm_dnn_layer* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status)
{
size_t l_scratch_size = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD: break;
case LIBXSMM_DNN_COMPUTE_KIND_UPD: break;
case LIBXSMM_DNN_COMPUTE_KIND_ALL: break;
default: {
*status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
l_scratch_size += handle->scratch_size + 64;
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return l_scratch_size;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_bind_scratch(libxsmm_dnn_layer* handle, const libxsmm_dnn_compute_kind kind, const void* scratch)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
uintptr_t address = (uintptr_t)scratch;
size_t offset = 0;
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
if (0 != handle) {
if (address % 64 == 0) {
handle->scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch = (void*)(address+offset);
}
address += handle->scratch_size + 64;
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD: break;
case LIBXSMM_DNN_COMPUTE_KIND_UPD: break;
case LIBXSMM_DNN_COMPUTE_KIND_ALL: break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_release_scratch(libxsmm_dnn_layer* handle, const libxsmm_dnn_compute_kind kind)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
handle->scratch = 0;
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD: break;
case LIBXSMM_DNN_COMPUTE_KIND_UPD: break;
case LIBXSMM_DNN_COMPUTE_KIND_ALL: break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API_INLINE libxsmm_dnn_err_t internal_execute_st(libxsmm_dnn_layer* handle,
libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
switch (handle->algo) {
case LIBXSMM_DNN_CONV_ALGO_DIRECT: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
switch (handle->buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
switch (handle->filter_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_convolve_st_fwd_custom_custom(handle, start_thread, tid);
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
case LIBXSMM_DNN_TENSOR_FORMAT_NHWC: {
switch (handle->filter_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_RSCK: {
status = libxsmm_dnn_convolve_st_fwd_nhwc_rsck(handle, start_thread, tid);
} break;
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_convolve_st_fwd_nhwc_custom(handle, start_thread, tid);
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD: {
switch (handle->buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
switch (handle->filter_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_convolve_st_bwd_custom_custom(handle, start_thread, tid);
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
case LIBXSMM_DNN_TENSOR_FORMAT_NHWC: {
switch (handle->filter_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_RSCK: {
status = libxsmm_dnn_convolve_st_bwd_nhwc_rsck(handle, start_thread, tid);
} break;
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_convolve_st_bwd_nhwc_custom(handle, start_thread, tid);
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_UPD: {
switch (handle->buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
switch (handle->filter_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_convolve_st_upd_custom_custom(handle, start_thread, tid);
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
case LIBXSMM_DNN_TENSOR_FORMAT_NHWC: {
switch (handle->filter_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_RSCK: {
status = libxsmm_dnn_convolve_st_upd_nhwc_rsck(handle, start_thread, tid);
} break;
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_convolve_st_upd_nhwc_custom(handle, start_thread, tid);
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: {
switch (handle->buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
switch (handle->filter_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_convolve_st_upd_custom_custom(handle, start_thread, tid);
status = libxsmm_dnn_convolve_st_bwd_custom_custom(handle, start_thread, tid);
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
case LIBXSMM_DNN_TENSOR_FORMAT_NHWC: {
switch (handle->filter_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_RSCK: {
status = libxsmm_dnn_convolve_st_upd_nhwc_rsck(handle, start_thread, tid);
status = libxsmm_dnn_convolve_st_bwd_nhwc_rsck(handle, start_thread, tid);
} break;
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_convolve_st_upd_nhwc_custom(handle, start_thread, tid);
status = libxsmm_dnn_convolve_st_bwd_nhwc_custom(handle, start_thread, tid);
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_ALGO;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_execute_st(libxsmm_dnn_layer* handle,
libxsmm_dnn_compute_kind kind, /*unsigned*/int start_thread, /*unsigned*/int tid)
{
return internal_execute_st(handle, kind, start_thread, tid);
}
LIBXSMM_API void libxsmm_dnn_execute(libxsmm_dnn_layer* handle, libxsmm_dnn_compute_kind kind)
{
#if defined(_OPENMP)
# pragma omp parallel num_threads(handle->desc.threads)
{
const int tid = omp_get_thread_num();
internal_execute_st(handle, kind, 0, tid);
}
#else
internal_execute_st(handle, kind, 0/*start_thread*/, 0/*tid*/);
#endif
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Evangelos Georganas, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_convolution_backward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void bf16_vnni_transpose_16x16_kernel(void* source_void, void* dest_void, int source_stride, int dest_stride)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
libxsmm_bfloat16 *source = (libxsmm_bfloat16*)source_void;
libxsmm_bfloat16 *dest = (libxsmm_bfloat16*)dest_void;
__m512i zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7;
__m512i tmp0, tmp1, tmp2, tmp3;
const __m512i abcdefgh_to_abefcdgh = _mm512_set4_epi32(0x0f0e0b0a, 0x0d0c0908, 0x07060302, 0x05040100);
zmm0 = _mm512_load_epi32(source);
zmm1 = _mm512_load_epi32(source + source_stride);
zmm2 = _mm512_load_epi32(source + source_stride*2);
zmm3 = _mm512_load_epi32(source + source_stride*3);
zmm4 = _mm512_load_epi32(source + source_stride*4);
zmm5 = _mm512_load_epi32(source + source_stride*5);
zmm6 = _mm512_load_epi32(source + source_stride*6);
zmm7 = _mm512_load_epi32(source + source_stride*7);
zmm0 = _mm512_shuffle_epi8(zmm0, abcdefgh_to_abefcdgh);
zmm1 = _mm512_shuffle_epi8(zmm1, abcdefgh_to_abefcdgh);
zmm2 = _mm512_shuffle_epi8(zmm2, abcdefgh_to_abefcdgh);
zmm3 = _mm512_shuffle_epi8(zmm3, abcdefgh_to_abefcdgh);
zmm4 = _mm512_shuffle_epi8(zmm4, abcdefgh_to_abefcdgh);
zmm5 = _mm512_shuffle_epi8(zmm5, abcdefgh_to_abefcdgh);
zmm6 = _mm512_shuffle_epi8(zmm6, abcdefgh_to_abefcdgh);
zmm7 = _mm512_shuffle_epi8(zmm7, abcdefgh_to_abefcdgh);
tmp0 = _mm512_unpacklo_epi64(zmm0, zmm1);
tmp1 = _mm512_unpackhi_epi64(zmm0, zmm1);
tmp2 = _mm512_unpacklo_epi64(zmm2, zmm3);
tmp3 = _mm512_unpackhi_epi64(zmm2, zmm3);
zmm0 = _mm512_unpacklo_epi64(zmm4, zmm5);
zmm1 = _mm512_unpackhi_epi64(zmm4, zmm5);
zmm2 = _mm512_unpacklo_epi64(zmm6, zmm7);
zmm3 = _mm512_unpackhi_epi64(zmm6, zmm7);
zmm4 = _mm512_shuffle_i32x4(tmp0, tmp2, 0x88);
zmm6 = _mm512_shuffle_i32x4(tmp0, tmp2, 0xdd);
zmm5 = _mm512_shuffle_i32x4(tmp1, tmp3, 0x88);
zmm7 = _mm512_shuffle_i32x4(tmp1, tmp3, 0xdd);
tmp0 = _mm512_shuffle_i32x4(zmm0, zmm2, 0x88);
tmp1 = _mm512_shuffle_i32x4(zmm0, zmm2, 0xdd);
tmp2 = _mm512_shuffle_i32x4(zmm1, zmm3, 0x88);
tmp3 = _mm512_shuffle_i32x4(zmm1, zmm3, 0xdd);
zmm0 = _mm512_shuffle_i32x4(zmm4, tmp0, 0x88);
zmm1 = _mm512_shuffle_i32x4(zmm5, tmp2, 0x88);
zmm2 = _mm512_shuffle_i32x4(zmm6, tmp1, 0x88);
zmm3 = _mm512_shuffle_i32x4(zmm7, tmp3, 0x88);
zmm4 = _mm512_shuffle_i32x4(zmm4, tmp0, 0xdd);
zmm5 = _mm512_shuffle_i32x4(zmm5, tmp2, 0xdd);
zmm6 = _mm512_shuffle_i32x4(zmm6, tmp1, 0xdd);
zmm7 = _mm512_shuffle_i32x4(zmm7, tmp3, 0xdd);
_mm512_store_epi32(dest, zmm0);
_mm512_store_epi32(dest + dest_stride, zmm1);
_mm512_store_epi32(dest + dest_stride * 2, zmm2);
_mm512_store_epi32(dest + dest_stride * 3, zmm3);
_mm512_store_epi32(dest + dest_stride * 4, zmm4);
_mm512_store_epi32(dest + dest_stride * 5, zmm5);
_mm512_store_epi32(dest + dest_stride * 6, zmm6);
_mm512_store_epi32(dest + dest_stride * 7, zmm7);
#else
LIBXSMM_UNUSED(source_void); LIBXSMM_UNUSED(dest_void); LIBXSMM_UNUSED(source_stride); LIBXSMM_UNUSED(dest_stride);
#endif
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void bf16_vnni_transpose_kernel(libxsmm_bfloat16* src, libxsmm_bfloat16* dst, int M, int N, int ld_in, int ld_out)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
const int _M = M/16, _N = N/16;
int i = 0, j = 0;
for (i = 0; i < _N; i++) {
for (j = 0; j < _M; j++) {
bf16_vnni_transpose_16x16_kernel((libxsmm_bfloat16*) src+i*16*ld_in+j*32, (libxsmm_bfloat16*) dst+j*16*ld_out+i*32, ld_in*2, ld_out*2);
}
}
#else
LIBXSMM_UNUSED(src); LIBXSMM_UNUSED(dst); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
#endif
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if (handle->use_fallback_bwd_loops == 0) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
const libxsmm_blasint ldB = (libxsmm_blasint)handle->ofmblock;
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->ifmblock * handle->desc.v) : (libxsmm_blasint)handle->ifmblock;
const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f);
int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c"
}
} else {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
const libxsmm_blasint ldC = (libxsmm_blasint)(handle->desc.v*handle->ifmblock);
{ /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, NULL, NULL, &ldC, NULL, NULL, NULL, NULL);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c"
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
if (handle->use_fallback_bwd_loops == 0) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{ typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function;
typedef libxsmm_bmmfunction_reducebatch_addr gemm_br_function_bf16bf16;
const libxsmm_blasint ldB = (libxsmm_blasint)handle->ofmblock;
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->ifmblock * handle->desc.v) : (libxsmm_blasint)handle->ifmblock;
const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f);
int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function br_gemm_kernel2 = libxsmm_bsmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function_bf16bf16 br_gemm_kernel_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function_bf16bf16 br_gemm_kernel2_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
} else {
const libxsmm_blasint ldC = (libxsmm_blasint)(handle->desc.v*handle->ifmblock);
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_bsmmfunction_reducebatch_strd brgemm_function;
int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
int stride_a = handle->ifmblock * handle->desc.R * handle->desc.S * handle->ofmblock * sizeof(libxsmm_bfloat16);
int stride_b = handle->ofmblock * handle->ofwp * handle->ofhp * sizeof(libxsmm_bfloat16);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function bf16fp32_brgemm_kernel = libxsmm_bsmmdispatch_reducebatch_strd(handle->ifmblock, handle->ofw, handle->ofmblock, stride_a, stride_b, NULL, NULL, &ldC, NULL, NULL, &l_flags, NULL);
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
if (handle->use_fallback_bwd_loops == 0) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_bsmmfunction gemm_function;
typedef libxsmm_bsmmfunction_reducebatch_offs gemm_br_function_offs;
typedef libxsmm_bsmmfunction_reducebatch_strd gemm_br_function_strd;
gemm_br_function_offs br_gemm_kernel_offs = handle->bwd_compute_kernel_offs;
gemm_br_function_strd br_gemm_kernel_strd = handle->bwd_compute_kernel_strd;
gemm_function tile_config_kernel = handle->bwd_config_kernel;
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
} else {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_bsmmfunction_reducebatch_strd brgemm_function;
const libxsmm_blasint ldC = (libxsmm_blasint)(handle->desc.v*handle->ifmblock);
int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
int stride_a = handle->ifmblock * handle->desc.R * handle->desc.S * handle->ofmblock * sizeof(libxsmm_bfloat16);
int stride_b = handle->ofmblock * handle->ofwp * handle->ofhp * sizeof(libxsmm_bfloat16);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function bf16fp32_brgemm_kernel = libxsmm_bsmmdispatch_reducebatch_strd(handle->ifmblock, handle->ofw, handle->ofmblock, stride_a, stride_b, NULL, NULL, &ldC, NULL, NULL, &l_flags, NULL);
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if (handle->use_fallback_bwd_loops == 0) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function;
typedef libxsmm_bmmfunction_reducebatch_addr gemm_br_function_bf16bf16;
const libxsmm_blasint ldB = (libxsmm_blasint)handle->ofmblock;
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->ifmblock * handle->desc.v) : (libxsmm_blasint)handle->ifmblock;
const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f);
int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function br_gemm_kernel2 = libxsmm_bsmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function_bf16bf16 br_gemm_kernel_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function_bf16bf16 br_gemm_kernel2_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
} else {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_bsmmfunction_reducebatch_strd brgemm_function;
const libxsmm_blasint ldC = (libxsmm_blasint)(handle->desc.v*handle->ifmblock);
int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
int stride_a = handle->ifmblock * handle->desc.R * handle->desc.S * handle->ofmblock * sizeof(libxsmm_bfloat16);
int stride_b = handle->ofmblock * handle->ofwp * handle->ofhp * sizeof(libxsmm_bfloat16);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function bf16fp32_brgemm_kernel = libxsmm_bsmmdispatch_reducebatch_strd(handle->ifmblock, handle->ofw, handle->ofmblock, stride_a, stride_b, NULL, NULL, &ldC, NULL, NULL, &l_flags, NULL);
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
return libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid );
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if (handle->use_fallback_bwd_loops == 0) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef libxsmm_bsmmfunction gemm_function;
typedef libxsmm_bsmmfunction_reducebatch_offs gemm_br_function_offs;
typedef libxsmm_bsmmfunction_reducebatch_strd gemm_br_function_strd;
gemm_br_function_offs br_gemm_kernel_offs = handle->bwd_compute_kernel_offs;
gemm_br_function_strd br_gemm_kernel_strd = handle->bwd_compute_kernel_strd;
gemm_function tile_config_kernel = handle->bwd_config_kernel;
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
} else {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_bsmmfunction_reducebatch_strd brgemm_function;
const libxsmm_blasint ldC = (libxsmm_blasint)(handle->desc.v*handle->ifmblock);
int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
int stride_a = handle->ifmblock * handle->desc.R * handle->desc.S * handle->ofmblock * sizeof(libxsmm_bfloat16);
int stride_b = handle->ofmblock * handle->ofwp * handle->ofhp * sizeof(libxsmm_bfloat16);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function bf16fp32_brgemm_kernel = libxsmm_bsmmdispatch_reducebatch_strd(handle->ifmblock, handle->ofw, handle->ofmblock, stride_a, stride_b, NULL, NULL, &ldC, NULL, NULL, &l_flags, NULL);
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
return libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid );
}
#endif
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if (handle->use_fallback_bwd_loops == 0) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock);
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v) : (libxsmm_blasint)(handle->blocksifm * handle->ifmblock);
const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f);
int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
} else {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock);
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = ( (handle->desc.pad_h != handle->desc.pad_h_in) || (handle->desc.pad_w != handle->desc.pad_w_in) ) ?
(libxsmm_blasint)(handle->ifmblock * handle->desc.v) :
(libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, &ldA, &ldB, &ldC, NULL, NULL, NULL, NULL);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if (handle->use_fallback_bwd_loops == 0) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock);
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v) : (libxsmm_blasint)(handle->blocksifm * handle->ifmblock);
const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f);
int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
} else {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock);
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = ( (handle->desc.pad_h != handle->desc.pad_h_in) || (handle->desc.pad_w != handle->desc.pad_w_in) ) ?
(libxsmm_blasint)(handle->ifmblock * handle->desc.v) :
(libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, &ldA, &ldB, &ldC, NULL, NULL, NULL, NULL);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
if (handle->grad_input == 0 || handle->grad_output == 0 || handle->reg_filter == 0 || handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32( handle, start_thread, tid);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_CPX ) {
status = libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CPX && handle->target_archid < LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx( handle, start_thread, tid);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid);
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
if (handle->use_fallback_bwd_loops == 0) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
const libxsmm_blasint ldx = ((libxsmm_blasint)handle->ofmblock);
const libxsmm_blasint ldA = handle->ifmblock;
const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? handle->ifmblock * handle->desc.v : handle->ifmblock;
const float beta = (handle->avoid_acc_load_bwd) ? 0.f : 1.f;
int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c"
}
} else {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
const libxsmm_blasint ldx = ((libxsmm_blasint)handle->desc.v*handle->ifmblock);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, NULL, NULL, &ldx, NULL, NULL, NULL, NULL);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c"
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
if (handle->grad_input == 0 || handle->grad_output == 0 || handle->reg_filter == 0 || handle->scratch == 0) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
if (handle->use_fallback_bwd_loops == 0) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock);
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v) : (libxsmm_blasint)(handle->blocksifm * handle->ifmblock);
const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f);
int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
} else {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock);
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = ( (handle->desc.pad_h != handle->desc.pad_h_in) || (handle->desc.pad_w != handle->desc.pad_w_in) ) ?
(libxsmm_blasint)(handle->ifmblock * handle->desc.v) :
(libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, &ldA, &ldB, &ldC, NULL, NULL, NULL, NULL);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
if (handle->grad_input == 0 || handle->grad_output == 0 || handle->reg_filter == 0 || handle->scratch == 0) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
if (handle->use_fallback_bwd_loops == 0) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock);
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v) : (libxsmm_blasint)(handle->blocksifm * handle->ifmblock);
const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f);
int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
} else {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock);
const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock;
const libxsmm_blasint ldC = ( (handle->desc.pad_h != handle->desc.pad_h_in) || (handle->desc.pad_w != handle->desc.pad_w_in) ) ?
(libxsmm_blasint)(handle->ifmblock * handle->desc.v) :
(libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, &ldA, &ldB, &ldC, NULL, NULL, NULL, NULL);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Rajkishore Barik, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_CONVOLUTION_BACKWARD_H
#define LIBXSMM_DNN_CONVOLUTION_BACKWARD_H
#include <libxsmm_dnn_convolution.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_CONVOLUTION_BACKWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas, Hans Pabst (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_convolution_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
#if 1
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function_addr;
const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock;
const libxsmm_blasint ldA = handle->ofmblock;
const libxsmm_blasint ldC = handle->ofmblock;
const float beta = (handle->avoid_acc_load) ? 0.f : 1.f;
int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags;
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function_addr br_gemm_kernel_a_addr = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function_addr br_gemm_kernel_b_addr = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
#else
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function_addr;
typedef libxsmm_smmfunction_reducebatch_offs gemm_br_function_offs;
typedef libxsmm_smmfunction_reducebatch_strd gemm_br_function_strd;
{
gemm_br_function_addr br_gemm_kernel_a_addr = handle->fwd_compute_kernel_addr_a_f32;
gemm_br_function_addr br_gemm_kernel_b_addr = handle->fwd_compute_kernel_addr_b_f32;
gemm_br_function_offs br_gemm_kernel_offs = handle->fwd_compute_kernel_offs_f32;
gemm_br_function_strd br_gemm_kernel_strd = handle->fwd_compute_kernel_strd_f32;
#endif
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c"
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function;
typedef libxsmm_bmmfunction_reducebatch_addr gemm_br_function_bf16bf16;
const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock;
const libxsmm_blasint ldA = handle->ofmblock;
const libxsmm_blasint ldC = handle->ofmblock;
const float beta = (handle->avoid_acc_load) ? 0.f : 1.f;
int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') )| handle->fwd_flags;
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function br_gemm_kernel2 = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function_bf16bf16 br_gemm_kernel_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function_bf16bf16 br_gemm_kernel2_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bsmmfunction gemm_function;
typedef libxsmm_bmmfunction_reducebatch_offs gemm_br_function_offs_a;
typedef libxsmm_bsmmfunction_reducebatch_offs gemm_br_function_offs_b;
typedef libxsmm_bmmfunction_reducebatch_strd gemm_br_function_strd;
gemm_br_function_offs_a br_gemm_kernel_offs_a = handle->fwd_compute_kernel_offs_a;
gemm_br_function_offs_b br_gemm_kernel_offs_b = handle->fwd_compute_kernel_offs_b;
gemm_br_function_strd br_gemm_kernel_strd = handle->fwd_compute_kernel_strd;
gemm_function tile_config_kernel = handle->fwd_config_kernel;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function;
typedef libxsmm_bmmfunction_reducebatch_addr gemm_br_function_bf16bf16;
const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock;
const libxsmm_blasint ldA = handle->ofmblock;
const libxsmm_blasint ldC = handle->ofmblock;
const float beta = (handle->avoid_acc_load) ? 0.f : 1.f;
int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | handle->fwd_flags;
gemm_br_function br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function br_gemm_kernel2 = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function_bf16bf16 br_gemm_kernel_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
gemm_br_function_bf16bf16 br_gemm_kernel2_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
return libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid );
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bsmmfunction gemm_function;
typedef libxsmm_bmmfunction_reducebatch_offs gemm_br_function_offs_a;
typedef libxsmm_bsmmfunction_reducebatch_offs gemm_br_function_offs_b;
typedef libxsmm_bmmfunction_reducebatch_strd gemm_br_function_strd;
gemm_br_function_offs_a br_gemm_kernel_offs_a = handle->fwd_compute_kernel_offs_a;
gemm_br_function_offs_b br_gemm_kernel_offs_b = handle->fwd_compute_kernel_offs_b;
gemm_br_function_strd br_gemm_kernel_strd = handle->fwd_compute_kernel_strd;
gemm_function tile_config_kernel = handle->fwd_config_kernel;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
return libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid );
}
#endif
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef unsigned char element_input_type;
typedef int element_output_type;
typedef char element_filter_type;
/* Basically we need only offset based and strided BRGEMMs */
libxsmm_subimmfunction_reducebatch_strd br_gemm_kernel_strided = handle->gemm_fwd.xgemm.subimrs;
libxsmm_subimmfunction_reducebatch_strd br_gemm_kernel_strided2 = handle->gemm_fwd2.xgemm.subimrs;
libxsmm_subimmfunction_reducebatch_offs br_gemm_kernel_offset = handle->gemm_fwd.xgemm.subimro;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i32.tpl.c"
#else
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef unsigned char element_input_type;
typedef unsigned char element_output_type;
typedef char element_filter_type;
/* Basically we need only offset based and strided BRGEMMs */
libxsmm_sububmmfunction_reducebatch_strd br_gemm_kernel_strided = handle->gemm_fwd.xgemm.sububmrs;
libxsmm_sububmmfunction_reducebatch_offs br_gemm_kernel_offset = handle->gemm_fwd.xgemm.sububmro;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i8.tpl.c"
#else
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->blocksifm*handle->ifmblock : (libxsmm_blasint)handle->blocksifm*handle->desc.v*handle->ifmblock;
const libxsmm_blasint ldA = handle->ofmblock;
const libxsmm_blasint ldC = handle->blocksofm*handle->ofmblock;
const float beta = (handle->avoid_acc_load) ? 0.f : 1.f;
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags;
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->blocksifm*handle->ifmblock : (libxsmm_blasint)handle->blocksifm*handle->desc.v*handle->ifmblock;
const libxsmm_blasint ldA = handle->blocksofm*handle->ofmblock;
const libxsmm_blasint ldC = handle->blocksofm*handle->ofmblock;
const float beta = (handle->avoid_acc_load) ? 0.f : 1.f;
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags;
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
if (handle->reg_input == 0 || handle->reg_output == 0 || handle->reg_filter == 0) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_I8 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_I32 ) {
status = libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_I8 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_I8 ) {
status = libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8( handle, start_thread, tid);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_CPX) {
status = libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CPX && handle->target_archid < LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx( handle, start_thread, tid);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid);
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
#if 1
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function_addr;
const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock;
const libxsmm_blasint ldA = handle->ofmblock;
const libxsmm_blasint ldC = handle->ofmblock;
const float beta = (handle->avoid_acc_load) ? 0.f : 1.f;
int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags;
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function_addr br_gemm_kernel_a_addr = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function_addr br_gemm_kernel_b_addr = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
#else
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function_addr;
typedef libxsmm_smmfunction_reducebatch_offs gemm_br_function_offs;
typedef libxsmm_smmfunction_reducebatch_strd gemm_br_function_strd;
{
gemm_br_function_addr br_gemm_kernel_a_addr = handle->fwd_compute_kernel_addr_a_f32;
gemm_br_function_addr br_gemm_kernel_b_addr = handle->fwd_compute_kernel_addr_b_f32;
gemm_br_function_offs br_gemm_kernel_offs = handle->fwd_compute_kernel_offs_f32;
gemm_br_function_strd br_gemm_kernel_strd = handle->fwd_compute_kernel_strd_f32;
#endif
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c"
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
if (handle->reg_input == 0 || handle->reg_output == 0 || handle->reg_filter == 0) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->blocksifm*handle->ifmblock : (libxsmm_blasint)handle->blocksifm*handle->desc.v*handle->ifmblock;
const libxsmm_blasint ldA = handle->ofmblock;
const libxsmm_blasint ldC = handle->blocksofm*handle->ofmblock;
const float beta = (handle->avoid_acc_load) ? 0.f : 1.f;
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags;
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
if (handle->reg_input == 0 || handle->reg_output == 0 || handle->reg_filter == 0) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->blocksifm*handle->ifmblock : (libxsmm_blasint)handle->blocksifm*handle->desc.v*handle->ifmblock;
const libxsmm_blasint ldA = handle->blocksofm*handle->ofmblock;
const libxsmm_blasint ldC = handle->blocksofm*handle->ofmblock;
const float beta = (handle->avoid_acc_load) ? 0.f : 1.f;
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags;
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
{ /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_CONVOLUTION_FORWARD_H
#define LIBXSMM_DNN_CONVOLUTION_FORWARD_H
#include <libxsmm_dnn_convolution.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_CONVOLUTION_FORWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Rajkishore Barik, Alexander Heinecke, Ankush Mandal, Jason Sewall (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_convolution_weight_update.h"
#include "libxsmm_main.h"
/* function prototypes for below implementations */
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void transpose_32x16(const libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int ld_in, int ld_out)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
__m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf;
__m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf;
const int in_width=ld_in, out_width=ld_out;
const __m512i idx_lo = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
const __m512i idx_hi = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10);
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
rc = _mm512_loadu_si512(in + 12*in_width);
rd = _mm512_loadu_si512(in + 13*in_width);
re = _mm512_loadu_si512(in + 14*in_width);
rf = _mm512_loadu_si512(in + 15*in_width);
t0 = _mm512_unpacklo_epi16(r0,r1);
t1 = _mm512_unpackhi_epi16(r0,r1);
t2 = _mm512_unpacklo_epi16(r2,r3);
t3 = _mm512_unpackhi_epi16(r2,r3);
t4 = _mm512_unpacklo_epi16(r4,r5);
t5 = _mm512_unpackhi_epi16(r4,r5);
t6 = _mm512_unpacklo_epi16(r6,r7);
t7 = _mm512_unpackhi_epi16(r6,r7);
t8 = _mm512_unpacklo_epi16(r8,r9);
t9 = _mm512_unpackhi_epi16(r8,r9);
ta = _mm512_unpacklo_epi16(ra,rb);
tb = _mm512_unpackhi_epi16(ra,rb);
tc = _mm512_unpacklo_epi16(rc,rd);
td = _mm512_unpackhi_epi16(rc,rd);
te = _mm512_unpacklo_epi16(re,rf);
tf = _mm512_unpackhi_epi16(re,rf);
r0 = _mm512_unpacklo_epi32(t0,t2);
r1 = _mm512_unpackhi_epi32(t0,t2);
r2 = _mm512_unpacklo_epi32(t1,t3);
r3 = _mm512_unpackhi_epi32(t1,t3);
r4 = _mm512_unpacklo_epi32(t4,t6);
r5 = _mm512_unpackhi_epi32(t4,t6);
r6 = _mm512_unpacklo_epi32(t5,t7);
r7 = _mm512_unpackhi_epi32(t5,t7);
r8 = _mm512_unpacklo_epi32(t8,ta);
r9 = _mm512_unpackhi_epi32(t8,ta);
ra = _mm512_unpacklo_epi32(t9,tb);
rb = _mm512_unpackhi_epi32(t9,tb);
rc = _mm512_unpacklo_epi32(tc,te);
rd = _mm512_unpackhi_epi32(tc,te);
re = _mm512_unpacklo_epi32(td,tf);
rf = _mm512_unpackhi_epi32(td,tf);
t0 = _mm512_unpacklo_epi64(r0,r4);
t1 = _mm512_unpackhi_epi64(r0,r4);
t2 = _mm512_unpacklo_epi64(r1,r5);
t3 = _mm512_unpackhi_epi64(r1,r5);
t4 = _mm512_unpacklo_epi64(r2,r6);
t5 = _mm512_unpackhi_epi64(r2,r6);
t6 = _mm512_unpacklo_epi64(r3,r7);
t7 = _mm512_unpackhi_epi64(r3,r7);
t8 = _mm512_unpacklo_epi64(r8,rc);
t9 = _mm512_unpackhi_epi64(r8,rc);
ta = _mm512_unpacklo_epi64(r9,rd);
tb = _mm512_unpackhi_epi64(r9,rd);
tc = _mm512_unpacklo_epi64(ra,re);
td = _mm512_unpackhi_epi64(ra,re);
te = _mm512_unpacklo_epi64(rb,rf);
tf = _mm512_unpackhi_epi64(rb,rf);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r2 = _mm512_shuffle_i32x4(t4, t5, 0x88);
r3 = _mm512_shuffle_i32x4(t6, t7, 0x88);
r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd);
r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd);
r8 = _mm512_shuffle_i32x4(t8, t9, 0x88);
r9 = _mm512_shuffle_i32x4(ta, tb, 0x88);
ra = _mm512_shuffle_i32x4(tc, td, 0x88);
rb = _mm512_shuffle_i32x4(te, tf, 0x88);
rc = _mm512_shuffle_i32x4(t8, t9, 0xdd);
rd = _mm512_shuffle_i32x4(ta, tb, 0xdd);
re = _mm512_shuffle_i32x4(tc, td, 0xdd);
rf = _mm512_shuffle_i32x4(te, tf, 0xdd);
t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8);
t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9);
t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra);
t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb);
t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc);
t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd);
t6 = _mm512_permutex2var_epi64(r6, idx_lo, re);
t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf);
t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0);
t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1);
ta = _mm512_permutex2var_epi64(ra, idx_hi, r2);
tb = _mm512_permutex2var_epi64(rb, idx_hi, r3);
tc = _mm512_permutex2var_epi64(rc, idx_hi, r4);
td = _mm512_permutex2var_epi64(rd, idx_hi, r5);
te = _mm512_permutex2var_epi64(re, idx_hi, r6);
tf = _mm512_permutex2var_epi64(rf, idx_hi, r7);
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 0*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 1*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 2*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 3*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 4*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 5*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 6*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 7*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 8*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 9*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 10*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 11*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 12*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 13*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 14*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 15*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 16*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 17*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 18*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 19*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 20*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 21*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 22*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 23*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 24*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 25*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 26*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 27*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 28*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 29*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 30*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 31*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1));
#else
LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
#endif
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void transpose_32xcols(const libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int col, int ld_in, int ld_out)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
__m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf;
__m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf;
const int in_width=ld_in, out_width=ld_out;
const __m512i idx_lo = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
const __m512i idx_hi = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10);
__mmask16 store_mask = LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16(((unsigned int)1 << col) - 1);
rf = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
if (col == 15) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
rc = _mm512_loadu_si512(in + 12*in_width);
rd = _mm512_loadu_si512(in + 13*in_width);
re = _mm512_loadu_si512(in + 14*in_width);
} else if (col == 14) {
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
rc = _mm512_loadu_si512(in + 12*in_width);
rd = _mm512_loadu_si512(in + 13*in_width);
} else if (col == 13) {
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
rc = _mm512_loadu_si512(in + 12*in_width);
} else if (col == 12) {
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
} else if (col == 11) {
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
} else if (col == 10) {
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
} else if (col == 9) {
r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
} else if (col == 8) {
r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
} else if (col == 7) {
r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
} else if (col == 6) {
r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
} else if (col == 5) {
r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
} else if (col == 4) {
r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
} else if (col == 3) {
r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
} else if (col == 2) {
r2 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
} else if (col == 1) {
r1 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r2 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r0 = _mm512_loadu_si512(in + 0*in_width);
} else {
r0 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r1 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r2 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
}
t0 = _mm512_unpacklo_epi16(r0,r1);
t1 = _mm512_unpackhi_epi16(r0,r1);
t2 = _mm512_unpacklo_epi16(r2,r3);
t3 = _mm512_unpackhi_epi16(r2,r3);
t4 = _mm512_unpacklo_epi16(r4,r5);
t5 = _mm512_unpackhi_epi16(r4,r5);
t6 = _mm512_unpacklo_epi16(r6,r7);
t7 = _mm512_unpackhi_epi16(r6,r7);
t8 = _mm512_unpacklo_epi16(r8,r9);
t9 = _mm512_unpackhi_epi16(r8,r9);
ta = _mm512_unpacklo_epi16(ra,rb);
tb = _mm512_unpackhi_epi16(ra,rb);
tc = _mm512_unpacklo_epi16(rc,rd);
td = _mm512_unpackhi_epi16(rc,rd);
te = _mm512_unpacklo_epi16(re,rf);
tf = _mm512_unpackhi_epi16(re,rf);
r0 = _mm512_unpacklo_epi32(t0,t2);
r1 = _mm512_unpackhi_epi32(t0,t2);
r2 = _mm512_unpacklo_epi32(t1,t3);
r3 = _mm512_unpackhi_epi32(t1,t3);
r4 = _mm512_unpacklo_epi32(t4,t6);
r5 = _mm512_unpackhi_epi32(t4,t6);
r6 = _mm512_unpacklo_epi32(t5,t7);
r7 = _mm512_unpackhi_epi32(t5,t7);
r8 = _mm512_unpacklo_epi32(t8,ta);
r9 = _mm512_unpackhi_epi32(t8,ta);
ra = _mm512_unpacklo_epi32(t9,tb);
rb = _mm512_unpackhi_epi32(t9,tb);
rc = _mm512_unpacklo_epi32(tc,te);
rd = _mm512_unpackhi_epi32(tc,te);
re = _mm512_unpacklo_epi32(td,tf);
rf = _mm512_unpackhi_epi32(td,tf);
t0 = _mm512_unpacklo_epi64(r0,r4);
t1 = _mm512_unpackhi_epi64(r0,r4);
t2 = _mm512_unpacklo_epi64(r1,r5);
t3 = _mm512_unpackhi_epi64(r1,r5);
t4 = _mm512_unpacklo_epi64(r2,r6);
t5 = _mm512_unpackhi_epi64(r2,r6);
t6 = _mm512_unpacklo_epi64(r3,r7);
t7 = _mm512_unpackhi_epi64(r3,r7);
t8 = _mm512_unpacklo_epi64(r8,rc);
t9 = _mm512_unpackhi_epi64(r8,rc);
ta = _mm512_unpacklo_epi64(r9,rd);
tb = _mm512_unpackhi_epi64(r9,rd);
tc = _mm512_unpacklo_epi64(ra,re);
td = _mm512_unpackhi_epi64(ra,re);
te = _mm512_unpacklo_epi64(rb,rf);
tf = _mm512_unpackhi_epi64(rb,rf);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r2 = _mm512_shuffle_i32x4(t4, t5, 0x88);
r3 = _mm512_shuffle_i32x4(t6, t7, 0x88);
r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd);
r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd);
r8 = _mm512_shuffle_i32x4(t8, t9, 0x88);
r9 = _mm512_shuffle_i32x4(ta, tb, 0x88);
ra = _mm512_shuffle_i32x4(tc, td, 0x88);
rb = _mm512_shuffle_i32x4(te, tf, 0x88);
rc = _mm512_shuffle_i32x4(t8, t9, 0xdd);
rd = _mm512_shuffle_i32x4(ta, tb, 0xdd);
re = _mm512_shuffle_i32x4(tc, td, 0xdd);
rf = _mm512_shuffle_i32x4(te, tf, 0xdd);
t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8);
t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9);
t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra);
t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb);
t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc);
t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd);
t6 = _mm512_permutex2var_epi64(r6, idx_lo, re);
t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf);
t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0);
t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1);
ta = _mm512_permutex2var_epi64(ra, idx_hi, r2);
tb = _mm512_permutex2var_epi64(rb, idx_hi, r3);
tc = _mm512_permutex2var_epi64(rc, idx_hi, r4);
td = _mm512_permutex2var_epi64(rd, idx_hi, r5);
te = _mm512_permutex2var_epi64(re, idx_hi, r6);
tf = _mm512_permutex2var_epi64(rf, idx_hi, r7);
_mm256_mask_storeu_epi16(out + 0*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0));
_mm256_mask_storeu_epi16(out + 1*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1));
_mm256_mask_storeu_epi16(out + 2*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0));
_mm256_mask_storeu_epi16(out + 3*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1));
_mm256_mask_storeu_epi16(out + 4*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0));
_mm256_mask_storeu_epi16(out + 5*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1));
_mm256_mask_storeu_epi16(out + 6*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0));
_mm256_mask_storeu_epi16(out + 7*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1));
_mm256_mask_storeu_epi16(out + 8*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0));
_mm256_mask_storeu_epi16(out + 9*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1));
_mm256_mask_storeu_epi16(out + 10*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0));
_mm256_mask_storeu_epi16(out + 11*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1));
_mm256_mask_storeu_epi16(out + 12*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0));
_mm256_mask_storeu_epi16(out + 13*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1));
_mm256_mask_storeu_epi16(out + 14*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0));
_mm256_mask_storeu_epi16(out + 15*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1));
_mm256_mask_storeu_epi16(out + 16*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0));
_mm256_mask_storeu_epi16(out + 17*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1));
_mm256_mask_storeu_epi16(out + 18*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0));
_mm256_mask_storeu_epi16(out + 19*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1));
_mm256_mask_storeu_epi16(out + 20*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0));
_mm256_mask_storeu_epi16(out + 21*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1));
_mm256_mask_storeu_epi16(out + 22*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0));
_mm256_mask_storeu_epi16(out + 23*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1));
_mm256_mask_storeu_epi16(out + 24*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0));
_mm256_mask_storeu_epi16(out + 25*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1));
_mm256_mask_storeu_epi16(out + 26*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0));
_mm256_mask_storeu_epi16(out + 27*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1));
_mm256_mask_storeu_epi16(out + 28*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0));
_mm256_mask_storeu_epi16(out + 29*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1));
_mm256_mask_storeu_epi16(out + 30*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0));
_mm256_mask_storeu_epi16(out + 31*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1));
#else
LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(col); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
#endif
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void transpose_input_pixels_bf16(const libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int M, int N, int ld_in, int ld_out){
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
int i, j;
int full16_chunks = N/16;
int remainder_cols = N%16;
int _N = N - remainder_cols;
if (full16_chunks) {
for (i=0; i<M; i+=32) {
for (j=0; j<_N; j+=16) {
transpose_32x16((const libxsmm_bfloat16*)in + i + ld_in*j, (libxsmm_bfloat16*)out + j + i*ld_out, ld_in, ld_out);
}
}
}
if (remainder_cols) {
for (i=0; i<M; i+=32) {
transpose_32xcols((const libxsmm_bfloat16*)in + i + ld_in*full16_chunks*16, (libxsmm_bfloat16*)out + full16_chunks*16 + i*ld_out, remainder_cols, ld_in, ld_out);
}
}
#else
LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
#endif
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_bsmmfunction gemm_function;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_bsmmfunction gemm_function;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bsmmfunction_reducebatch_strd gemm_br_function;
gemm_function tile_config_kernel = handle->upd_config_kernel;
gemm_function gemm_kernel = NULL;
gemm_br_function br_gemm_kernel = NULL;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_bsmmfunction gemm_function;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
return libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu( handle, start_thread, tid );
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_bsmmfunction gemm_function;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bsmmfunction_reducebatch_strd gemm_br_function;
gemm_function tile_config_kernel = handle->upd_config_kernel;
gemm_function gemm_kernel = NULL;
gemm_br_function br_gemm_kernel = NULL;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
return libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid );
}
#endif
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
if (handle->reg_input == 0 || handle->grad_output == 0 || handle->grad_filter == 0 || handle->scratch == 0) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ((handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32( handle, start_thread, tid);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_CPX ) {
status = libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CPX && handle->target_archid < LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx( handle, start_thread, tid);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid);
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic.tpl.c"
}
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
if (handle->reg_input == 0 || handle->grad_output == 0 || handle->grad_filter == 0 || handle->scratch == 0) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
}
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
if (handle->reg_input == 0 || handle->grad_output == 0 || handle->grad_filter == 0 || handle->scratch == 0) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
}
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Rajkishore Barik, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H
#define LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H
#include <libxsmm_dnn_convolution.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H */
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment